Skip to content

Commit 09fb14a

Browse files
committed
quic backward
1 parent 46d78c8 commit 09fb14a

File tree

1 file changed

+35
-31
lines changed

1 file changed

+35
-31
lines changed

pproxy/server.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ async def prepare_ciphers_and_headers(self, reader_remote, writer_remote, host,
294294
whost, wport = self.jump.destination(host, port)
295295
await self.rproto.connect(reader_remote=reader_remote, writer_remote=writer_remote, rauth=self.auth, host_name=whost, port=wport, writer_cipher_r=writer_cipher_r, myhost=self.host_name, sock=writer_remote.get_extra_info('socket'))
296296
return await self.jump.prepare_ciphers_and_headers(reader_remote, writer_remote, host, port)
297-
def start_server(self, args):
297+
def start_server(self, args, stream_handler=stream_handler):
298298
handler = functools.partial(stream_handler, **vars(self), **args)
299299
if self.unix:
300300
return asyncio.start_unix_server(handler, path=self.bind)
@@ -307,15 +307,20 @@ def __init__(self, quicserver, quicclient, **kw):
307307
self.quicserver = quicserver
308308
self.quicclient = quicclient
309309
self.handshake = None
310-
self.streams = {}
311310
def patch_writer(self, writer):
312311
async def drain():
313312
writer._transport.protocol.transmit()
314313
#print('stream_id', writer.get_extra_info("stream_id"))
315314
remote_addr = writer._transport.protocol._quic._network_paths[0].addr
316315
writer.get_extra_info = dict(peername=remote_addr, sockname=remote_addr).get
317316
writer.drain = drain
317+
closed = False
318+
def is_closing():
319+
return closed
320+
writer.is_closing = is_closing
318321
def close():
322+
nonlocal closed
323+
closed = True
319324
try:
320325
writer.write_eof()
321326
except Exception:
@@ -345,7 +350,7 @@ def quic_event_received(s, event):
345350
reader, writer = conn._create_stream(stream_id)
346351
self.patch_writer(writer)
347352
return reader, writer
348-
async def start_server(self, args):
353+
async def start_server(self, args, stream_handler=stream_handler):
349354
import aioquic.asyncio
350355
def handler(reader, writer):
351356
self.patch_writer(writer)
@@ -405,16 +410,17 @@ async def channel():
405410
return await self.jump.prepare_ciphers_and_headers(reader_remote, writer_remote, host, port)
406411

407412
class ProxyBackward(ProxySimple):
408-
def __init__(self, backward, **kw):
413+
def __init__(self, backward, backward_num, **kw):
409414
super().__init__(**kw)
410-
self.backward_num = backward
415+
self.backward = backward
416+
self.backward_num = backward_num
411417
self.closed = False
412418
self.writers = set()
413419
self.conn = asyncio.Queue()
414420
async def wait_open_connection(self, *args):
415421
while True:
416422
reader, writer = await self.conn.get()
417-
if not writer.transport.is_closing():
423+
if not writer.is_closing():
418424
return reader, writer
419425
def close(self):
420426
self.closed = True
@@ -423,24 +429,21 @@ def close(self):
423429
self.writer.close()
424430
except Exception:
425431
pass
426-
async def start_server(self, args):
432+
async def start_server(self, args, stream_handler=stream_handler):
427433
handler = functools.partial(stream_handler, **vars(self), **args)
428434
for _ in range(self.backward_num):
429435
asyncio.ensure_future(self.start_server_run(handler))
430436
return self
431437
async def start_server_run(self, handler):
432438
errwait = 0
433439
while not self.closed:
434-
if self.unix:
435-
wait = asyncio.open_unix_connection(path=self.bind)
436-
else:
437-
wait = asyncio.open_connection(host=self.host_name, port=self.port, local_addr=(self.lbind, 0) if self.lbind else None)
440+
wait = self.backward.open_connection(self.host_name, self.port, self.lbind, None)
438441
try:
439442
reader, writer = await asyncio.wait_for(wait, timeout=SOCKET_TIMEOUT)
440443
if self.closed:
441444
writer.close()
442445
break
443-
writer.write(self.auth)
446+
writer.write(self.auth or b'\x01')
444447
self.writers.add(writer)
445448
try:
446449
data = await reader.read_n(1)
@@ -462,17 +465,14 @@ async def start_server_run(self, handler):
462465
await asyncio.sleep(errwait)
463466
errwait = min(errwait*1.3 + 0.1, 30)
464467
def start_backward_client(self, args):
465-
async def handler(reader, writer):
466-
if self.auth:
467-
try:
468-
assert self.auth == (await reader.read_n(len(self.auth)))
469-
except Exception:
470-
return
468+
async def handler(reader, writer, **kw):
469+
auth = self.auth or b'\x01'
470+
try:
471+
assert auth == (await reader.read_n(len(auth)))
472+
except Exception:
473+
return
471474
await self.conn.put((reader, writer))
472-
if self.unix:
473-
return asyncio.start_unix_server(handler, path=self.bind)
474-
else:
475-
return asyncio.start_server(handler, host=self.host_name, port=self.port, reuse_port=args.get('ruport'))
475+
return self.backward.start_server(args, handler)
476476

477477

478478
def compile_rule(filename):
@@ -548,18 +548,20 @@ def proxy_by_uri(uri):
548548
else:
549549
auth = url.fragment.encode()
550550
users = [i.rstrip() for i in auth.split(b'\n')] if auth else None
551-
params = dict(protos=protos, cipher=cipher, users=users, rule=url.query, bind=loc or urlpath,
552-
host_name=host_name, port=port, unix=not loc, lbind=lbind, sslclient=sslclient, sslserver=sslserver)
553551
if 'direct' in protonames:
554552
return ProxyDirect(lbind=lbind)
555-
elif 'in' in rawprotos:
556-
return ProxyBackward(rawprotos.count('in'), **params)
557-
elif 'quic' in rawprotos:
558-
return ProxyQUIC(quicserver, quicclient, **params)
559-
elif 'ssh' in protonames:
560-
return ProxySSH(**params)
561553
else:
562-
return ProxySimple(**params)
554+
params = dict(protos=protos, cipher=cipher, users=users, rule=url.query, bind=loc or urlpath,
555+
host_name=host_name, port=port, unix=not loc, lbind=lbind, sslclient=sslclient, sslserver=sslserver)
556+
if 'quic' in rawprotos:
557+
proxy = ProxyQUIC(quicserver, quicclient, **params)
558+
elif 'ssh' in protonames:
559+
proxy = ProxySSH(**params)
560+
else:
561+
proxy = ProxySimple(**params)
562+
if 'in' in rawprotos:
563+
proxy = ProxyBackward(proxy, rawprotos.count('in'), **params)
564+
return proxy
563565

564566
async def test_url(url, rserver):
565567
url = urllib.parse.urlparse(url)
@@ -626,6 +628,8 @@ def main():
626628
for option in args.listen+args.rserver:
627629
if isinstance(option, ProxyQUIC):
628630
option.quicserver.load_cert_chain(*sslfile)
631+
if isinstance(option, ProxyBackward) and isinstance(option.backward, ProxyQUIC):
632+
option.backward.quicserver.load_cert_chain(*sslfile)
629633
elif any(map(lambda o: o.sslclient or isinstance(o, ProxyQUIC), args.listen)):
630634
print('You must specify --ssl to listen in ssl mode')
631635
return

0 commit comments

Comments
 (0)