@@ -2,9 +2,10 @@ use std::io::Cursor;
22use std:: net:: { Ipv4Addr , Ipv6Addr , SocketAddr } ;
33use std:: sync:: Arc ;
44use std:: time:: SystemTime ;
5- use log:: { debug, info} ;
5+ use log:: { debug, info, warn } ;
66use socket2:: { Socket , Domain , Type , Protocol } ;
77use tokio:: net:: UdpSocket ;
8+ use tokio:: sync:: mpsc;
89use crate :: stats:: enums:: stats_event:: StatsEvent ;
910use crate :: tracker:: enums:: torrent_peers_type:: TorrentPeersType ;
1011use crate :: tracker:: structs:: announce_query_request:: AnnounceQueryRequest ;
@@ -24,6 +25,7 @@ use crate::udp::structs::connection_id::ConnectionId;
2425use crate :: udp:: structs:: error_response:: ErrorResponse ;
2526use crate :: udp:: structs:: number_of_downloads:: NumberOfDownloads ;
2627use crate :: udp:: structs:: number_of_peers:: NumberOfPeers ;
28+ use crate :: udp:: structs:: packet_job:: PacketJob ;
2729use crate :: udp:: structs:: port:: Port ;
2830use crate :: udp:: structs:: response_peer:: ResponsePeer ;
2931use crate :: udp:: structs:: scrape_request:: ScrapeRequest ;
@@ -35,7 +37,16 @@ use crate::udp::udp::MAX_SCRAPE_TORRENTS;
3537
3638impl UdpServer {
3739 #[ tracing:: instrument( level = "debug" ) ]
38- pub async fn new ( tracker : Arc < TorrentTracker > , bind_address : SocketAddr , threads : u64 , recv_buffer_size : usize , send_buffer_size : usize , reuse_address : bool ) -> tokio:: io:: Result < UdpServer >
40+ pub async fn new (
41+ tracker : Arc < TorrentTracker > ,
42+ bind_address : SocketAddr ,
43+ recv_buffer_size : usize ,
44+ send_buffer_size : usize ,
45+ reuse_address : bool ,
46+ receiver_threads : usize ,
47+ worker_threads : usize ,
48+ queue_size : usize
49+ ) -> tokio:: io:: Result < UdpServer >
3950 {
4051 let domain = if bind_address. is_ipv4 ( ) { Domain :: IPV4 } else { Domain :: IPV6 } ;
4152 let socket = Socket :: new ( domain, Type :: DGRAM , Some ( Protocol :: UDP ) ) ?;
@@ -52,58 +63,126 @@ impl UdpServer {
5263
5364 Ok ( UdpServer {
5465 socket : Arc :: new ( tokio_socket) ,
55- threads,
5666 tracker,
67+ receiver_threads : receiver_threads as u64 ,
68+ worker_threads : worker_threads as u64 ,
69+ queue_size : queue_size as u64
5770 } )
5871 }
5972
6073 #[ tracing:: instrument( level = "debug" ) ]
61- pub async fn start ( & self , rx : tokio:: sync:: watch:: Receiver < bool > )
62- {
63- let threads = self . threads ;
64- for _index in 0 ..=threads {
74+ pub async fn start ( & self , rx : tokio:: sync:: watch:: Receiver < bool > ) {
75+ let ( packet_tx, packet_rx) = mpsc:: channel :: < PacketJob > ( self . queue_size as usize ) ;
76+ let packet_rx = Arc :: new ( tokio:: sync:: Mutex :: new ( packet_rx) ) ;
77+
78+ let receiver_threads = self . receiver_threads as usize ;
79+ let worker_threads = self . worker_threads as usize ;
80+
81+ for thread_id in 0 ..receiver_threads {
6582 let socket_clone = self . socket . clone ( ) ;
6683 let tracker = self . tracker . clone ( ) ;
67- let mut rx = rx. clone ( ) ;
68- let mut data = [ 0 ; 1496 ] ;
84+ let mut shutdown_rx = rx. clone ( ) ;
85+ let packet_tx = packet_tx. clone ( ) ;
86+
6987 tokio:: spawn ( async move {
88+ info ! ( "Starting UDP receiver thread {}" , thread_id) ;
89+ let mut data = [ 0 ; 1496 ] ;
90+
7091 loop {
71- let udp_sock = socket_clone. local_addr ( ) . unwrap ( ) ;
7292 tokio:: select! {
73- _ = rx . changed( ) => {
74- info!( "Stopping UDP server: {udp_sock }..." ) ;
93+ _ = shutdown_rx . changed( ) => {
94+ info!( "Stopping UDP receiver thread { }..." , thread_id ) ;
7595 break ;
7696 }
7797 Ok ( ( valid_bytes, remote_addr) ) = socket_clone. recv_from( & mut data) => {
78- let payload = & data[ ..valid_bytes] ;
79-
80- debug!( "Received {} bytes from {}" , payload. len( ) , remote_addr) ;
81- debug!( "{payload:?}" ) ;
82-
83- let tracker_cloned = tracker. clone( ) ;
84- let socket_cloned = socket_clone. clone( ) ;
85- // Use payload slice instead of cloning the entire Vec
86- let payload_vec = payload. to_vec( ) ;
87- tokio:: spawn( async move {
88- let response = UdpServer :: handle_packet( remote_addr, payload_vec, tracker_cloned. clone( ) ) . await ;
89- UdpServer :: send_response( tracker_cloned. clone( ) , socket_cloned. clone( ) , remote_addr, response) . await ;
90- } ) ;
98+ let payload = data[ ..valid_bytes] . to_vec( ) ;
99+
100+ debug!( "Receiver {} got {} bytes from {}" , thread_id, payload. len( ) , remote_addr) ;
101+
102+ let job = PacketJob {
103+ data: payload,
104+ remote_addr,
105+ } ;
106+
107+ if let Err ( e) = packet_tx. try_send( job) {
108+ warn!( "Packet queue full, dropping packet: {}" , e) ;
109+ match remote_addr {
110+ SocketAddr :: V4 ( _) => tracker. update_stats( StatsEvent :: Udp4BadRequest , 1 ) ,
111+ SocketAddr :: V6 ( _) => tracker. update_stats( StatsEvent :: Udp6BadRequest , 1 ) ,
112+ } ;
113+ }
114+ }
115+ }
116+ }
117+ } ) ;
118+ }
119+
120+ for thread_id in 0 ..worker_threads {
121+ let socket_clone = self . socket . clone ( ) ;
122+ let tracker = self . tracker . clone ( ) ;
123+ let mut shutdown_rx = rx. clone ( ) ;
124+ let packet_rx = packet_rx. clone ( ) ;
125+
126+ tokio:: spawn ( async move {
127+ info ! ( "Starting UDP worker thread {}" , thread_id) ;
128+
129+ loop {
130+ tokio:: select! {
131+ _ = shutdown_rx. changed( ) => {
132+ info!( "Stopping UDP worker thread {}..." , thread_id) ;
133+ break ;
134+ }
135+ job = async {
136+ let mut rx = packet_rx. lock( ) . await ;
137+ rx. recv( ) . await
138+ } => {
139+ if let Some ( PacketJob { data, remote_addr } ) = job {
140+ debug!( "Worker {} processing packet from {}" , thread_id, remote_addr) ;
141+
142+ let response = UdpServer :: handle_packet(
143+ remote_addr,
144+ data,
145+ tracker. clone( )
146+ ) . await ;
147+
148+ UdpServer :: send_response(
149+ tracker. clone( ) ,
150+ socket_clone. clone( ) ,
151+ remote_addr,
152+ response
153+ ) . await ;
154+ }
91155 }
92156 }
93157 }
94158 } ) ;
95159 }
160+
161+ drop ( packet_tx) ;
96162 }
97163
164+ // Optimized send_response with pre-sized buffer
98165 #[ tracing:: instrument( level = "debug" ) ]
99- pub async fn send_response ( tracker : Arc < TorrentTracker > , socket : Arc < UdpSocket > , remote_addr : SocketAddr , response : Response )
100- {
166+ pub async fn send_response (
167+ tracker : Arc < TorrentTracker > ,
168+ socket : Arc < UdpSocket > ,
169+ remote_addr : SocketAddr ,
170+ response : Response
171+ ) {
101172 debug ! ( "sending response to: {:?}" , & remote_addr) ;
102173 let sentry = sentry:: TransactionContext :: new ( "udp server" , "send response" ) ;
103174 let transaction = sentry:: start_transaction ( sentry) ;
104175
105- // Pre-allocate buffer with exact capacity instead of MAX_PACKET_SIZE
106- let mut buffer = Vec :: with_capacity ( 512 ) ; // Most responses are much smaller than MAX_PACKET_SIZE
176+ // Optimize buffer allocation based on response type
177+ let estimated_size = match & response {
178+ Response :: Connect ( _) => 16 ,
179+ Response :: AnnounceIpv4 ( _) => 20 + 6 * 72 , // header + max IPv4 peers (6 bytes each)
180+ Response :: AnnounceIpv6 ( _) => 20 + 18 * 72 , // header + max IPv6 peers (18 bytes each)
181+ Response :: Scrape ( _) => 8 + 12 * 74 , // header + max torrents
182+ Response :: Error ( _) => 128 , // reasonable max for error message
183+ } ;
184+
185+ let mut buffer = Vec :: with_capacity ( estimated_size) ;
107186 let mut cursor = Cursor :: new ( & mut buffer) ;
108187
109188 match response. write ( & mut cursor) {
@@ -125,6 +204,7 @@ impl UdpServer {
125204 transaction. finish ( ) ;
126205 }
127206
207+ // Rest of the methods remain the same...
128208 #[ tracing:: instrument( level = "debug" ) ]
129209 pub async fn send_packet ( socket : Arc < UdpSocket > , remote_addr : & SocketAddr , payload : & [ u8 ] ) {
130210 let _ = socket. send_to ( payload, remote_addr) . await ;
@@ -195,7 +275,7 @@ impl UdpServer {
195275 #[ tracing:: instrument( level = "debug" ) ]
196276 pub async fn handle_udp_connect ( remote_addr : SocketAddr , request : & ConnectRequest , tracker : Arc < TorrentTracker > ) -> Result < Response , ServerError > {
197277 let connection_id = UdpServer :: get_connection_id ( & remote_addr) . await ;
198- let response = Response :: from ( ConnectResponse {
278+ let response = Response :: Connect ( ConnectResponse {
199279 transaction_id : request. transaction_id ,
200280 connection_id
201281 } ) ;
@@ -353,15 +433,15 @@ impl UdpServer {
353433
354434 // Create response
355435 let response = if remote_addr. is_ipv6 ( ) {
356- Response :: from ( AnnounceResponse {
436+ Response :: AnnounceIpv6 ( AnnounceResponse {
357437 transaction_id : request. transaction_id ,
358438 announce_interval : AnnounceInterval ( config. request_interval as i32 ) ,
359439 leechers : NumberOfPeers ( torrent. peers . len ( ) as i32 ) ,
360440 seeders : NumberOfPeers ( torrent. seeds . len ( ) as i32 ) ,
361441 peers : peers6,
362442 } )
363443 } else {
364- Response :: from ( AnnounceResponse {
444+ Response :: AnnounceIpv4 ( AnnounceResponse {
365445 transaction_id : request. transaction_id ,
366446 announce_interval : AnnounceInterval ( config. request_interval as i32 ) ,
367447 leechers : NumberOfPeers ( torrent. peers . len ( ) as i32 ) ,
@@ -408,15 +488,15 @@ impl UdpServer {
408488 } ;
409489 tracker. update_stats ( stats_event, 1 ) ;
410490
411- Ok ( Response :: from ( ScrapeResponse {
491+ Ok ( Response :: Scrape ( ScrapeResponse {
412492 transaction_id : request. transaction_id ,
413493 torrent_stats,
414494 } ) )
415495 }
416496
417497 #[ tracing:: instrument( level = "debug" ) ]
418498 pub async fn handle_udp_error ( e : ServerError , transaction_id : TransactionId ) -> Response {
419- Response :: from ( ErrorResponse {
499+ Response :: Error ( ErrorResponse {
420500 transaction_id,
421501 message : e. to_string ( ) . into ( )
422502 } )
0 commit comments