@@ -1000,7 +1000,8 @@ def _end_sessions(self, session_ids):
10001000 # Use SocketInfo.command directly to avoid implicitly creating
10011001 # another session.
10021002 with self ._socket_for_reads (
1003- ReadPreference .PRIMARY_PREFERRED ) as (sock_info , slave_ok ):
1003+ ReadPreference .PRIMARY_PREFERRED ,
1004+ None ) as (sock_info , slave_ok ):
10041005 if not sock_info .supports_sessions :
10051006 return
10061007
@@ -1105,12 +1106,40 @@ def _get_socket(self, server):
11051106 self .__reset_server (server .description .address )
11061107 raise
11071108
1108- def _socket_for_writes (self ):
1109- server = self ._get_topology ().select_server (writable_server_selector )
1109+ def _select_server (self , server_selector , session , address = None ):
1110+ """Select a server to run an operation on this client.
1111+
1112+ :Parameters:
1113+ - `server_selector`: The server selector to use if the session is
1114+ not pinned and no address is given.
1115+ - `session`: The ClientSession for the next operation, or None. May
1116+ be pinned to a mongos server address.
1117+ - `address` (optional): Address when sending a message
1118+ to a specific server, used for getMore.
1119+ """
1120+ topology = self ._get_topology ()
1121+ address = address or (session and session ._pinned_address )
1122+ if address :
1123+ # We're running a getMore or this session is pinned to a mongos.
1124+ server = topology .select_server_by_address (address )
1125+ if not server :
1126+ raise AutoReconnect ('server %s:%d no longer available'
1127+ % address )
1128+ else :
1129+ server = topology .select_server (server_selector )
1130+ # Pin this session to the selected server if it's performing a
1131+ # sharded transaction.
1132+ if server .description .mongos and (session and
1133+ session ._in_transaction ):
1134+ session ._pin_mongos (server )
1135+ return server
1136+
1137+ def _socket_for_writes (self , session ):
1138+ server = self ._select_server (writable_server_selector , session )
11101139 return self ._get_socket (server )
11111140
11121141 @contextlib .contextmanager
1113- def _socket_for_reads (self , read_preference ):
1142+ def _socket_for_reads (self , read_preference , session ):
11141143 assert read_preference is not None , "read_preference must not be None"
11151144 # Get a socket for a server matching the read preference, and yield
11161145 # sock_info, slave_ok. Server Selection Spec: "slaveOK must be sent to
@@ -1120,7 +1149,7 @@ def _socket_for_reads(self, read_preference):
11201149 # Thread safe: if the type is single it cannot change.
11211150 topology = self ._get_topology ()
11221151 single = topology .description .topology_type == TOPOLOGY_TYPE .Single
1123- server = topology . select_server (read_preference )
1152+ server = self . _select_server (read_preference , session )
11241153
11251154 with self ._get_socket (server ) as sock_info :
11261155 slave_ok = (single and not sock_info .is_mongos ) or (
@@ -1139,14 +1168,9 @@ def _send_message_with_response(self, operation, exhaust=False,
11391168 - `address` (optional): Optional address when sending a message
11401169 to a specific server, used for getMore.
11411170 """
1171+ server = self ._select_server (
1172+ operation .read_preference , operation .session , address = address )
11421173 topology = self ._get_topology ()
1143- if address :
1144- server = topology .select_server_by_address (address )
1145- if not server :
1146- raise AutoReconnect ('server %s:%d no longer available'
1147- % address )
1148- else :
1149- server = topology .select_server (operation .read_preference )
11501174
11511175 # If this is a direct connection to a mongod, *always* set the slaveOk
11521176 # bit. See bullet point 2 in server-selection.rst#topology-type-single.
@@ -1206,8 +1230,7 @@ def is_retrying():
12061230
12071231 while True :
12081232 try :
1209- server = self ._get_topology ().select_server (
1210- writable_server_selector )
1233+ server = self ._select_server (writable_server_selector , session )
12111234 supports_session = (
12121235 session is not None and
12131236 server .description .retryable_writes_supported )
@@ -1736,7 +1759,7 @@ def drop_database(self, name_or_database, session=None):
17361759 "of %s or a Database" % (string_type .__name__ ,))
17371760
17381761 self ._purge_index (name )
1739- with self ._socket_for_writes () as sock_info :
1762+ with self ._socket_for_writes (session ) as sock_info :
17401763 self [name ]._command (
17411764 sock_info ,
17421765 "dropDatabase" ,
@@ -1877,7 +1900,7 @@ def unlock(self, session=None):
18771900 Added ``session`` parameter.
18781901 """
18791902 cmd = SON ([("fsyncUnlock" , 1 )])
1880- with self ._socket_for_writes () as sock_info :
1903+ with self ._socket_for_writes (session ) as sock_info :
18811904 if sock_info .max_wire_version >= 4 :
18821905 try :
18831906 with self ._tmp_session (session ) as s :
0 commit comments