@@ -71,7 +71,7 @@ async def stream_handler(reader, writer, unix, lbind, protos, rserver, cipher, s
7171 remote_text = f'{ remote_ip } :{ remote_port } '
7272 local_addr = None if server_ip in ('127.0.0.1' , '::1' , None ) else (server_ip , 0 )
7373 reader_cipher , _ = await prepare_ciphers (cipher , reader , writer , server_side = False )
74- lproto , user , host_name , port , lbuf , rbuf = await proto .accept (protos , reader = reader , writer = writer , authtable = AuthTable (remote_ip , authtime ), reader_cipher = reader_cipher , sock = writer .get_extra_info ('socket' ), ** kwargs )
74+ lproto , user , host_name , port , client_connected = await proto .accept (protos , reader = reader , writer = writer , authtable = AuthTable (remote_ip , authtime ), reader_cipher = reader_cipher , sock = writer .get_extra_info ('socket' ), ** kwargs )
7575 if host_name == 'echo' :
7676 asyncio .ensure_future (lproto .channel (reader , writer , DUMMY , DUMMY ))
7777 elif host_name == 'empty' :
@@ -87,13 +87,12 @@ async def stream_handler(reader, writer, unix, lbind, protos, rserver, cipher, s
8787 raise Exception (f'Connection timeout { roption .bind } ' )
8888 try :
8989 reader_remote , writer_remote = await roption .prepare_connection (reader_remote , writer_remote , host_name , port )
90- writer .write (lbuf )
91- writer_remote .write (rbuf )
90+ use_http = (await client_connected (writer_remote )) if client_connected else None
9291 except Exception :
9392 writer_remote .close ()
9493 raise Exception ('Unknown remote protocol' )
9594 m = modstat (user , remote_ip , host_name )
96- lchannel = lproto .http_channel if rbuf else lproto .channel
95+ lchannel = lproto .http_channel if use_http else lproto .channel
9796 asyncio .ensure_future (lproto .channel (reader_remote , writer , m (2 + roption .direct ), m (4 + roption .direct )))
9897 asyncio .ensure_future (lchannel (reader , writer_remote , m (roption .direct ), roption .connection_change ))
9998 except Exception as ex :
@@ -304,6 +303,126 @@ def start_server(self, args, stream_handler=stream_handler):
304303 else :
305304 return asyncio .start_server (handler , host = self .host_name , port = self .port , reuse_port = args .get ('ruport' ))
306305
306+ class ProxyH2 (ProxySimple ):
307+ def __init__ (self , sslserver , sslclient , ** kw ):
308+ super ().__init__ (sslserver = None , sslclient = None , ** kw )
309+ self .handshake = None
310+ self .h2sslserver = sslserver
311+ self .h2sslclient = sslclient
312+ async def handler (self , reader , writer , client_side = True , stream_handler = None , ** kw ):
313+ import h2 .connection , h2 .config , h2 .events
314+ reader , writer = proto .sslwrap (reader , writer , self .h2sslclient if client_side else self .h2sslserver , not client_side , None )
315+ config = h2 .config .H2Configuration (client_side = client_side )
316+ conn = h2 .connection .H2Connection (config = config )
317+ streams = {}
318+ conn .initiate_connection ()
319+ writer .write (conn .data_to_send ())
320+ while not reader .at_eof () and not writer .is_closing ():
321+ try :
322+ data = await reader .read (65636 )
323+ if not data :
324+ break
325+ events = conn .receive_data (data )
326+ except Exception :
327+ pass
328+ writer .write (conn .data_to_send ())
329+ for event in events :
330+ if isinstance (event , h2 .events .RequestReceived ) and not client_side :
331+ if event .stream_id not in streams :
332+ stream_reader , stream_writer = self .get_stream (conn , writer , event .stream_id )
333+ streams [event .stream_id ] = (stream_reader , stream_writer )
334+ asyncio .ensure_future (stream_handler (stream_reader , stream_writer ))
335+ else :
336+ stream_reader , stream_writer = streams [event .stream_id ]
337+ stream_writer .headers .set_result (event .headers )
338+ elif isinstance (event , h2 .events .SettingsAcknowledged ) and client_side :
339+ self .handshake .set_result ((conn , streams , writer ))
340+ elif isinstance (event , h2 .events .DataReceived ):
341+ stream_reader , stream_writer = streams [event .stream_id ]
342+ stream_reader .feed_data (event .data )
343+ conn .acknowledge_received_data (len (event .data ), event .stream_id )
344+ writer .write (conn .data_to_send ())
345+ elif isinstance (event , h2 .events .StreamEnded ) or isinstance (event , h2 .events .StreamReset ):
346+ stream_reader , stream_writer = streams [event .stream_id ]
347+ stream_reader .feed_eof ()
348+ if not stream_writer .closed :
349+ stream_writer .close ()
350+ elif isinstance (event , h2 .events .ConnectionTerminated ):
351+ break
352+ elif isinstance (event , h2 .events .WindowUpdated ):
353+ if event .stream_id in streams :
354+ stream_reader , stream_writer = streams [event .stream_id ]
355+ stream_writer .window_update ()
356+ writer .write (conn .data_to_send ())
357+ writer .close ()
358+ def get_stream (self , conn , writer , stream_id ):
359+ reader = asyncio .StreamReader ()
360+ write_buffer = bytearray ()
361+ write_wait = asyncio .Event ()
362+ write_full = asyncio .Event ()
363+ class StreamWriter ():
364+ def __init__ (self ):
365+ self .closed = False
366+ self .headers = asyncio .get_event_loop ().create_future ()
367+ def get_extra_info (self , key ):
368+ return writer .get_extra_info (key )
369+ def write (self , data ):
370+ write_buffer .extend (data )
371+ write_wait .set ()
372+ def drain (self ):
373+ writer .write (conn .data_to_send ())
374+ return writer .drain ()
375+ def is_closing (self ):
376+ return self .closed
377+ def close (self ):
378+ self .closed = True
379+ write_wait .set ()
380+ def window_update (self ):
381+ write_full .set ()
382+ def send_headers (self , headers ):
383+ conn .send_headers (stream_id , headers )
384+ writer .write (conn .data_to_send ())
385+ stream_writer = StreamWriter ()
386+ async def write_job ():
387+ while not stream_writer .closed :
388+ while len (write_buffer ) > 0 :
389+ while conn .local_flow_control_window (stream_id ) <= 0 :
390+ write_full .clear ()
391+ await write_full .wait ()
392+ if stream_writer .closed :
393+ break
394+ chunk_size = min (conn .local_flow_control_window (stream_id ), len (write_buffer ), conn .max_outbound_frame_size )
395+ conn .send_data (stream_id , write_buffer [:chunk_size ])
396+ writer .write (conn .data_to_send ())
397+ del write_buffer [:chunk_size ]
398+ if not stream_writer .closed :
399+ write_wait .clear ()
400+ await write_wait .wait ()
401+ conn .send_data (stream_id , b'' , end_stream = True )
402+ writer .write (conn .data_to_send ())
403+ asyncio .ensure_future (write_job ())
404+ return reader , stream_writer
405+ async def wait_h2_connection (self , local_addr , family ):
406+ if self .handshake is not None :
407+ if not self .handshake .done ():
408+ await self .handshake
409+ else :
410+ self .handshake = asyncio .get_event_loop ().create_future ()
411+ reader , writer = await super ().wait_open_connection (None , None , local_addr , family )
412+ asyncio .ensure_future (self .handler (reader , writer ))
413+ await self .handshake
414+ return self .handshake .result ()
415+ async def wait_open_connection (self , host , port , local_addr , family ):
416+ conn , streams , writer = await self .wait_h2_connection (local_addr , family )
417+ stream_id = conn .get_next_available_stream_id ()
418+ conn ._begin_new_stream (stream_id , stream_id % 2 )
419+ stream_reader , stream_writer = self .get_stream (conn , writer , stream_id )
420+ streams [stream_id ] = (stream_reader , stream_writer )
421+ return stream_reader , stream_writer
422+ def start_server (self , args , stream_handler = stream_handler ):
423+ handler = functools .partial (stream_handler , ** vars (self ), ** args )
424+ return super ().start_server (args , functools .partial (self .handler , client_side = False , stream_handler = handler ))
425+
307426class ProxyQUIC (ProxySimple ):
308427 def __init__ (self , quicserver , quicclient , ** kw ):
309428 super ().__init__ (** kw )
@@ -544,6 +663,8 @@ def proxies_by_uri(uri_jumps):
544663 jump = proxy_by_uri (uri , jump )
545664 return jump
546665
666+ sslcontexts = []
667+
547668def proxy_by_uri (uri , jump ):
548669 scheme , _ , uri = uri .partition ('://' )
549670 url = urllib .parse .urlparse ('s://' + uri )
@@ -558,17 +679,25 @@ def proxy_by_uri(uri, jump):
558679 if 'ssl' in rawprotos :
559680 sslclient .check_hostname = False
560681 sslclient .verify_mode = ssl .CERT_NONE
682+ sslcontexts .append (sslserver )
683+ sslcontexts .append (sslclient )
561684 else :
562685 sslserver = sslclient = None
563686 if 'quic' in rawprotos :
564687 try :
565688 import ssl , aioquic .quic .configuration
566689 except Exception :
567690 raise Exception ('Missing library: "pip3 install aioquic"' )
568- import logging
569691 quicserver = aioquic .quic .configuration .QuicConfiguration (is_client = False )
570692 quicclient = aioquic .quic .configuration .QuicConfiguration ()
571693 quicclient .verify_mode = ssl .CERT_NONE
694+ sslcontexts .append (quicserver )
695+ sslcontexts .append (quicclient )
696+ if 'h2' in rawprotos :
697+ try :
698+ import h2
699+ except Exception :
700+ raise Exception ('Missing library: "pip3 install h2"' )
572701 protonames = [i .name for i in protos ]
573702 urlpath , _ , plugins = url .path .partition (',' )
574703 urlpath , _ , lbind = urlpath .partition ('@' )
@@ -611,6 +740,8 @@ def proxy_by_uri(uri, jump):
611740 host_name = host_name , port = port , unix = not loc , lbind = lbind , sslclient = sslclient , sslserver = sslserver )
612741 if 'quic' in rawprotos :
613742 proxy = ProxyQUIC (quicserver , quicclient , ** params )
743+ elif 'h2' in protonames :
744+ proxy = ProxyH2 (** params )
614745 elif 'ssh' in protonames :
615746 proxy = ProxySSH (** params )
616747 else :
@@ -646,7 +777,7 @@ async def test_url(url, rserver):
646777 print (headers .decode ()[:- 4 ])
647778 print (f'--------------------------------' )
648779 body = bytearray ()
649- while 1 :
780+ while not reader . at_eof () :
650781 s = await reader .read (65536 )
651782 if not s :
652783 break
@@ -677,15 +808,8 @@ def main():
677808 args = parser .parse_args ()
678809 if args .sslfile :
679810 sslfile = args .sslfile .split (',' )
680- for option in args .listen :
681- if option .sslclient :
682- option .sslclient .load_cert_chain (* sslfile )
683- option .sslserver .load_cert_chain (* sslfile )
684- for option in args .listen + args .ulisten + args .rserver + args .urserver :
685- if isinstance (option , ProxyQUIC ):
686- option .quicserver .load_cert_chain (* sslfile )
687- if isinstance (option , ProxyBackward ) and isinstance (option .backward , ProxyQUIC ):
688- option .backward .quicserver .load_cert_chain (* sslfile )
811+ for context in sslcontexts :
812+ context .load_cert_chain (* sslfile )
689813 elif any (map (lambda o : o .sslclient or isinstance (o , ProxyQUIC ), args .listen + args .ulisten )):
690814 print ('You must specify --ssl to listen in ssl mode' )
691815 return
0 commit comments