Skip to content

Commit 98238d6

Browse files
committed
Hopefully this will catch most of the use cases where a write operation is attempted after the connection was closed Lawouach#73
1 parent 3f9b2bf commit 98238d6

1 file changed

Lines changed: 19 additions & 9 deletions

File tree

ws4py/websocket.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def close(self, code=1000, reason=''):
173173
"""
174174
if not self.server_terminated:
175175
self.server_terminated = True
176-
self.sock.sendall(self.stream.close(code=code, reason=reason).single(mask=self.stream.always_mask))
176+
self._write(self.stream.close(code=code, reason=reason).single(mask=self.stream.always_mask))
177177

178178
def closed(self, code, reason=None):
179179
"""
@@ -229,6 +229,19 @@ def received_message(self, message):
229229
"""
230230
pass
231231

232+
def _write(self, b):
233+
"""
234+
Trying to prevent a write operation
235+
on an already closed websocket stream.
236+
237+
This cannot be bullet proof but hopefully
238+
will catch almost all use cases.
239+
"""
240+
if self.terminated or self.sock is None:
241+
raise RuntimeError("Cannot send on a terminated websocket")
242+
243+
self.sock.sendall(b)
244+
232245
def send(self, payload, binary=False):
233246
"""
234247
Sends the given ``payload`` out.
@@ -241,28 +254,25 @@ def send(self, payload, binary=False):
241254
242255
If ``binary`` is set, handles the payload as a binary message.
243256
"""
244-
if self.terminated:
245-
raise RuntimeError("Cannot send on a terminated websocket")
246-
247257
message_sender = self.stream.binary_message if binary else self.stream.text_message
248258

249259
if isinstance(payload, basestring) or isinstance(payload, bytearray):
250260
m = message_sender(payload).single(mask=self.stream.always_mask)
251-
self.sock.sendall(m)
261+
self._write(m)
252262

253263
elif isinstance(payload, Message):
254264
data = payload.single(mask=self.stream.always_mask)
255-
self.sock.sendall(data)
265+
self._write(data)
256266

257267
elif type(payload) == types.GeneratorType:
258268
bytes = next(payload)
259269
first = True
260270
for chunk in payload:
261-
self.sock.sendall(message_sender(bytes).fragment(first=first, mask=self.stream.always_mask))
271+
self._write(message_sender(bytes).fragment(first=first, mask=self.stream.always_mask))
262272
bytes = chunk
263273
first = False
264274

265-
self.sock.sendall(message_sender(bytes).fragment(last=True, mask=self.stream.always_mask))
275+
self._write(message_sender(bytes).fragment(last=True, mask=self.stream.always_mask))
266276

267277
else:
268278
raise ValueError("Unsupported type '%s' passed to send()" % type(payload))
@@ -377,7 +387,7 @@ def process(self, bytes):
377387

378388
if s.pings:
379389
for ping in s.pings:
380-
self.sock.sendall(s.pong(ping.data))
390+
self._write(s.pong(ping.data))
381391
s.pings = []
382392

383393
if s.pongs:

0 commit comments

Comments
 (0)