1- use super :: socket:: { self , PySocketRef } ;
1+ use super :: socket:: { self , PySocket } ;
22use crate :: common:: lock:: { PyRwLock , PyRwLockWriteGuard } ;
33use crate :: {
44 builtins:: { pytype, weakref:: PyWeak , PyStrRef , PyTypeRef } ,
@@ -592,7 +592,7 @@ impl PySslContext {
592592 }
593593 }
594594
595- let stream = ssl:: SslStream :: new ( ssl, args. sock . clone ( ) )
595+ let stream = ssl:: SslStream :: new ( ssl, SocketStream ( args. sock . clone ( ) ) )
596596 . map_err ( |e| convert_openssl_error ( vm, e) ) ?;
597597
598598 // TODO: use this
@@ -611,7 +611,7 @@ impl PySslContext {
611611#[ derive( FromArgs ) ]
612612struct WrapSocketArgs {
613613 #[ pyarg( any) ]
614- sock : PySocketRef ,
614+ sock : PyRef < PySocket > ,
615615 #[ pyarg( any) ]
616616 server_side : bool ,
617617 #[ pyarg( any, default ) ]
@@ -647,8 +647,8 @@ struct SocketTimeout {
647647 deadline : Result < Instant , bool > ,
648648}
649649impl SocketTimeout {
650- fn get ( s : & socket :: PySocket ) -> Self {
651- let deadline = s. get_timeout ( ) . map ( |d| Instant :: now ( ) + d) ;
650+ fn get ( s : & SocketStream ) -> Self {
651+ let deadline = s. 0 . get_timeout ( ) . map ( |d| Instant :: now ( ) + d) ;
652652 Self { deadline }
653653 }
654654}
@@ -659,8 +659,8 @@ enum SelectRet {
659659 Closed ,
660660 Ok ,
661661}
662- fn ssl_select ( sock : & socket :: PySocket , needs : SslNeeds , timeout : & SocketTimeout ) -> SelectRet {
663- let sock = match sock. sock_opt ( ) {
662+ fn ssl_select ( sock : & SocketStream , needs : SslNeeds , timeout : & SocketTimeout ) -> SelectRet {
663+ let sock = match sock. 0 . sock_opt ( ) {
664664 Some ( s) => s,
665665 None => return SelectRet :: Closed ,
666666 } ;
@@ -693,7 +693,7 @@ enum SslNeeds {
693693
694694fn socket_needs (
695695 err : & ssl:: Error ,
696- sock : & socket :: PySocket ,
696+ sock : & SocketStream ,
697697 timeout : & SocketTimeout ,
698698) -> ( Option < SslNeeds > , SelectRet ) {
699699 let needs = match err. code ( ) {
@@ -716,7 +716,7 @@ fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
716716#[ derive( PyValue ) ]
717717struct PySslSocket {
718718 ctx : PyRef < PySslContext > ,
719- stream : PyRwLock < ssl:: SslStream < PySocketRef > > ,
719+ stream : PyRwLock < ssl:: SslStream < SocketStream > > ,
720720 socket_type : SslServerOrClient ,
721721 server_hostname : Option < PyStrRef > ,
722722 owner : PyRwLock < Option < PyWeak > > ,
@@ -819,7 +819,7 @@ impl PySslSocket {
819819 Ok ( ( ) ) => return Ok ( ( ) ) ,
820820 Err ( e) => e,
821821 } ;
822- let ( needs, state) = socket_needs ( & err, stream. get_ref ( ) , & timeout) ;
822+ let ( needs, state) = socket_needs ( & err, & stream. get_ref ( ) , & timeout) ;
823823 match state {
824824 SelectRet :: TimedOut => {
825825 return Err ( socket:: timeout_error_msg (
@@ -1237,3 +1237,19 @@ fn extend_module_platform_specific(module: &PyObjectRef, vm: &VirtualMachine) {
12371237
12381238#[ cfg( not( windows) ) ]
12391239fn extend_module_platform_specific ( _module : & PyObjectRef , _vm : & VirtualMachine ) { }
1240+
1241+ struct SocketStream ( PyRef < PySocket > ) ;
1242+
1243+ impl std:: io:: Read for SocketStream {
1244+ fn read ( & mut self , buf : & mut [ u8 ] ) -> std:: io:: Result < usize > {
1245+ <& socket2:: Socket as std:: io:: Read >:: read ( & mut & * self . 0 . sock_io ( ) ?, buf)
1246+ }
1247+ }
1248+ impl std:: io:: Write for SocketStream {
1249+ fn write ( & mut self , buf : & [ u8 ] ) -> std:: io:: Result < usize > {
1250+ <& socket2:: Socket as std:: io:: Write >:: write ( & mut & * self . 0 . sock_io ( ) ?, buf)
1251+ }
1252+ fn flush ( & mut self ) -> std:: io:: Result < ( ) > {
1253+ <& socket2:: Socket as std:: io:: Write >:: flush ( & mut & * self . 0 . sock_io ( ) ?)
1254+ }
1255+ }
0 commit comments