1717
1818__all__ = ['WebSocketProtocol' ]
1919
20- class WebSocketProtocol (asyncio .Protocol ):
20+ class WebSocketProtocol (asyncio .StreamReaderProtocol ):
2121 def __init__ (self , handler_cls ):
22- asyncio .Protocol .__init__ (self )
22+ asyncio .StreamReaderProtocol .__init__ (self , asyncio .StreamReader (),
23+ self ._pseudo_connected )
2324 self .ws = handler_cls (self )
25+
26+ def _pseudo_connected (self , reader , writer ):
27+ pass
2428
2529 def connection_made (self , transport ):
2630 """
@@ -31,10 +35,34 @@ def connection_made(self, transport):
3135 and the transport is associated before the
3236 initial HTTP handshake is undertaken.
3337 """
34- self .transport = transport
35- self .stream = asyncio .StreamReader ()
36- self .stream .set_transport (transport )
37- asyncio .async (self .handle_initial_handshake ())
38+ #self.transport = transport
39+ #self.stream = asyncio.StreamReader()
40+ #self.stream.set_transport(transport)
41+ asyncio .StreamReaderProtocol .connection_made (self , transport )
42+ # Let make it concurrent for others to tag along
43+ f = asyncio .async (self .handle_initial_handshake ())
44+ f .add_done_callback (self .terminated )
45+
46+ @property
47+ def writer (self ):
48+ return self ._stream_writer
49+
50+ @property
51+ def reader (self ):
52+ return self ._stream_reader
53+
54+ def terminated (self , f ):
55+ if f .done () and not f .cancelled ():
56+ ex = f .exception ()
57+ if ex :
58+ print (ex )
59+ response = [b'HTTP/1.0 400 Bad Request' ]
60+ response .append (b'Content-Length: 0' )
61+ response .append (b'Connection: close' )
62+ response .append (b'' )
63+ response .append (b'' )
64+ self .writer .write (CRLF .join (response ))
65+ self .ws .close_connection ()
3866
3967 def close (self ):
4068 """
@@ -57,13 +85,11 @@ def connection_lost(self, exc):
5785 be aware of it by calling its `closed`
5886 method.
5987 """
60- self .ws .close_connection ()
61- if self .ws .started :
62- self .ws .closed (1002 , "Peer connection was lost" )
88+ if exc is not None :
89+ self .ws .close_connection ()
90+ if self .ws .started :
91+ self .ws .closed (1002 , "Peer connection was lost" )
6392
64- def data_received (self , data ):
65- self .stream .feed_data (data )
66-
6793 @asyncio .coroutine
6894 def handle_initial_handshake (self ):
6995 """
@@ -81,6 +107,8 @@ def handle_initial_handshake(self):
81107 raise HandshakeError ('HTTP method must be a GET' )
82108
83109 headers = yield from self .read_headers ()
110+ if req_protocol == b'HTTP/1.1' and 'Host' not in headers :
111+ raise ValueError ("Missing host header" )
84112
85113 for key , expected_value in [('Upgrade' , 'websocket' ),
86114 ('Connection' , 'upgrade' )]:
@@ -113,30 +141,37 @@ def handle_initial_handshake(self):
113141 raise HandshakeError ("WebSocket key's length is invalid" )
114142
115143 protocols = []
144+ ws_protocols = []
116145 subprotocols = headers .get ('Sec-WebSocket-Protocol' )
117146 if subprotocols :
118- ws_protocols = []
119147 for s in subprotocols .split (',' ):
120148 s = s .strip ()
121149 if s in protocols :
122150 ws_protocols .append (s )
123151
124152 exts = []
153+ ws_extensions = []
125154 extensions = headers .get ('Sec-WebSocket-Extensions' )
126155 if extensions :
127156 for ext in extensions .split (',' ):
128157 ext = ext .strip ()
129158 if ext in exts :
130159 ws_extensions .append (ext )
131160
132- self .transport .write (('%s 101 Switching Protocols\r \n ' % req_protocol ).encode ('utf-8' ))
133- self .transport .write (b'Upgrade: websocket\r \n ' )
134- self .transport .write (b'Content-Length: 0\r \n ' )
135- self .transport .write (b'Connection: Upgrade\r \n ' )
136- self .transport .write (b'Sec-WebSocket-Version:' + bytes (str (version ), 'utf-8' ) + CRLF )
137- self .transport .write (b'Sec-WebSocket-Accept:' + base64 .b64encode (sha1 (key .encode ('utf-8' ) + WS_KEY ).digest ()) + CRLF )
138- self .transport .write (CRLF )
139-
161+ response = [req_protocol + b' 101 Switching Protocols' ]
162+ response .append (b'Upgrade: websocket' )
163+ response .append (b'Content-Type: text/plain' )
164+ response .append (b'Content-Length: 0' )
165+ response .append (b'Connection: Upgrade' )
166+ response .append (b'Sec-WebSocket-Version:' + bytes (str (version ), 'utf-8' ))
167+ response .append (b'Sec-WebSocket-Accept:' + base64 .b64encode (sha1 (key .encode ('utf-8' ) + WS_KEY ).digest ()))
168+ if ws_protocols :
169+ response .append (b'Sec-WebSocket-Protocol:' + b', ' .join (ws_protocols ))
170+ if ws_extensions :
171+ response .append (b'Sec-WebSocket-Extensions:' + b',' .join (ws_extensions ))
172+ response .append (b'' )
173+ response .append (b'' )
174+ self .writer .write (CRLF .join (response ))
140175 yield from self .handle_websocket ()
141176
142177 @asyncio .coroutine
@@ -167,19 +202,19 @@ def next_line(self):
167202 Reads data until \r \n is met and then return all read
168203 bytes.
169204 """
170- line = yield from self .stream .readline ()
205+ line = yield from self .reader .readline ()
171206 if not line .endswith (CRLF ):
172207 raise ValueError ("Missing mandatory trailing CRLF" )
173208 return line
174209
175210if __name__ == '__main__' :
176- from ws4py .websocket import AsyncEchoWebSocket
211+ from ws4py .async_websocket import EchoWebSocket
177212
178213 loop = asyncio .get_event_loop ()
179214
180215 def start_server ():
181- proto_factory = lambda : WebSocketProtocol (AsyncEchoWebSocket )
182- return loop .create_server (proto_factory , '' , 7002 )
216+ proto_factory = lambda : WebSocketProtocol (EchoWebSocket )
217+ return loop .create_server (proto_factory , '' , 9007 )
183218
184219 s = loop .run_until_complete (start_server ())
185220 print ('serving on' , s .sockets [0 ].getsockname ())
0 commit comments