Skip to content

Commit 4152604

Browse files
websocket: use asyncio instead of threads
1 parent 26d9fed commit 4152604

File tree

3 files changed

+86
-58
lines changed

3 files changed

+86
-58
lines changed

internal_filesystem/lib/websocket.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def send_bytes(self, data):
130130
"""Send binary data."""
131131
self.send(data, ABNF.OPCODE_BINARY)
132132

133-
def close(self, **kwargs):
133+
async def close(self, **kwargs):
134134
"""Close the WebSocket connection."""
135135
_log_debug("Close requested")
136136
self.running = False
@@ -184,7 +184,7 @@ def ready(self):
184184
_log_debug(f"Connection status: ready={status}")
185185
return status
186186

187-
def run_forever(
187+
async def run_forever(
188188
self,
189189
sockopt=None,
190190
sslopt=None,
@@ -230,7 +230,7 @@ def run_forever(
230230
self.close()
231231
return False
232232
except Exception as e:
233-
_log_error(f"run_forever's _loop.run_until_complete() got general exception: {e}")
233+
_log_error(f"run_forever's _loop.run_until_complete() for {self.url} got general exception: {e}")
234234
self.has_errored = True
235235
self.running = False
236236
#return True
@@ -262,7 +262,7 @@ async def _async_main(self):
262262
try:
263263
await self._connect_and_run() # keep waiting for it, until finished
264264
except Exception as e:
265-
_log_error(f"_async_main got exception: {e}")
265+
_log_error(f"_async_main's await self._connect_and_run() got exception: {e}")
266266
self.has_errored = True
267267
_run_callback(self.on_error, self, e)
268268
if not reconnect:
@@ -298,6 +298,11 @@ async def _connect_and_run(self):
298298

299299
self.session = aiohttp.ClientSession(headers=self.header)
300300
async with self.session.ws_connect(self.url, ssl=ssl_context) as ws:
301+
if not ws:
302+
print("ERROR: ws_connect got None instead of ws object!")
303+
_run_callback(self.on_error, self, str(e))
304+
return
305+
301306
self.ws = ws
302307
_log_debug("WebSocket connected, running on_open callback")
303308
_run_callback(self.on_open, self)

tests/test_websocket.py

Lines changed: 76 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import unittest
23
import _thread
34
import time
@@ -7,87 +8,109 @@
78

89
from websocket import WebSocketApp
910

10-
class TestWebsocket(unittest.TestCase):
11+
class TestMutlipleWebsocketsAsyncio(unittest.TestCase):
1112

12-
ws = None
13+
max_allowed_connections = 3 # max that echo.websocket.org allows
1314

14-
on_open_called = None
15-
on_message_called = None
16-
on_ping_called = None
17-
on_close_called = None
15+
#relays = ["wss://echo.websocket.org", "wss://echo.websocket.org", "wss://echo.websocket.org" ] # more gives "too many requests" error
16+
relays = ["wss://echo.websocket.org", "wss://echo.websocket.org", "wss://echo.websocket.org", "wss://echo.websocket.org", "wss://echo.websocket.org" ] # more might give "too many requests" error
17+
wslist = []
18+
19+
on_open_called = 0
20+
on_message_called = 0
21+
on_ping_called = 0
22+
on_close_called = 0
23+
on_error_called = 0
1824

1925
def on_message(self, wsapp, message: str):
2026
print(f"on_message received: {message}")
2127
self.on_message_called = True
2228

2329
def on_open(self, wsapp):
2430
print(f"on_open called: {wsapp}")
25-
self.on_open_called = True
26-
self.ws.send('{"type": "subscribe","product_ids": ["BTC-USD"],"channels": ["ticker_batch"]}')
31+
self.on_open_called += 1
32+
#wsapp.send('{"type": "subscribe","product_ids": ["BTC-USD"],"channels": ["ticker_batch"]}')
2733

2834
def on_ping(wsapp, message):
2935
print("Got a ping!")
3036
self.on_ping_called = True
3137

3238
def on_close(self, wsapp, close_status_code, close_msg):
3339
print(f"on_close called: {wsapp}")
34-
self.on_close_called = True
40+
self.on_close_called += 1
3541

36-
def websocket_thread(self):
37-
wsurl = "wss://ws-feed.exchange.coinbase.com"
42+
def on_error(self, wsapp, arg1):
43+
print(f"on_error called: {wsapp}, {arg1}")
44+
self.on_error_called += 1
3845

39-
self.ws = WebSocketApp(
40-
wsurl,
41-
on_open=self.on_open,
42-
on_close=self.on_close,
43-
on_message=self.on_message,
44-
on_ping=self.on_ping
45-
) # maybe add other callbacks to reconnect when disconnected etc.
46-
self.ws.run_forever()
46+
async def closeall(self):
47+
await asyncio.sleep(1)
4748

48-
def wait_for_ping(self):
49-
self.on_ping_called = False
50-
for _ in range(60):
51-
print("Waiting for on_ping to be called...")
52-
if self.on_ping_called:
53-
print("yes, it was called!")
54-
break
55-
time.sleep(1)
56-
self.assertTrue(self.on_ping_called)
49+
self.on_close_called = 0
50+
print("disconnecting...")
51+
for ws in self.wslist:
52+
await ws.close()
5753

58-
def test_it(self):
59-
on_open_called = False
60-
_thread.stack_size(mpos.apps.good_stack_size())
61-
_thread.start_new_thread(self.websocket_thread, ())
62-
63-
self.on_open_called = False
64-
self.on_message_called = False # message might be received very quickly, before we expect it
65-
for _ in range(5):
54+
async def main(self) -> None:
55+
tasks = []
56+
self.wslist = []
57+
for idx, wsurl in enumerate(self.relays):
58+
print(f"creating WebSocketApp for {wsurl}")
59+
ws = WebSocketApp(
60+
wsurl,
61+
on_open=self.on_open,
62+
on_close=self.on_close,
63+
on_message=self.on_message,
64+
on_ping=self.on_ping,
65+
on_error=self.on_error
66+
)
67+
print(f"creating task for {wsurl}")
68+
tasks.append(asyncio.create_task(ws.run_forever(),))
69+
print(f"created task for {wsurl}")
70+
self.wslist.append(ws)
71+
72+
print(f"Starting {len(tasks)} concurrent WebSocket connections…")
73+
await asyncio.sleep(2)
74+
await self.closeall()
75+
76+
for _ in range(10):
6677
print("Waiting for on_open to be called...")
67-
if self.on_open_called:
78+
if self.on_open_called == min(len(self.relays),self.max_allowed_connections):
6879
print("yes, it was called!")
6980
break
70-
time.sleep(1)
71-
self.assertTrue(self.on_open_called)
81+
await asyncio.sleep(1)
82+
self.assertTrue(self.on_open_called == min(len(self.relays),self.max_allowed_connections))
7283

73-
self.on_message_called = False # message might be received very quickly, before we expect it
74-
for _ in range(5):
75-
print("Waiting for on_message to be called...")
76-
if self.on_message_called:
84+
for _ in range(10):
85+
print("Waiting for on_close to be called...")
86+
if self.on_close_called == min(len(self.relays),self.max_allowed_connections):
7787
print("yes, it was called!")
7888
break
79-
time.sleep(1)
80-
self.assertTrue(self.on_message_called)
89+
await asyncio.sleep(1)
90+
self.assertTrue(self.on_close_called == min(len(self.relays),self.max_allowed_connections))
8191

82-
# Disabled because not all servers send pings:
83-
# self.wait_for_ping()
92+
self.assertTrue(self.on_error_called == min(len(self.relays),self.max_allowed_connections))
8493

85-
self.on_close_called = False
86-
self.ws.close()
87-
for _ in range(5):
88-
print("Waiting for on_close to be called...")
89-
if self.on_close_called:
94+
# Wait for *all* of them to finish (or be cancelled)
95+
# If this hangs, it's also a failure:
96+
await asyncio.gather(*tasks, return_exceptions=True)
97+
98+
def wait_for_ping(self):
99+
self.on_ping_called = False
100+
for _ in range(60):
101+
print("Waiting for on_ping to be called...")
102+
if self.on_ping_called:
90103
print("yes, it was called!")
91104
break
92105
time.sleep(1)
93-
self.assertTrue(self.on_close_called)
106+
self.assertTrue(self.on_ping_called)
107+
108+
def test_it_loop(self):
109+
for testnr in range(1):
110+
print(f"starting iteration {testnr}")
111+
asyncio.run(self.do_two())
112+
print(f"finished iteration {testnr}")
113+
114+
def do_two(self):
115+
await self.main()
116+

0 commit comments

Comments
 (0)