Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions integration_tests/samples/socket_mode/aiohttp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ async def process(client: SocketModeClient, req: SocketModeRequest):
if req.type == "events_api":
response = SocketModeResponse(envelope_id=req.envelope_id)
await client.send_socket_mode_response(response)

await client.web_client.reactions_add(
name="eyes",
channel=req.payload["event"]["channel"],
timestamp=req.payload["event"]["ts"],
)
if req.payload["event"]["type"] == "message":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the issue where this example may fail if the app subscribes to other events like reaction_added

await client.web_client.reactions_add(
name="eyes",
channel=req.payload["event"]["channel"],
timestamp=req.payload["event"]["ts"],
)

client.socket_mode_request_listeners.append(process)
await client.connect()
Expand Down
11 changes: 10 additions & 1 deletion slack_sdk/socket_mode/aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import asyncio
import logging
from asyncio import Future
from asyncio import Future, Lock
from asyncio import Queue
from logging import Logger
from typing import Union, Optional, List, Callable, Awaitable
Expand Down Expand Up @@ -58,6 +58,7 @@ class SocketModeClient(AsyncBaseSocketModeClient):
auto_reconnect_enabled: bool
default_auto_reconnect_enabled: bool
closed: bool
connect_operation_lock: Lock

on_message_listeners: List[Callable[[WSMessage], Awaitable[None]]]
on_error_listeners: List[Callable[[WSMessage], Awaitable[None]]]
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
self.logger = logger or logging.getLogger(__name__)
self.web_client = web_client or AsyncWebClient()
self.closed = False
self.connect_operation_lock = Lock()
self.proxy = proxy
if self.proxy is None or len(self.proxy.strip()) == 0:
env_variable = load_http_proxy_from_env(self.logger)
Expand Down Expand Up @@ -185,6 +187,13 @@ async def receive_messages(self) -> None:
else:
await asyncio.sleep(consecutive_error_count)

async def is_connected(self) -> bool:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For easier access to the state and better alignment with the sync SocketModeClient, I've added this flag.

return (
not self.closed
and self.current_session is not None
and not self.current_session.closed
)

async def connect(self):
old_session = None if self.current_session is None else self.current_session
if self.wss_uri is None:
Expand Down
21 changes: 16 additions & 5 deletions slack_sdk/socket_mode/async_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
import logging
from asyncio import Queue
from asyncio import Queue, Lock
from asyncio.futures import Future
from logging import Logger
from typing import Dict, Union, Any, Optional, List, Callable, Awaitable
Expand All @@ -23,6 +23,8 @@ class AsyncBaseSocketModeClient:
wss_uri: str
auto_reconnect_enabled: bool
closed: bool
connect_operation_lock: Lock

message_queue: Queue
message_listeners: List[
Union[
Expand Down Expand Up @@ -58,15 +60,24 @@ async def issue_new_wss_url(self) -> str:
self.logger.error(f"Failed to retrieve WSS URL: {e}")
raise e

async def is_connected(self) -> bool:
return False

async def connect(self):
raise NotImplementedError()

async def disconnect(self):
raise NotImplementedError()

async def connect_to_new_endpoint(self):
self.wss_uri = await self.issue_new_wss_url()
await self.connect()
async def connect_to_new_endpoint(self, force: bool = False):
try:
await self.connect_operation_lock.acquire()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lock acquisition and release inside this method is the main change in this PR. With this way, other concurrent method calls are synchronized. The client can safely check sef.is_connected() inside and decide if the client still needs to replace the WSS URL and the underlying WebSocket session.

if force or not await self.is_connected():
self.wss_uri = await self.issue_new_wss_url()
await self.connect()
finally:
if self.connect_operation_lock.locked() is True:
self.connect_operation_lock.release()

async def close(self):
self.closed = True
Expand Down Expand Up @@ -116,7 +127,7 @@ async def run_message_listeners(self, message: dict, raw_message: str) -> None:
)
try:
if message.get("type") == "disconnect":
await self.connect_to_new_endpoint()
await self.connect_to_new_endpoint(force=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sync client does the same. When we receive a disconnect type message from the Socket Mode server-side, we should immediately reconnect even if the current session is still active.

return

for listener in self.message_listeners:
Expand Down
11 changes: 10 additions & 1 deletion slack_sdk/socket_mode/websockets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import asyncio
import logging
from asyncio import Future
from asyncio import Future, Lock
from logging import Logger
from asyncio import Queue
from typing import Union, Optional, List, Callable, Awaitable
Expand Down Expand Up @@ -56,6 +56,7 @@ class SocketModeClient(AsyncBaseSocketModeClient):
auto_reconnect_enabled: bool
default_auto_reconnect_enabled: bool
closed: bool
connect_operation_lock: Lock
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the only changes that are necessary in other client implementations. I'll have a migration guide section in the next version's release notes.


def __init__(
self,
Expand All @@ -78,6 +79,7 @@ def __init__(
self.logger = logger or logging.getLogger(__name__)
self.web_client = web_client or AsyncWebClient()
self.closed = False
self.connect_operation_lock = Lock()
self.default_auto_reconnect_enabled = auto_reconnect_enabled
self.auto_reconnect_enabled = self.default_auto_reconnect_enabled
self.ping_interval = ping_interval
Expand Down Expand Up @@ -130,6 +132,13 @@ async def receive_messages(self) -> None:
else:
await asyncio.sleep(consecutive_error_count)

async def is_connected(self) -> bool:
return (
not self.closed
and self.current_session is not None
and not self.current_session.closed
)

async def connect(self):
if self.wss_uri is None:
self.wss_uri = await self.issue_new_wss_url()
Expand Down