Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ coverage.xml
.cache
.pytest_cache/
.python-version
pip
pip
.mypy_cache/
99 changes: 31 additions & 68 deletions slack/rtm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
import os
import logging
import random
import json
import collections
import functools
import inspect
import signal
import concurrent.futures
from typing import Optional, Callable
from typing import Optional, Callable, DefaultDict
from ssl import SSLContext

# ThirdParty Imports
Expand Down Expand Up @@ -69,13 +66,12 @@ class RTMClient(object):
Example:
```python
import os
import slack
from slack import RTMClient

@slack.RTMClient.run_on(event='message')
@RTMClient.run_on(event="message")
def say_hello(**payload):
data = payload['data']
web_client = payload['web_client']
rtm_client = payload['rtm_client']
if 'Hello' in data['text']:
channel_id = data['channel']
thread_ts = data['ts']
Expand All @@ -88,7 +84,7 @@ def say_hello(**payload):
)

slack_token = os.environ["SLACK_API_TOKEN"]
rtm_client = slack.RTMClient(token=slack_token)
rtm_client = RTMClient(token=slack_token)
rtm_client.start()
```

Expand All @@ -102,7 +98,7 @@ def say_hello(**payload):
removed at anytime.
"""

_callbacks = collections.defaultdict(list)
_callbacks: DefaultDict = collections.defaultdict(list)

def __init__(
self,
Expand Down Expand Up @@ -141,11 +137,7 @@ def run_on(*, event: str):
"""A decorator to store and link a callback to an event."""

def decorator(callback):
@functools.wraps(callback)
def decorator_wrapper():
RTMClient.on(event=event, callback=callback)

return decorator_wrapper()
RTMClient.on(event=event, callback=callback)

return decorator

Expand Down Expand Up @@ -196,7 +188,7 @@ def start(self) -> asyncio.Future:

future = asyncio.ensure_future(self._connect_and_read(), loop=self._event_loop)

if self.run_async or self._event_loop.is_running():
if self.run_async:
return future

return self._event_loop.run_until_complete(future)
Expand Down Expand Up @@ -231,17 +223,19 @@ def send_over_websocket(self, *, payload: dict):
Raises:
SlackClientNotConnectedError: Websocket connection is closed.
"""
return asyncio.ensure_future(self._send_json(payload))

async def _send_json(self, payload):
if self._websocket is None or self._event_loop is None:
raise client_err.SlackClientNotConnectedError(
"Websocket connection is closed."
)
if "id" not in payload:
payload["id"] = self._next_msg_id()
asyncio.ensure_future(
self._websocket.send_str(json.dumps(payload)), loop=self._event_loop
)

def ping(self):
return await self._websocket.send_json(payload)

async def ping(self):
"""Sends a ping message over the websocket to Slack.

Not all web browsers support the WebSocket ping spec,
Expand All @@ -251,9 +245,9 @@ def ping(self):
SlackClientNotConnectedError: Websocket connection is closed.
"""
payload = {"id": self._next_msg_id(), "type": "ping"}
self.send_over_websocket(payload=payload)
await self._send_json(payload=payload)

def typing(self, *, channel: str):
async def typing(self, *, channel: str):
"""Sends a typing indicator to the specified channel.

This indicates that this app is currently
Expand All @@ -266,7 +260,7 @@ def typing(self, *, channel: str):
SlackClientNotConnectedError: Websocket connection is closed.
"""
payload = {"id": self._next_msg_id(), "type": "typing", "channel": channel}
self.send_over_websocket(payload=payload)
await self._send_json(payload=payload)

@staticmethod
def _validate_callback(callback):
Expand Down Expand Up @@ -307,9 +301,9 @@ def _next_msg_id(self):
return self._last_message_id

async def _connect_and_read(self):
"""Retreives and connects to Slack's RTM API.
"""Retreives the WS url and connects to Slack's RTM API.

Makes an authenticated call to Slack's RTM API to retrieve
Makes an authenticated call to Slack's Web API to retrieve
a websocket URL. Then connects to the message server and
reads event messages as they come in.

Expand Down Expand Up @@ -338,15 +332,15 @@ async def _connect_and_read(self):
) as websocket:
self._logger.debug("The Websocket connection has been opened.")
self._websocket = websocket
self._dispatch_event(event="open", data=data)
await self._dispatch_event(event="open", data=data)
await self._read_messages()
except (
client_err.SlackClientNotConnectedError,
client_err.SlackApiError,
# TODO: Catch websocket exceptions thrown by aiohttp.
) as exception:
self._logger.debug(str(exception))
self._dispatch_event(event="error", data=exception)
await self._dispatch_event(event="error", data=exception)
if self.auto_reconnect and not self._stopped:
await self._wait_exponentially(exception)
continue
Expand All @@ -366,11 +360,11 @@ async def _read_messages(self):
if message.type == aiohttp.WSMsgType.TEXT:
payload = message.json()
event = payload.pop("type", "Unknown")
self._dispatch_event(event, data=payload)
await self._dispatch_event(event, data=payload)
elif message.type == aiohttp.WSMsgType.ERROR:
break

def _dispatch_event(self, event, data=None):
async def _dispatch_event(self, event, data=None):
"""Dispatches the event and executes any associated callbacks.

Note: To prevent the app from crashing due to callback errors. We
Expand Down Expand Up @@ -399,52 +393,19 @@ def _dispatch_event(self, event, data=None):
# Don't run callbacks if client was stopped unless they're close/error callbacks.
break

if self.run_async:
self._execute_callback_async(callback, data)
if inspect.iscoroutinefunction(callback):
await callback(
rtm_client=self, web_client=self._web_client, data=data
)
else:
self._execute_callback(callback, data)
callback(rtm_client=self, web_client=self._web_client, data=data)
except Exception as err:
name = callback.__name__
module = callback.__module__
msg = f"When calling '#{name}()' in the '{module}' module the following error was raised: {err}"
self._logger.error(msg)
raise

def _execute_callback_async(self, callback, data):
"""Execute the callback asynchronously.

If the callback is not a coroutine, convert it.

Note: The WebClient passed into the callback is running in "async" mode.
This means all responses will be futures.
"""
if asyncio.iscoroutine(callback):
asyncio.ensure_future(
callback(rtm_client=self, web_client=self._web_client, data=data)
)
else:
asyncio.ensure_future(
asyncio.coroutine(callback)(
rtm_client=self, web_client=self._web_client, data=data
)
)

def _execute_callback(self, callback, data):
"""Execute the callback in another thread. Wait for and return the results."""
web_client = WebClient(
token=self.token, base_url=self.base_url, ssl=self.ssl, proxy=self.proxy
)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
# Execute the callback on a separate thread,
future = executor.submit(
callback, rtm_client=self, web_client=web_client, data=data
)

while future.running():
pass

future.result()

async def _retreive_websocket_info(self):
"""Retreives the WebSocket info from Slack.

Expand Down Expand Up @@ -491,7 +452,7 @@ async def _wait_exponentially(self, exception, max_wait_time=300):
"""Wait exponentially longer for each connection attempt.

Calculate the number of seconds to wait and then add
a random number of milliseconds to avoid coincendental
a random number of milliseconds to avoid coincidental
synchronized client retries. Wait up to the maximium amount
of wait time specified via 'max_wait_time'. However,
if Slack returned how long to wait use that.
Expand All @@ -512,4 +473,6 @@ def _close_websocket(self):
if callable(close_method):
asyncio.ensure_future(close_method(), loop=self._event_loop)
self._websocket = None
self._dispatch_event(event="close")
asyncio.ensure_future(
self._dispatch_event(event="close"), loop=self._event_loop
)
18 changes: 9 additions & 9 deletions slack/web/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import sys
import logging
import asyncio
from typing import Optional, Union
import inspect

# ThirdParty Imports
import aiohttp
from aiohttp import FormData

# Internal Imports
from slack.web.slack_response import SlackResponse
Expand All @@ -25,7 +27,7 @@ def __init__(
token,
base_url=BASE_URL,
timeout=30,
loop=None,
loop: Optional[asyncio.AbstractEventLoop] = None,
ssl=None,
proxy=None,
run_async=False,
Expand All @@ -43,19 +45,19 @@ def __init__(

def _set_event_loop(self):
if self.run_async:
self._event_loop = asyncio.get_event_loop()
return asyncio.get_event_loop()
else:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._event_loop = loop
return loop

def api_call(
self,
api_method: str,
*,
http_verb: str = "POST",
files: dict = None,
data: dict = None,
data: Union[dict, FormData] = None,
params: dict = None,
json: dict = None,
):
Expand Down Expand Up @@ -99,14 +101,14 @@ def api_call(
"Authorization": "Bearer {}".format(self.token),
}
if files is not None:
form_data = aiohttp.FormData()
form_data = FormData()
for k, v in files.items():
if isinstance(v, str):
form_data.add_field(k, open(v, "rb"))
else:
form_data.add_field(k, v)

if data is not None:
if isinstance(data, dict):
for k, v in data.items():
form_data.add_field(k, str(v))

Expand All @@ -122,7 +124,7 @@ def api_call(
}

if self._event_loop is None:
self._set_event_loop()
self._event_loop = self._set_event_loop()

future = asyncio.ensure_future(
self._send(http_verb=http_verb, api_url=api_url, req_args=req_args),
Expand Down Expand Up @@ -195,7 +197,6 @@ async def _request(self, *, http_verb, api_url, req_args):
"""
if self.session and not self.session.closed:
async with self.session.request(http_verb, api_url, **req_args) as res:
self._logger.debug("Ran the request with existing session.")
return {
"data": await res.json(),
"headers": res.headers,
Expand All @@ -205,7 +206,6 @@ async def _request(self, *, http_verb, api_url, req_args):
loop=self._event_loop, timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:
async with session.request(http_verb, api_url, **req_args) as res:
self._logger.debug("Ran the request with a new session.")
return {
"data": await res.json(),
"headers": res.headers,
Expand Down
7 changes: 3 additions & 4 deletions slack/web/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,10 +656,9 @@ def files_upload(

if file:
return self.api_call("files.upload", files={"file": file}, data=kwargs)
elif content:
data = kwargs.copy()
data.update({"content": content})
return self.api_call("files.upload", data=data)
data = kwargs.copy()
data.update({"content": content})
return self.api_call("files.upload", data=data)

def groups_archive(self, *, channel: str, **kwargs) -> SlackResponse:
"""Archives a private channel.
Expand Down
4 changes: 3 additions & 1 deletion tests/rtm/test_rtm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import collections
import unittest
from unittest import mock
import asyncio

# Internal Imports
import slack
Expand Down Expand Up @@ -65,7 +66,8 @@ def invalid_cb():

def test_send_over_websocket_raises_when_not_connected(self):
with self.assertRaises(e.SlackClientError) as context:
self.client.send_over_websocket(payload={})
loop = asyncio.get_event_loop()
loop.run_until_complete(self.client.send_over_websocket(payload={}))

expected_error = "Websocket connection is closed."
error = str(context.exception)
Expand Down
10 changes: 5 additions & 5 deletions tests/rtm/test_rtm_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def websocket_handler(self, request):
async for msg in ws:
await ws.send_json({"type": "message", "message_sent": msg.json()})
finally:
request.app["websockets"].discard(ws)
request.app["websockets"].remove(ws)
return ws

async def on_shutdown(self, app):
Expand Down Expand Up @@ -170,9 +170,9 @@ def check_message(**payload):

def test_ping_sends_expected_message(self, mock_rtm_response):
@slack.RTMClient.run_on(event="open")
def ping_message(**payload):
async def ping_message(**payload):
rtm_client = payload["rtm_client"]
rtm_client.ping()
await rtm_client.ping()

@slack.RTMClient.run_on(event="message")
def check_message(**payload):
Expand All @@ -185,9 +185,9 @@ def check_message(**payload):

def test_typing_sends_expected_message(self, mock_rtm_response):
@slack.RTMClient.run_on(event="open")
def typing_message(**payload):
async def typing_message(**payload):
rtm_client = payload["rtm_client"]
rtm_client.typing(channel="C01234567")
await rtm_client.typing(channel="C01234567")

@slack.RTMClient.run_on(event="message")
def check_message(**payload):
Expand Down