@@ -294,7 +294,7 @@ async def prepare_ciphers_and_headers(self, reader_remote, writer_remote, host,
294294 whost , wport = self .jump .destination (host , port )
295295 await self .rproto .connect (reader_remote = reader_remote , writer_remote = writer_remote , rauth = self .auth , host_name = whost , port = wport , writer_cipher_r = writer_cipher_r , myhost = self .host_name , sock = writer_remote .get_extra_info ('socket' ))
296296 return await self .jump .prepare_ciphers_and_headers (reader_remote , writer_remote , host , port )
297- def start_server (self , args ):
297+ def start_server (self , args , stream_handler = stream_handler ):
298298 handler = functools .partial (stream_handler , ** vars (self ), ** args )
299299 if self .unix :
300300 return asyncio .start_unix_server (handler , path = self .bind )
@@ -307,15 +307,20 @@ def __init__(self, quicserver, quicclient, **kw):
307307 self .quicserver = quicserver
308308 self .quicclient = quicclient
309309 self .handshake = None
310- self .streams = {}
311310 def patch_writer (self , writer ):
312311 async def drain ():
313312 writer ._transport .protocol .transmit ()
314313 #print('stream_id', writer.get_extra_info("stream_id"))
315314 remote_addr = writer ._transport .protocol ._quic ._network_paths [0 ].addr
316315 writer .get_extra_info = dict (peername = remote_addr , sockname = remote_addr ).get
317316 writer .drain = drain
317+ closed = False
318+ def is_closing ():
319+ return closed
320+ writer .is_closing = is_closing
318321 def close ():
322+ nonlocal closed
323+ closed = True
319324 try :
320325 writer .write_eof ()
321326 except Exception :
@@ -345,7 +350,7 @@ def quic_event_received(s, event):
345350 reader , writer = conn ._create_stream (stream_id )
346351 self .patch_writer (writer )
347352 return reader , writer
348- async def start_server (self , args ):
353+ async def start_server (self , args , stream_handler = stream_handler ):
349354 import aioquic .asyncio
350355 def handler (reader , writer ):
351356 self .patch_writer (writer )
@@ -405,16 +410,17 @@ async def channel():
405410 return await self .jump .prepare_ciphers_and_headers (reader_remote , writer_remote , host , port )
406411
407412class ProxyBackward (ProxySimple ):
408- def __init__ (self , backward , ** kw ):
413+ def __init__ (self , backward , backward_num , ** kw ):
409414 super ().__init__ (** kw )
410- self .backward_num = backward
415+ self .backward = backward
416+ self .backward_num = backward_num
411417 self .closed = False
412418 self .writers = set ()
413419 self .conn = asyncio .Queue ()
414420 async def wait_open_connection (self , * args ):
415421 while True :
416422 reader , writer = await self .conn .get ()
417- if not writer .transport . is_closing ():
423+ if not writer .is_closing ():
418424 return reader , writer
419425 def close (self ):
420426 self .closed = True
@@ -423,24 +429,21 @@ def close(self):
423429 self .writer .close ()
424430 except Exception :
425431 pass
426- async def start_server (self , args ):
432+ async def start_server (self , args , stream_handler = stream_handler ):
427433 handler = functools .partial (stream_handler , ** vars (self ), ** args )
428434 for _ in range (self .backward_num ):
429435 asyncio .ensure_future (self .start_server_run (handler ))
430436 return self
431437 async def start_server_run (self , handler ):
432438 errwait = 0
433439 while not self .closed :
434- if self .unix :
435- wait = asyncio .open_unix_connection (path = self .bind )
436- else :
437- wait = asyncio .open_connection (host = self .host_name , port = self .port , local_addr = (self .lbind , 0 ) if self .lbind else None )
440+ wait = self .backward .open_connection (self .host_name , self .port , self .lbind , None )
438441 try :
439442 reader , writer = await asyncio .wait_for (wait , timeout = SOCKET_TIMEOUT )
440443 if self .closed :
441444 writer .close ()
442445 break
443- writer .write (self .auth )
446+ writer .write (self .auth or b' \x01 ' )
444447 self .writers .add (writer )
445448 try :
446449 data = await reader .read_n (1 )
@@ -462,17 +465,14 @@ async def start_server_run(self, handler):
462465 await asyncio .sleep (errwait )
463466 errwait = min (errwait * 1.3 + 0.1 , 30 )
464467 def start_backward_client (self , args ):
465- async def handler (reader , writer ):
466- if self .auth :
467- try :
468- assert self . auth == (await reader .read_n (len (self . auth )))
469- except Exception :
470- return
468+ async def handler (reader , writer , ** kw ):
469+ auth = self .auth or b' \x01 '
470+ try :
471+ assert auth == (await reader .read_n (len (auth )))
472+ except Exception :
473+ return
471474 await self .conn .put ((reader , writer ))
472- if self .unix :
473- return asyncio .start_unix_server (handler , path = self .bind )
474- else :
475- return asyncio .start_server (handler , host = self .host_name , port = self .port , reuse_port = args .get ('ruport' ))
475+ return self .backward .start_server (args , handler )
476476
477477
478478def compile_rule (filename ):
@@ -548,18 +548,20 @@ def proxy_by_uri(uri):
548548 else :
549549 auth = url .fragment .encode ()
550550 users = [i .rstrip () for i in auth .split (b'\n ' )] if auth else None
551- params = dict (protos = protos , cipher = cipher , users = users , rule = url .query , bind = loc or urlpath ,
552- host_name = host_name , port = port , unix = not loc , lbind = lbind , sslclient = sslclient , sslserver = sslserver )
553551 if 'direct' in protonames :
554552 return ProxyDirect (lbind = lbind )
555- elif 'in' in rawprotos :
556- return ProxyBackward (rawprotos .count ('in' ), ** params )
557- elif 'quic' in rawprotos :
558- return ProxyQUIC (quicserver , quicclient , ** params )
559- elif 'ssh' in protonames :
560- return ProxySSH (** params )
561553 else :
562- return ProxySimple (** params )
554+ params = dict (protos = protos , cipher = cipher , users = users , rule = url .query , bind = loc or urlpath ,
555+ host_name = host_name , port = port , unix = not loc , lbind = lbind , sslclient = sslclient , sslserver = sslserver )
556+ if 'quic' in rawprotos :
557+ proxy = ProxyQUIC (quicserver , quicclient , ** params )
558+ elif 'ssh' in protonames :
559+ proxy = ProxySSH (** params )
560+ else :
561+ proxy = ProxySimple (** params )
562+ if 'in' in rawprotos :
563+ proxy = ProxyBackward (proxy , rawprotos .count ('in' ), ** params )
564+ return proxy
563565
564566async def test_url (url , rserver ):
565567 url = urllib .parse .urlparse (url )
@@ -626,6 +628,8 @@ def main():
626628 for option in args .listen + args .rserver :
627629 if isinstance (option , ProxyQUIC ):
628630 option .quicserver .load_cert_chain (* sslfile )
631+ if isinstance (option , ProxyBackward ) and isinstance (option .backward , ProxyQUIC ):
632+ option .backward .quicserver .load_cert_chain (* sslfile )
629633 elif any (map (lambda o : o .sslclient or isinstance (o , ProxyQUIC ), args .listen )):
630634 print ('You must specify --ssl to listen in ssl mode' )
631635 return
0 commit comments