Skip to content

Commit 9b208ca

Browse files
committed
PYTHON-829 Call ismaster on each new connection.
Call ismaster on each new connection and store the results on the SocketInfo instance. The upcoming Authentication Spec says: "If credentials exist, upon opening a socket, drivers MUST send an isMaster command immediately. This allows a driver to determine whether the server is an Arbiter. Calling ismaster additionally allows the driver to know if the default authentication method for each socket is MONGODB-CR or SCRAM-SHA-1, avoiding races when the driver repopulates the pool after a disconnect." In theory we could choose not to call ismaster if there are no credentials, but it's simpler always to call ismaster, and paves the way for future breaking changes to the wire protocol besides the current breaking change to authentication.
1 parent 38a6711 commit 9b208ca

File tree

7 files changed

+54
-69
lines changed

7 files changed

+54
-69
lines changed

pymongo/monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _check_once(self):
137137
138138
Returns a ServerDescription, or raises an exception.
139139
"""
140-
with self._pool.get_socket({}, 0, 0) as sock_info:
140+
with self._pool.get_socket({}) as sock_info:
141141
response, round_trip_time = self._check_with_socket(sock_info)
142142
self._avg_round_trip_time.add_sample(round_trip_time)
143143
sd = ServerDescription(

pymongo/network.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from pymongo import helpers, message
2121
from pymongo.errors import AutoReconnect
2222

23+
_UNPACK_INT = struct.Struct("<i").unpack
24+
2325

2426
def command(sock, dbname, spec):
2527
"""Execute a command over the socket, or raise socket.error.
@@ -43,15 +45,15 @@ def command(sock, dbname, spec):
4345
def receive_message(sock, operation, request_id):
4446
"""Receive a raw BSON message or raise socket.error."""
4547
header = _receive_data_on_socket(sock, 16)
46-
length = struct.unpack("<i", header[:4])[0]
48+
length = _UNPACK_INT(header[:4])[0]
4749

4850
# No request_id for exhaust cursor "getMore".
4951
if request_id is not None:
50-
response_id = struct.unpack("<i", header[8:12])[0]
52+
response_id = _UNPACK_INT(header[8:12])[0]
5153
assert request_id == response_id, "ids don't match %r %r" % (
5254
request_id, response_id)
5355

54-
assert operation == struct.unpack("<i", header[12:])[0]
56+
assert operation == _UNPACK_INT(header[12:])[0]
5557
return _receive_data_on_socket(sock, length - 16)
5658

5759

pymongo/pool.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from bson.py3compat import u, itervalues
2121
from pymongo import auth, thread_util
2222
from pymongo.errors import ConnectionFailure
23+
from pymongo.ismaster import IsMaster
2324
from pymongo.monotonic import time as _time
2425
from pymongo.network import (command,
2526
receive_message,
@@ -107,16 +108,21 @@ def socket_keepalive(self):
107108

108109

109110
class SocketInfo(object):
110-
"""Store a socket with some metadata."""
111-
def __init__(self, sock, pool, host):
111+
"""Store a socket with some metadata.
112+
113+
:Parameters:
114+
- `sock`: a raw socket object
115+
- `pool`: a Pool instance
116+
- `ismaster`: an IsMaster instance, response to ismaster call on `sock`
117+
- `host`: a string, the server's hostname (without port)
118+
"""
119+
def __init__(self, sock, pool, ismaster, host):
112120
self.sock = sock
113121
self.host = host
114122
self.authset = set()
115123
self.closed = False
116124
self.last_checkout = _time()
117-
118-
self._min_wire_version = None
119-
self._max_wire_version = None
125+
self.ismaster = ismaster
120126

121127
# The pool's pool_id changes with each reset() so we can close sockets
122128
# created before the last reset.
@@ -196,19 +202,13 @@ def close(self):
196202
except:
197203
pass
198204

199-
def set_wire_version_range(self, min_wire_version, max_wire_version):
200-
self._min_wire_version = min_wire_version
201-
self._max_wire_version = max_wire_version
202-
203205
@property
204206
def min_wire_version(self):
205-
assert self._min_wire_version is not None
206-
return self._min_wire_version
207+
return self.ismaster.min_wire_version
207208

208209
@property
209210
def max_wire_version(self):
210-
assert self._max_wire_version is not None
211-
return self._max_wire_version
211+
return self.ismaster.max_wire_version
212212

213213
def __eq__(self, other):
214214
return self.sock == other.sock
@@ -353,26 +353,24 @@ def connect(self):
353353
return_socket() when you're done with it.
354354
"""
355355
sock = _configured_socket(self.address, self.opts)
356-
return SocketInfo(sock, self, host=self.address[0])
356+
try:
357+
ismaster = IsMaster(command(sock, 'admin', {'ismaster': 1}))
358+
except:
359+
sock.close()
360+
raise
361+
362+
return SocketInfo(sock, self, ismaster, host=self.address[0])
357363

358364
@contextlib.contextmanager
359-
def get_socket(
360-
self,
361-
all_credentials,
362-
min_wire_version,
363-
max_wire_version,
364-
checkout=False):
365+
def get_socket(self, all_credentials, checkout=False):
365366
"""Get a socket from the pool. Use with a "with" statement.
366367
367368
Returns a :class:`SocketInfo` object wrapping a connected
368369
:class:`socket.socket`.
369370
370371
This method should always be used in a with-statement::
371372
372-
with pool.get_socket(credentials,
373-
min_wire_version,
374-
max_wire_version,
375-
checkout) as socket_info:
373+
with pool.get_socket(credentials, checkout) as socket_info:
376374
socket_info.send_message(msg)
377375
data = socket_info.receive_message(op_code, request_id)
378376
@@ -382,14 +380,11 @@ def get_socket(
382380
383381
:Parameters:
384382
- `all_credentials`: dict, maps auth source to MongoCredential.
385-
- `min_wire_version`: int, minimum protocol the server supports.
386-
- `max_wire_version`: int, maximum protocol the server supports.
387383
- `checkout` (optional): keep socket checked out.
388384
"""
389385
# First get a socket, then attempt authentication. Simplifies
390386
# semaphore management in the face of network errors during auth.
391387
sock_info = self._get_socket_no_auth()
392-
sock_info.set_wire_version_range(min_wire_version, max_wire_version)
393388
try:
394389
sock_info.check_auth(all_credentials)
395390
yield sock_info

pymongo/server.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,7 @@ def send_message_with_response(
103103

104104
@contextlib.contextmanager
105105
def get_socket(self, all_credentials, checkout=False):
106-
sd = self.description
107-
with self.pool.get_socket(all_credentials,
108-
sd.min_wire_version,
109-
sd.max_wire_version,
110-
checkout) as sock_info:
106+
with self.pool.get_socket(all_credentials, checkout) as sock_info:
111107
yield sock_info
112108

113109
@property

test/pymongo_mocks.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,7 @@ def __init__(self, client, pair, *args, **kwargs):
4242
PoolOptions(connect_timeout=20))
4343

4444
@contextlib.contextmanager
45-
def get_socket(
46-
self,
47-
all_credentials,
48-
min_wire_version,
49-
max_wire_version,
50-
checkout=False):
45+
def get_socket(self, all_credentials, checkout=False):
5146
client = self.client
5247
host_and_port = '%s:%s' % (self.mock_host, self.mock_port)
5348
if host_and_port in client.mock_down_hosts:
@@ -58,10 +53,7 @@ def get_socket(
5853
+ client.mock_members
5954
+ client.mock_mongoses), "bad host: %s" % host_and_port
6055

61-
with Pool.get_socket(self,
62-
all_credentials,
63-
min_wire_version,
64-
max_wire_version) as sock_info:
56+
with Pool.get_socket(self, all_credentials) as sock_info:
6557
sock_info.mock_host = self.mock_host
6658
sock_info.mock_port = self.mock_port
6759
yield sock_info

test/test_pooling.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def run_mongo_thread(self):
124124
self.state = 'get_socket'
125125

126126
# Pass 'checkout' so we can hold the socket.
127-
with self.pool.get_socket({}, 0, 0, checkout=True) as sock:
127+
with self.pool.get_socket({}, checkout=True) as sock:
128128
self.sock = sock
129129

130130
self.state = 'sock'
@@ -189,10 +189,10 @@ def test_pool_reuses_open_socket(self):
189189
# Test Pool's _check_closed() method doesn't close a healthy socket.
190190
cx_pool = self.create_pool(max_pool_size=10)
191191
cx_pool._check_interval_seconds = 0 # Always check.
192-
with cx_pool.get_socket({}, 0, 0) as sock_info:
192+
with cx_pool.get_socket({}) as sock_info:
193193
pass
194194

195-
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
195+
with cx_pool.get_socket({}) as new_sock_info:
196196
self.assertEqual(sock_info, new_sock_info)
197197

198198
self.assertEqual(1, len(cx_pool.sockets))
@@ -201,11 +201,11 @@ def test_get_socket_and_exception(self):
201201
# get_socket() returns socket after a non-network error.
202202
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
203203
with self.assertRaises(ZeroDivisionError):
204-
with cx_pool.get_socket({}, 0, 0) as sock_info:
204+
with cx_pool.get_socket({}) as sock_info:
205205
1 / 0
206206

207207
# Socket was returned, not closed.
208-
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
208+
with cx_pool.get_socket({}) as new_sock_info:
209209
self.assertEqual(sock_info, new_sock_info)
210210

211211
self.assertEqual(1, len(cx_pool.sockets))
@@ -214,7 +214,7 @@ def test_pool_removes_closed_socket(self):
214214
# Test that Pool removes explicitly closed socket.
215215
cx_pool = self.create_pool()
216216

217-
with cx_pool.get_socket({}, 0, 0) as sock_info:
217+
with cx_pool.get_socket({}) as sock_info:
218218
# Use SocketInfo's API to close the socket.
219219
sock_info.close()
220220

@@ -226,25 +226,25 @@ def test_pool_removes_dead_socket(self):
226226
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
227227
cx_pool._check_interval_seconds = 0 # Always check.
228228

229-
with cx_pool.get_socket({}, 0, 0) as sock_info:
229+
with cx_pool.get_socket({}) as sock_info:
230230
# Simulate a closed socket without telling the SocketInfo it's
231231
# closed.
232232
sock_info.sock.close()
233233
self.assertTrue(socket_closed(sock_info.sock))
234234

235-
with cx_pool.get_socket({}, 0, 0) as new_sock_info:
235+
with cx_pool.get_socket({}) as new_sock_info:
236236
self.assertEqual(0, len(cx_pool.sockets))
237237
self.assertNotEqual(sock_info, new_sock_info)
238238

239239
self.assertEqual(1, len(cx_pool.sockets))
240240

241241
# Semaphore was released.
242-
with cx_pool.get_socket({}, 0, 0):
242+
with cx_pool.get_socket({}):
243243
pass
244244

245245
def test_return_socket_after_reset(self):
246246
pool = self.create_pool()
247-
with pool.get_socket({}, 0, 0) as sock:
247+
with pool.get_socket({}) as sock:
248248
pool.reset()
249249

250250
self.assertTrue(sock.closed)
@@ -258,20 +258,20 @@ def test_pool_check(self):
258258
wait_queue_timeout=1)
259259
cx_pool._check_interval_seconds = 0 # Always check.
260260

261-
with cx_pool.get_socket({}, 0, 0) as sock_info:
261+
with cx_pool.get_socket({}) as sock_info:
262262
# Simulate a closed socket without telling the SocketInfo it's
263263
# closed.
264264
sock_info.sock.close()
265265

266266
# Swap pool's address with a bad one.
267267
address, cx_pool.address = cx_pool.address, ('foo.com', 1234)
268268
with self.assertRaises(socket.error):
269-
with cx_pool.get_socket({}, 0, 0):
269+
with cx_pool.get_socket({}):
270270
pass
271271

272272
# Back to normal, semaphore was correctly released.
273273
cx_pool.address = address
274-
with cx_pool.get_socket({}, 0, 0, checkout=True):
274+
with cx_pool.get_socket({}, checkout=True):
275275
pass
276276

277277
def test_pool_with_fork(self):
@@ -326,18 +326,18 @@ def loop(pipe):
326326
self.assertTrue(b_sock != c_sock)
327327

328328
# a_sock, created by parent process, is still in the pool
329-
with get_pool(a).get_socket({}, 0, 0) as d_sock:
329+
with get_pool(a).get_socket({}) as d_sock:
330330
self.assertEqual(a_sock, d_sock)
331331

332332
def test_wait_queue_timeout(self):
333333
wait_queue_timeout = 2 # Seconds
334334
pool = self.create_pool(
335335
max_pool_size=1, wait_queue_timeout=wait_queue_timeout)
336336

337-
with pool.get_socket({}, 0, 0) as sock_info:
337+
with pool.get_socket({}) as sock_info:
338338
start = time.time()
339339
with self.assertRaises(ConnectionFailure):
340-
with pool.get_socket({}, 0, 0):
340+
with pool.get_socket({}):
341341
pass
342342

343343
duration = time.time() - start
@@ -353,7 +353,7 @@ def test_no_wait_queue_timeout(self):
353353
pool = self.create_pool(max_pool_size=1)
354354

355355
# Reach max_size.
356-
with pool.get_socket({}, 0, 0) as s1:
356+
with pool.get_socket({}) as s1:
357357
t = SocketGetter(self.c, pool)
358358
t.start()
359359
while t.state != 'get_socket':
@@ -375,8 +375,8 @@ def test_wait_queue_multiple(self):
375375
max_pool_size=2, wait_queue_multiple=wait_queue_multiple)
376376

377377
# Reach max_size sockets.
378-
with pool.get_socket({}, 0, 0):
379-
with pool.get_socket({}, 0, 0):
378+
with pool.get_socket({}):
379+
with pool.get_socket({}):
380380

381381
# Reach max_size * wait_queue_multiple waiters.
382382
threads = []
@@ -390,7 +390,7 @@ def test_wait_queue_multiple(self):
390390
self.assertEqual(t.state, 'get_socket')
391391

392392
with self.assertRaises(ExceededMaxWaiters):
393-
with pool.get_socket({}, 0, 0):
393+
with pool.get_socket({}):
394394
pass
395395

396396
def test_no_wait_queue_multiple(self):
@@ -399,7 +399,7 @@ def test_no_wait_queue_multiple(self):
399399
socks = []
400400
for _ in range(2):
401401
# Pass 'checkout' so we can hold the socket.
402-
with pool.get_socket({}, 0, 0, checkout=True) as sock:
402+
with pool.get_socket({}, checkout=True) as sock:
403403
socks.append(sock)
404404

405405
threads = []
@@ -498,7 +498,7 @@ def test_max_pool_size_with_connection_failure(self):
498498
# socket from pool" instead of the socket.error.
499499
for i in range(2):
500500
with self.assertRaises(socket.error):
501-
with test_pool.get_socket({}, 0, 0, checkout=True):
501+
with test_pool.get_socket({}, checkout=True):
502502
pass
503503

504504

test/test_topology.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self, *args, **kwargs):
5656
self.pool_id = 0
5757
self._lock = threading.Lock()
5858

59-
def get_socket(self, all_credentials, min_wire_version, max_wire_version):
59+
def get_socket(self, all_credentials):
6060
return MockSocketInfo()
6161

6262
def return_socket(self, _):

0 commit comments

Comments
 (0)