@@ -1112,8 +1112,8 @@ def _get_topology(self):
11121112 return self ._topology
11131113
11141114 @contextlib .contextmanager
1115- def _get_socket (self , server ):
1116- with self ._reset_on_error (server .description .address ):
1115+ def _get_socket (self , server , session ):
1116+ with self ._reset_on_error (server .description .address , session ):
11171117 with server .get_socket (self .__all_credentials ) as sock_info :
11181118 yield sock_info
11191119
@@ -1128,26 +1128,31 @@ def _select_server(self, server_selector, session, address=None):
11281128 - `address` (optional): Address when sending a message
11291129 to a specific server, used for getMore.
11301130 """
1131- topology = self ._get_topology ()
1132- address = address or (session and session ._pinned_address )
1133- if address :
1134- # We're running a getMore or this session is pinned to a mongos.
1135- server = topology .select_server_by_address (address )
1136- if not server :
1137- raise AutoReconnect ('server %s:%d no longer available'
1138- % address )
1139- else :
1140- server = topology .select_server (server_selector )
1141- # Pin this session to the selected server if it's performing a
1142- # sharded transaction.
1143- if server .description .mongos and (session and
1144- session ._in_transaction ):
1145- session ._pin_mongos (server )
1146- return server
1131+ try :
1132+ topology = self ._get_topology ()
1133+ address = address or (session and session ._pinned_address )
1134+ if address :
1135+ # We're running a getMore or this session is pinned to a mongos.
1136+ server = topology .select_server_by_address (address )
1137+ if not server :
1138+ raise AutoReconnect ('server %s:%d no longer available'
1139+ % address )
1140+ else :
1141+ server = topology .select_server (server_selector )
1142+ # Pin this session to the selected server if it's performing a
1143+ # sharded transaction.
1144+ if server .description .mongos and (session and
1145+ session ._in_transaction ):
1146+ session ._pin_mongos (server )
1147+ return server
1148+ except PyMongoError as exc :
1149+ if session and exc .has_error_label ("TransientTransactionError" ):
1150+ session ._unpin_mongos ()
1151+ raise
11471152
11481153 def _socket_for_writes (self , session ):
11491154 server = self ._select_server (writable_server_selector , session )
1150- return self ._get_socket (server )
1155+ return self ._get_socket (server , session )
11511156
11521157 @contextlib .contextmanager
11531158 def _socket_for_reads (self , read_preference , session ):
@@ -1162,7 +1167,7 @@ def _socket_for_reads(self, read_preference, session):
11621167 single = topology .description .topology_type == TOPOLOGY_TYPE .Single
11631168 server = self ._select_server (read_preference , session )
11641169
1165- with self ._get_socket (server ) as sock_info :
1170+ with self ._get_socket (server , session ) as sock_info :
11661171 slave_ok = (single and not sock_info .is_mongos ) or (
11671172 read_preference != ReadPreference .PRIMARY )
11681173 yield sock_info , slave_ok
@@ -1194,7 +1199,8 @@ def _send_message_with_response(self, operation, exhaust=False,
11941199 and server .description .server_type != SERVER_TYPE .Mongos ) or (
11951200 operation .read_preference != ReadPreference .PRIMARY )
11961201
1197- with self ._reset_on_error (server .description .address ):
1202+ with self ._reset_on_error (server .description .address ,
1203+ operation .session ):
11981204 return server .send_message_with_response (
11991205 operation ,
12001206 set_slave_ok ,
@@ -1203,14 +1209,20 @@ def _send_message_with_response(self, operation, exhaust=False,
12031209 exhaust )
12041210
12051211 @contextlib .contextmanager
1206- def _reset_on_error (self , server_address ):
1212+ def _reset_on_error (self , server_address , session ):
12071213 """On "not master" or "node is recovering" errors reset the server
12081214 according to the SDAM spec.
12091215
12101216 Unpin the session on transient transaction errors.
12111217 """
12121218 try :
1213- yield
1219+ try :
1220+ yield
1221+ except PyMongoError as exc :
1222+ if session and exc .has_error_label (
1223+ "TransientTransactionError" ):
1224+ session ._unpin_mongos ()
1225+ raise
12141226 except NetworkTimeout :
12151227 # The socket has been closed. Don't reset the server.
12161228 # Server Discovery And Monitoring Spec: "When an application
@@ -1264,7 +1276,7 @@ def is_retrying():
12641276 supports_session = (
12651277 session is not None and
12661278 server .description .retryable_writes_supported )
1267- with self ._get_socket (server ) as sock_info :
1279+ with self ._get_socket (server , session ) as sock_info :
12681280 if retryable and not supports_session :
12691281 if is_retrying ():
12701282 # A retry is not possible because this server does
@@ -1674,12 +1686,10 @@ def _send_cluster_time(self, command, session):
16741686 if cluster_time :
16751687 command ['$clusterTime' ] = cluster_time
16761688
1677- def _receive_cluster_time (self , reply , session ):
1678- cluster_time = reply .get ('$clusterTime' )
1679- self ._topology .receive_cluster_time (cluster_time )
1689+ def _process_response (self , reply , session ):
1690+ self ._topology .receive_cluster_time (reply .get ('$clusterTime' ))
16801691 if session is not None :
1681- session ._advance_cluster_time (cluster_time )
1682- session ._advance_operation_time (reply .get ("operationTime" ))
1692+ session ._process_response (reply )
16831693
16841694 def server_info (self , session = None ):
16851695 """Get information about the MongoDB server we're connected to.
0 commit comments