Skip to content

Commit 656aa1e

Browse files
ShaneHarveyajdavis
authored andcommitted
Pin transactions to a single server address
1 parent 116d2c2 commit 656aa1e

File tree

6 files changed

+72
-42
lines changed

6 files changed

+72
-42
lines changed

pymongo/bulk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def execute(self, write_concern, session):
427427

428428
client = self.collection.database.client
429429
if not write_concern.acknowledged:
430-
with client._socket_for_writes() as sock_info:
430+
with client._socket_for_writes(session) as sock_info:
431431
self.execute_no_results(sock_info, generator)
432432
else:
433433
return self.execute_command(generator, write_concern, session)

pymongo/client_session.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def __init__(self, client, server_session, options, authset):
118118
self._cluster_time = None
119119
self._operation_time = None
120120
self._current_txn_read_pref = None
121+
self._current_txn_address = None
121122
if self.options.auto_start_transaction:
122123
# TODO: Get transaction options from self.options.
123124
self._current_transaction_opts = TransactionOptions()
@@ -240,6 +241,8 @@ def _finish_transaction(self, command_name):
240241
finally:
241242
self._server_session.reset_transaction()
242243
self._current_transaction_opts = None
244+
self._current_txn_address = None
245+
self._current_txn_read_pref = None
243246

244247
def _advance_cluster_time(self, cluster_time):
245248
"""Internal cluster time helper."""
@@ -295,6 +298,10 @@ def in_transaction(self):
295298
"""True if this session has an active multi-statement transaction."""
296299
return self._current_transaction_opts is not None
297300

301+
def _pin_server_address(self, address):
302+
assert self._current_txn_address is None, "Transaction already pinned"
303+
self._current_txn_address = address
304+
298305
def _apply_to(self, command, is_retryable, read_preference):
299306
self._check_ended()
300307

pymongo/collection.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,16 @@ def __init__(self, database, name, create=False, codec_options=None,
184184
unicode_decode_error_handler='replace',
185185
document_class=dict)
186186

187-
def _socket_for_reads(self):
188-
return self.__database.client._socket_for_reads(self.read_preference)
187+
def _socket_for_reads(self, session):
188+
return self.__database.client._socket_for_reads(
189+
self.read_preference, session)
189190

190-
def _socket_for_primary_reads(self):
191-
return self.__database.client._socket_for_reads(ReadPreference.PRIMARY)
191+
def _socket_for_primary_reads(self, session):
192+
return self.__database.client._socket_for_reads(
193+
ReadPreference.PRIMARY, session)
192194

193-
def _socket_for_writes(self):
194-
return self.__database.client._socket_for_writes()
195+
def _socket_for_writes(self, session):
196+
return self.__database.client._socket_for_writes(session)
195197

196198
def _command(self, sock_info, command, slave_ok=False,
197199
read_preference=None,
@@ -252,7 +254,7 @@ def __create(self, options, collation, session):
252254
if "size" in options:
253255
options["size"] = float(options["size"])
254256
cmd.update(options)
255-
with self._socket_for_writes() as sock_info:
257+
with self._socket_for_writes(session) as sock_info:
256258
self._command(
257259
sock_info, cmd, read_preference=ReadPreference.PRIMARY,
258260
write_concern=self.write_concern,
@@ -579,7 +581,7 @@ def _insert_command(session, sock_info, retryable_write):
579581
True, _insert_command, session)
580582
_check_write_command_response(result)
581583
else:
582-
with self._socket_for_writes() as sock_info:
584+
with self._socket_for_writes(session=None) as sock_info:
583585
# Legacy OP_INSERT.
584586
self._legacy_write(
585587
sock_info, 'insert', command, op_id,
@@ -1493,7 +1495,7 @@ def parallel_scan(self, num_cursors, session=None, **kwargs):
14931495
('numCursors', num_cursors)])
14941496
cmd.update(kwargs)
14951497

1496-
with self._socket_for_reads() as (sock_info, slave_ok):
1498+
with self._socket_for_reads(session) as (sock_info, slave_ok):
14971499
result = self._command(sock_info, cmd, slave_ok,
14981500
read_concern=self.read_concern,
14991501
session=session)
@@ -1509,7 +1511,7 @@ def parallel_scan(self, num_cursors, session=None, **kwargs):
15091511

15101512
def _count(self, cmd, collation=None, session=None):
15111513
"""Internal count helper."""
1512-
with self._socket_for_reads() as (sock_info, slave_ok):
1514+
with self._socket_for_reads(session) as (sock_info, slave_ok):
15131515
res = self._command(
15141516
sock_info, cmd, slave_ok,
15151517
allowable_errors=["ns missing"],
@@ -1606,7 +1608,7 @@ def create_indexes(self, indexes, session=None, **kwargs):
16061608
"""
16071609
common.validate_list('indexes', indexes)
16081610
names = []
1609-
with self._socket_for_writes() as sock_info:
1611+
with self._socket_for_writes(session) as sock_info:
16101612
supports_collations = sock_info.max_wire_version >= 5
16111613
def gen_indexes():
16121614
for index in indexes:
@@ -1647,7 +1649,7 @@ def __create_index(self, keys, index_options, session, **kwargs):
16471649
index_options.pop('collation', None))
16481650
index.update(index_options)
16491651

1650-
with self._socket_for_writes() as sock_info:
1652+
with self._socket_for_writes(session) as sock_info:
16511653
if collation is not None:
16521654
if sock_info.max_wire_version < 5:
16531655
raise ConfigurationError(
@@ -1874,7 +1876,7 @@ def drop_index(self, index_or_name, session=None, **kwargs):
18741876
self.__database.name, self.__name, name)
18751877
cmd = SON([("dropIndexes", self.__name), ("index", name)])
18761878
cmd.update(kwargs)
1877-
with self._socket_for_writes() as sock_info:
1879+
with self._socket_for_writes(session) as sock_info:
18781880
self._command(sock_info,
18791881
cmd,
18801882
read_preference=ReadPreference.PRIMARY,
@@ -1911,7 +1913,7 @@ def reindex(self, session=None, **kwargs):
19111913
"""
19121914
cmd = SON([("reIndex", self.__name)])
19131915
cmd.update(kwargs)
1914-
with self._socket_for_writes() as sock_info:
1916+
with self._socket_for_writes(session) as sock_info:
19151917
return self._command(
19161918
sock_info, cmd, read_preference=ReadPreference.PRIMARY,
19171919
parse_write_concern_error=True, session=session)
@@ -1940,7 +1942,7 @@ def list_indexes(self, session=None):
19401942
codec_options = CodecOptions(SON)
19411943
coll = self.with_options(codec_options=codec_options,
19421944
read_preference=ReadPreference.PRIMARY)
1943-
with self._socket_for_primary_reads() as (sock_info, slave_ok):
1945+
with self._socket_for_primary_reads(session) as (sock_info, slave_ok):
19441946
cmd = SON([("listIndexes", self.__name), ("cursor", {})])
19451947
if sock_info.max_wire_version > 2:
19461948
with self.__database.client._tmp_session(session, False) as s:
@@ -2061,7 +2063,7 @@ def _aggregate(self, pipeline, cursor_class, first_batch_size, session,
20612063
"batchSize", kwargs.pop("batchSize", None))
20622064
# If the server does not support the "cursor" option we
20632065
# ignore useCursor and batchSize.
2064-
with self._socket_for_reads() as (sock_info, slave_ok):
2066+
with self._socket_for_reads(session) as (sock_info, slave_ok):
20652067
dollar_out = pipeline and '$out' in pipeline[-1]
20662068
if use_cursor:
20672069
if "cursor" not in kwargs:
@@ -2350,7 +2352,7 @@ def group(self, key, condition, initial, reduce, finalize=None, **kwargs):
23502352
collation = validate_collation_or_none(kwargs.pop('collation', None))
23512353
cmd.update(kwargs)
23522354

2353-
with self._socket_for_reads() as (sock_info, slave_ok):
2355+
with self._socket_for_reads(session=None) as (sock_info, slave_ok):
23542356
return self._command(sock_info, cmd, slave_ok,
23552357
collation=collation)["retval"]
23562358

@@ -2396,7 +2398,7 @@ def rename(self, new_name, session=None, **kwargs):
23962398

23972399
new_name = "%s.%s" % (self.__database.name, new_name)
23982400
cmd = SON([("renameCollection", self.__full_name), ("to", new_name)])
2399-
with self._socket_for_writes() as sock_info:
2401+
with self._socket_for_writes(session) as sock_info:
24002402
with self.__database.client._tmp_session(session) as s:
24012403
if sock_info.max_wire_version >= 5 and self.write_concern:
24022404
cmd['writeConcern'] = self.write_concern.document
@@ -2451,7 +2453,7 @@ def distinct(self, key, filter=None, session=None, **kwargs):
24512453
kwargs["query"] = filter
24522454
collation = validate_collation_or_none(kwargs.pop('collation', None))
24532455
cmd.update(kwargs)
2454-
with self._socket_for_reads() as (sock_info, slave_ok):
2456+
with self._socket_for_reads(session) as (sock_info, slave_ok):
24552457
return self._command(sock_info, cmd, slave_ok,
24562458
read_concern=self.read_concern,
24572459
collation=collation, session=session)["values"]
@@ -2523,7 +2525,7 @@ def map_reduce(self, map, reduce, out, full_response=False, session=None,
25232525
cmd.update(kwargs)
25242526

25252527
inline = 'inline' in cmd['out']
2526-
with self._socket_for_primary_reads() as (sock_info, slave_ok):
2528+
with self._socket_for_primary_reads(session) as (sock_info, slave_ok):
25272529
if (sock_info.max_wire_version >= 5 and self.write_concern and
25282530
not inline):
25292531
cmd['writeConcern'] = self.write_concern.document
@@ -2592,7 +2594,7 @@ def inline_map_reduce(self, map, reduce, full_response=False, session=None,
25922594
("out", {"inline": 1})])
25932595
collation = validate_collation_or_none(kwargs.pop('collation', None))
25942596
cmd.update(kwargs)
2595-
with self._socket_for_reads() as (sock_info, slave_ok):
2597+
with self._socket_for_reads(session) as (sock_info, slave_ok):
25962598
if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd:
25972599
res = self._command(sock_info, cmd, slave_ok,
25982600
read_concern=self.read_concern,

pymongo/database.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,8 @@ def command(self, command, value=1, check=True,
526526
.. mongodoc:: commands
527527
"""
528528
client = self.__client
529-
with client._socket_for_reads(read_preference) as (sock_info, slave_ok):
529+
with client._socket_for_reads(
530+
read_preference, session) as (sock_info, slave_ok):
530531
return self._command(sock_info, command, slave_ok, value,
531532
check, allowable_errors, read_preference,
532533
codec_options, session=session, **kwargs)
@@ -584,7 +585,7 @@ def list_collections(self, session=None, **kwargs):
584585
.. versionadded:: 3.6
585586
"""
586587
with self.__client._socket_for_reads(
587-
ReadPreference.PRIMARY) as (sock_info, slave_okay):
588+
ReadPreference.PRIMARY, session) as (sock_info, slave_okay):
588589
return self._list_collections(
589590
sock_info, slave_okay, session=session, **kwargs)
590591

@@ -649,7 +650,7 @@ def drop_collection(self, name_or_collection, session=None):
649650
self.__client._purge_index(self.__name, name)
650651

651652
with self.__client._socket_for_reads(
652-
ReadPreference.PRIMARY) as (sock_info, slave_ok):
653+
ReadPreference.PRIMARY, session) as (sock_info, slave_ok):
653654
return self._command(
654655
sock_info, 'drop', slave_ok, _unicode(name),
655656
allowable_errors=['ns not found'],
@@ -730,7 +731,7 @@ def current_op(self, include_all=False, session=None):
730731
Added ``session`` parameter.
731732
"""
732733
cmd = SON([("currentOp", 1), ("$all", include_all)])
733-
with self.__client._socket_for_writes() as sock_info:
734+
with self.__client._socket_for_writes(session) as sock_info:
734735
if sock_info.max_wire_version >= 4:
735736
with self.__client._tmp_session(session) as s:
736737
return sock_info.command("admin", cmd, session=s,

pymongo/mongo_client.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,8 @@ def _end_sessions(self, session_ids):
871871
# Use SocketInfo.command directly to avoid implicitly creating
872872
# another session.
873873
with self._socket_for_reads(
874-
ReadPreference.PRIMARY_PREFERRED) as (sock_info, slave_ok):
874+
ReadPreference.PRIMARY_PREFERRED,
875+
None) as (sock_info, slave_ok):
875876
if not sock_info.supports_sessions:
876877
return
877878

@@ -967,13 +968,31 @@ def _get_socket(self, server):
967968
self.__reset_server(server.description.address)
968969
raise
969970

970-
def _socket_for_writes(self):
971-
server = self._get_topology().select_server(writable_server_selector)
972-
return self._get_socket(server)
971+
def _select_server(self, read_preference, session):
972+
topology = self._get_topology()
973+
if session and session.in_transaction:
974+
if session._current_txn_address:
975+
server = topology.select_server_by_address(
976+
session._current_txn_address)
977+
if not server:
978+
raise AutoReconnect(
979+
'Pinned server %s:%d for transaction no longer'
980+
'available' % session._current_txn_address)
981+
return server
982+
983+
server = topology.select_server(read_preference)
984+
session._pin_server_address(server.description.address)
985+
return server
986+
else:
987+
return topology.select_server(read_preference)
988+
989+
def _socket_for_writes(self, session):
990+
return self._get_socket(self._select_server(
991+
ReadPreference.PRIMARY, session))
973992

974993
@contextlib.contextmanager
975-
def _socket_for_reads(self, read_preference):
976-
preference = read_preference or ReadPreference.PRIMARY
994+
def _socket_for_reads(self, read_preference, session):
995+
assert read_preference is not None, "read_preference must not be None"
977996
# Get a socket for a server matching the read preference, and yield
978997
# sock_info, slave_ok. Server Selection Spec: "slaveOK must be sent to
979998
# mongods with topology type Single. If the server type is Mongos,
@@ -982,10 +1001,11 @@ def _socket_for_reads(self, read_preference):
9821001
# Thread safe: if the type is single it cannot change.
9831002
topology = self._get_topology()
9841003
single = topology.description.topology_type == TOPOLOGY_TYPE.Single
985-
server = topology.select_server(read_preference)
1004+
server = self._select_server(read_preference, session)
1005+
9861006
with self._get_socket(server) as sock_info:
9871007
slave_ok = (single and not sock_info.is_mongos) or (
988-
preference != ReadPreference.PRIMARY)
1008+
read_preference != ReadPreference.PRIMARY)
9891009
yield sock_info, slave_ok
9901010

9911011
def _send_message_with_response(self, operation, read_preference=None,
@@ -1005,14 +1025,14 @@ def _send_message_with_response(self, operation, read_preference=None,
10051025
self._kill_cursors_executor.open()
10061026

10071027
topology = self._get_topology()
1028+
session = operation.session
10081029
if address:
10091030
server = topology.select_server_by_address(address)
10101031
if not server:
10111032
raise AutoReconnect('server %s:%d no longer available'
10121033
% address)
10131034
else:
1014-
selector = read_preference or writable_server_selector
1015-
server = topology.select_server(selector)
1035+
server = self._select_server(read_preference, session)
10161036

10171037
# A _Query's slaveOk bit is already set for queries with non-primary
10181038
# read preference. If this is a direct connection to a mongod, override
@@ -1064,8 +1084,7 @@ def is_retrying():
10641084
return bulk.retrying if bulk else retrying
10651085
while True:
10661086
try:
1067-
server = self._get_topology().select_server(
1068-
writable_server_selector)
1087+
server = self._select_server(ReadPreference.PRIMARY, session)
10691088
supports_session = (
10701089
session is not None and
10711090
server.description.retryable_writes_supported)
@@ -1539,7 +1558,7 @@ def drop_database(self, name_or_database, session=None):
15391558

15401559
self._purge_index(name)
15411560
with self._socket_for_reads(
1542-
ReadPreference.PRIMARY) as (sock_info, slave_ok):
1561+
ReadPreference.PRIMARY, None) as (sock_info, slave_ok):
15431562
self[name]._command(
15441563
sock_info,
15451564
"dropDatabase",
@@ -1681,7 +1700,7 @@ def unlock(self, session=None):
16811700
Added ``session`` parameter.
16821701
"""
16831702
cmd = SON([("fsyncUnlock", 1)])
1684-
with self._socket_for_writes() as sock_info:
1703+
with self._socket_for_writes(session=None) as sock_info:
16851704
if sock_info.max_wire_version >= 4:
16861705
try:
16871706
with self._tmp_session(session) as s:

test/test_read_preferences.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,9 @@ def __init__(self, *args, **kwargs):
315315
super(ReadPrefTester, self).__init__(*args, **client_options)
316316

317317
@contextlib.contextmanager
318-
def _socket_for_reads(self, read_preference):
319-
context = super(ReadPrefTester, self)._socket_for_reads(read_preference)
318+
def _socket_for_reads(self, read_preference, session):
319+
context = super(ReadPrefTester, self)._socket_for_reads(
320+
read_preference, session)
320321
with context as (sock_info, slave_ok):
321322
self.record_a_read(sock_info.address)
322323
yield sock_info, slave_ok

0 commit comments

Comments
 (0)