22from base64 import b64encode , b64encode
33from hashlib import sha1
44import os
5+ import socket
6+ import ssl
57import types
68from urlparse import urlsplit
79import json
810
911from ws4py import WS_KEY
10- from ws4py .streaming import Stream
1112from ws4py .exc import HandshakeError
13+ from ws4py .websocket import WebSocket , WS_VERSION
1214
1315__all__ = ['WebSocketBaseClient' ]
1416
15- class WebSocketBaseClient (object ):
16- def __init__ (self , url , protocols = None , version = '13' ):
17- self . stream = Stream ( always_mask = True , expect_masking = False )
18- self . url = url
19- self .protocols = protocols
20- self .version = version
17+ class WebSocketBaseClient (WebSocket ):
18+ def __init__ (self , url , protocols , extensions ):
19+ sock = socket . socket ( socket . AF_INET , socket . SOCK_STREAM , 0 )
20+ WebSocket . __init__ ( self , sock , protocols = protocols , extensions = extensions )
21+ self .stream . always_mask = True
22+ self .stream . expect_masking = False
2123 self .key = b64encode (os .urandom (16 ))
22- self .client_terminated = False
23- self .server_terminated = False
24+ self .url = url
25+
26+ def connect (self ):
27+ #self.sock.settimeout(3)
28+ parts = urlsplit (self .url )
29+ host , port = parts .netloc , 80
30+ if ':' in host :
31+ host , port = parts .netloc .split (':' )
32+ self .sock .connect ((host , int (port )))
33+
34+ if parts .scheme == "wss" :
35+ self .sock = ssl .wrap_socket (self .sock )
2436
37+ self .sender (self .handshake_request )
38+
39+ response = ''
40+ while True :
41+ bytes = self .sock .recv (128 )
42+ if not bytes :
43+ break
44+ response += bytes
45+ if '\r \n \r \n ' in response :
46+ break
47+
48+ if not response :
49+ self .close_connection ()
50+ raise HandshakeError ("Invalid response" )
51+
52+ headers , _ , body = response .partition ('\r \n \r \n ' )
53+ response_line , _ , headers = headers .partition ('\r \n ' )
54+
55+ self .__buffer = body
56+
57+ try :
58+ self .process_response_line (response_line )
59+ self .protocols , self .extensions = self .process_handshake_header (headers )
60+ except HandshakeError :
61+ self .close_connection ()
62+ raise
63+
64+ self .handshake_ok ()
65+
2566 @property
2667 def handshake_headers (self ):
2768 parts = urlsplit (self .url )
@@ -35,7 +76,7 @@ def handshake_headers(self):
3576 ('Upgrade' , 'websocket' ),
3677 ('Sec-WebSocket-Key' , self .key ),
3778 ('Sec-WebSocket-Origin' , self .url ),
38- ('Sec-WebSocket-Version' , self . version )
79+ ('Sec-WebSocket-Version' , WS_VERSION )
3980 ]
4081
4182 if self .protocols :
@@ -89,58 +130,3 @@ def process_handshake_header(self, headers):
89130 extensions = ',' .join (value )
90131
91132 return protocols , extensions
92-
93- def opened (self , protocols , extensions ):
94- pass
95-
96- def received_message (self , m ):
97- pass
98-
99- def closed (self , code , reason = None ):
100- pass
101-
102- @property
103- def terminated (self ):
104- return self .client_terminated is True and self .server_terminated is True
105-
106- def close (self , reason = '' , code = 1000 ):
107- if not self .client_terminated :
108- self .client_terminated = True
109- self .write_to_connection (self .stream .close (code = code , reason = reason ).single (mask = True ))
110-
111- def connect (self ):
112- raise NotImplemented ()
113-
114- def write_to_connection (self , bytes ):
115- raise NotImplemented ()
116-
117- def read_from_connection (self , amount ):
118- raise NotImplemented ()
119-
120- def close_connection (self ):
121- raise NotImplemented ()
122-
123- def send (self , payload , binary = False ):
124- message_sender = self .stream .binary_message if binary else self .stream .text_message
125-
126- if isinstance (payload , basestring ):
127- self .write_to_connection (message_sender (payload ).single (mask = True ))
128-
129- elif isinstance (payload , dict ):
130- self .write_to_connection (message_sender (json .dumps (payload )).single (mask = True ))
131-
132- elif type (payload ) == types .GeneratorType :
133- first = True
134- last = False
135- bytes = payload .next ()
136-
137- while not last :
138- try :
139- peeked_bytes = payload .next ()
140- except StopIteration :
141- last = True
142-
143- self .write_to_connection (message_sender (bytes ).fragment (first = first , last = last , mask = True ))
144- first = False
145- bytes = peeked_bytes
146-
0 commit comments