1+ use std:: sync:: Arc ;
12use std:: time:: Duration ;
23
34use tokio:: {
45 fs,
56 net:: UdpSocket ,
7+ sync:: Notify ,
68 time:: timeout,
79} ;
810
@@ -14,12 +16,13 @@ use pasque::{
1416 } ,
1517 stream:: {
1618 filestream:: { FileStream , Files } ,
17- udptunnel:: { UdpEndpoint , UdpTunnel } , PsqStream ,
19+ udptunnel:: { UdpEndpoint , UdpTunnel } ,
20+ PsqStream ,
1821 } , test_utils:: init_logger, PsqError
1922} ;
2023
2124
22- #[ tokio:: test( flavor = "multi_thread" , worker_threads = 3 ) ]
25+ #[ tokio:: test]
2326async fn test_get_request ( ) {
2427 init_logger ( ) ;
2528 let addr = "127.0.0.1:8888" ;
@@ -79,83 +82,124 @@ async fn test_get_request() {
7982}
8083
8184
82- #[ tokio:: test( flavor = "multi_thread" , worker_threads = 3 ) ]
85+ async fn run_server ( addr : & str , shutdown : Arc < Notify > ) {
86+ let config = Config :: create_default ( ) ;
87+ let mut psqserver = PsqServer :: start ( addr, & config) . await . unwrap ( ) ;
88+ psqserver. add_endpoint (
89+ "udp" ,
90+ UdpEndpoint :: new ( ) . unwrap ( )
91+ ) . await ;
92+ loop {
93+ tokio:: select! {
94+ _ = shutdown. notified( ) => {
95+ break ;
96+ }
97+ result = psqserver. process( ) => {
98+ result. unwrap( ) ;
99+ }
100+ }
101+ }
102+ }
103+
104+
105+ async fn run_client ( mut psqclient : PsqClient , shutdown : Arc < Notify > ) {
106+ loop {
107+ tokio:: select! {
108+ _ = shutdown. notified( ) => {
109+ break ;
110+ }
111+ result = psqclient. process( ) => {
112+ result. unwrap( ) ;
113+ }
114+ }
115+ }
116+ }
117+
118+
119+ async fn run_udpserver ( udpsocket : UdpSocket , shutdown : Arc < Notify > ) {
120+ loop {
121+ let mut buf = [ 0u8 ; 2000 ] ;
122+ tokio:: select! {
123+ _ = shutdown. notified( ) => {
124+ break ;
125+ }
126+ result = udpsocket. recv_from( & mut buf) => {
127+ let ( n, addr) = result. unwrap( ) ;
128+ udpsocket. send_to( & buf[ ..n] , addr) . await . unwrap( ) ;
129+ }
130+ }
131+ }
132+ }
133+
134+
135+ #[ tokio:: test]
83136async fn test_udp_tunnel ( ) {
84137 init_logger ( ) ;
85- let addr = "127.0.0.1:9000" ;
86- let server = tokio:: spawn ( async move {
87- let config = Config :: create_default ( ) ;
88- let mut psqserver = PsqServer :: start ( addr, & config) . await . unwrap ( ) ;
89- psqserver. add_endpoint (
90- "udp" ,
91- UdpEndpoint :: new ( ) . unwrap ( )
92- ) . await ;
93- loop {
94- psqserver. process ( ) . await . unwrap ( ) ;
95- }
96- } ) ;
138+ let addr = "127.0.0.1:7000" ;
139+ let server_notify = Arc :: new ( Notify :: new ( ) ) ;
140+ let server = tokio:: spawn ( run_server ( addr, server_notify. clone ( ) ) ) ;
97141
98142 tokio:: time:: sleep ( Duration :: from_millis ( 100 ) ) . await ;
99143
100144 // Run client
101- let mut psqconn = PsqClient :: connect (
145+ let mut psqclient = PsqClient :: connect (
102146 format ! ( "https://{}/" , addr) . as_str ( ) ,
103147 true ,
104148 ) . await . unwrap ( ) ;
105149
106150 // Test first with GET which should not be supported on UDP tunnel.
107151 let ret = FileStream :: get (
108- & mut psqconn ,
152+ & mut psqclient ,
109153 "udp" ,
110154 "testout" ,
111155 ) . await ;
112156 assert ! ( matches!( ret, Err ( PsqError :: HttpResponse ( 405 , _) ) ) ) ;
113157
114158 let udptunnel = UdpTunnel :: connect (
115- & mut psqconn ,
159+ & mut psqclient ,
116160 "udp" ,
117161 "127.0.0.1" ,
118- 9000 ,
162+ 9002 ,
119163 "127.0.0.1:0" . parse ( ) . unwrap ( ) ,
120164 ) . await . unwrap ( ) ;
121165 let tunneladdr = udptunnel. sockaddr ( ) . unwrap ( ) ;
122166
123- let client1 = tokio:: spawn ( async move {
124- loop {
125- psqconn. process ( ) . await . unwrap ( ) ;
126- }
127- } ) ;
167+ let client_notify = Arc :: new ( Notify :: new ( ) ) ;
168+ let client = tokio:: spawn ( run_client ( psqclient, client_notify. clone ( ) ) ) ;
128169
129170 // Start UDP server
130- let udpsocket = UdpSocket :: bind ( "127.0.0.1:9001 " ) . await . unwrap ( ) ;
171+ let udpsocket = UdpSocket :: bind ( "127.0.0.1:9002 " ) . await . unwrap ( ) ;
131172
132- let udpserver = tokio:: spawn ( async move {
133- loop {
134- let mut buf = [ 0u8 ; 2000 ] ;
135- let ( n, addr) = udpsocket. recv_from ( & mut buf) . await . unwrap ( ) ;
136- udpsocket. send_to ( & buf[ ..n] , addr) . await . unwrap ( ) ;
137- }
138- } ) ;
173+ let udpserver_notify = Arc :: new ( Notify :: new ( ) ) ;
174+ let udpserver = tokio:: spawn ( run_udpserver ( udpsocket, udpserver_notify. clone ( ) ) ) ;
139175
140176 tokio:: time:: sleep ( Duration :: from_millis ( 100 ) ) . await ;
141177
142178 // Send UDP datagram to the client socket
143- let udpclient = UdpSocket :: bind ( "0.0.0.0:0" ) . await . unwrap ( ) ;
144- let mut buf = [ 0u8 ; 2000 ] ;
145- udpclient. send_to ( b"Testing" , tunneladdr) . await . unwrap ( ) ;
146- let ( n, _) = udpclient. recv_from ( & mut buf) . await . unwrap ( ) ;
147- assert_eq ! ( & buf[ ..n] , b"Testing" ) ;
148-
149- udpserver. abort ( ) ;
150- client1. abort ( ) ;
151- server. abort ( ) ;
179+ let result = timeout ( Duration :: from_millis ( 1000 ) , async {
180+ let udpclient = UdpSocket :: bind ( "0.0.0.0:0" ) . await . unwrap ( ) ;
181+ let mut buf = [ 0u8 ; 2000 ] ;
182+ udpclient. send_to ( b"Testing" , tunneladdr) . await . unwrap ( ) ;
183+ let ( n, _) = udpclient. recv_from ( & mut buf) . await . unwrap ( ) ;
184+ assert_eq ! ( & buf[ ..n] , b"Testing" ) ;
185+ } ) . await ;
186+ assert ! ( result. is_ok( ) , "Test timed out" ) ;
187+
188+ udpserver_notify. notify_one ( ) ;
189+ udpserver. await . unwrap ( ) ;
190+
191+ client_notify. notify_one ( ) ;
192+ client. await . unwrap ( ) ;
193+
194+ server_notify. notify_one ( ) ;
195+ server. await . unwrap ( ) ;
152196}
153197
154198
155- #[ tokio:: test( flavor = "multi_thread" , worker_threads = 3 ) ]
199+ #[ tokio:: test]
156200async fn tunnel_closing ( ) {
157201 init_logger ( ) ;
158- let addr = "127.0.0.1:9002 " ;
202+ let addr = "127.0.0.1:9003 " ;
159203 let server = tokio:: spawn ( async move {
160204 let config = Config :: create_default ( ) ;
161205 let mut psqserver = PsqServer :: start ( addr, & config) . await . unwrap ( ) ;
@@ -182,7 +226,7 @@ async fn tunnel_closing() {
182226 & mut psqconn,
183227 "udp" ,
184228 "127.0.0.1" ,
185- 9000 ,
229+ 9004 ,
186230 "127.0.0.1:0" . parse ( ) . unwrap ( ) ,
187231 ) . await . unwrap ( ) ;
188232 let tunneladdr = udptunnel. sockaddr ( ) . unwrap ( ) ;
0 commit comments