Skip to content

Commit 0726f21

Browse files
author
Mike Dirolf
committed
Rework connection handling, and add support for replSets
1 parent 60a2d52 commit 0726f21

File tree

6 files changed

+76
-69
lines changed

6 files changed

+76
-69
lines changed

pymongo/connection.py

Lines changed: 65 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ def __init__(self, host=None, port=None, pool_size=None,
186186
database = db or database
187187
username = u or username
188188
password = p or password
189+
if not nodes:
190+
raise ConfigurationError("need to specify at least one host")
189191
self.__nodes = nodes
190192
if database and username is None:
191193
raise InvalidURI("cannot specify database without "
@@ -206,6 +208,9 @@ def __init__(self, host=None, port=None, pool_size=None,
206208
self.__port = None
207209

208210
self.__slave_okay = slave_okay
211+
if slave_okay and len(self.__nodes) > 1:
212+
raise ConfigurationError("cannot specify slave_okay for a paired "
213+
"or replica set connection")
209214

210215
self.__cursor_manager = CursorManager(self)
211216

@@ -309,16 +314,13 @@ def paired(cls, left, right=None, **connection_args):
309314
The remaining keyword arguments are the same as those accepted
310315
by :meth:`~Connection`.
311316
"""
317+
if isinstance(left, str) or isinstance(right, str):
318+
raise TypeError("arguments to paired must be tuples")
312319
if right is None:
313320
right = (cls.HOST, cls.PORT)
314321
return cls([":".join(map(str, left)), ":".join(map(str, right))],
315322
**connection_args)
316323

317-
def __master(self, sock):
318-
"""Is this socket connected to a master server?
319-
"""
320-
return self["admin"].command("ismaster", _sock=sock)["ismaster"]
321-
322324
def _cache_index(self, database, collection, index, ttl):
323325
"""Add an index to the index cache for ensure_index operations.
324326
@@ -420,63 +422,79 @@ def tz_aware(self):
420422
"""
421423
return self.__tz_aware
422424

425+
def __add_hosts_and_get_primary(self, response):
426+
if "hosts" in response:
427+
self.__nodes.update([h.split(":") for h in response["hosts"]])
428+
return response.get("primary", False)
429+
430+
def __try_node(self, node):
431+
self.disconnect()
432+
self.__host, self.__port = node
433+
try:
434+
response = self.admin.command("isMaster")
435+
self.end_request()
436+
437+
primary = self.__add_hosts_and_get_primary(response)
438+
if response["ismaster"]:
439+
return True
440+
return primary
441+
except:
442+
self.end_request()
443+
return None
444+
423445
def __find_master(self):
424446
"""Create a new socket and use it to figure out who the master is.
425447
426448
Sets __host and __port so that :attr:`host` and :attr:`port`
427449
will return the address of the master. Also (possibly) updates
428450
any replSet information.
429451
"""
430-
self.__host = None
431-
self.__port = None
432-
sock = None
433-
sock_error = False
434-
close = True
435-
452+
# Special case the first node to try to get the primary or any
453+
# additional hosts from a replSet:
454+
first = iter(self.__nodes).next()
455+
456+
primary = self.__try_node(first)
457+
if primary is True:
458+
return first
459+
if self.__slave_okay and primary is not None: # no network error
460+
return first
461+
462+
# Wasn't the first node, but we got a primary - let's try it:
463+
tried = [first]
464+
if primary:
465+
if self.__try_node(primary) is True:
466+
return primary
467+
tried.append(primary)
468+
469+
nodes = self.__nodes - set(tried)
470+
471+
# Just scan
436472
# TODO parallelize these to minimize connect time?
437-
for (host, port) in self.__nodes:
438-
try:
439-
try:
440-
sock = socket.socket()
441-
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
442-
sock.settimeout(self.__network_timeout or _CONNECT_TIMEOUT)
443-
sock.connect((host, port))
444-
sock.settimeout(self.__network_timeout)
445-
master = self.__master(sock)
446-
if master or self.__slave_okay:
447-
self.__host = host
448-
self.__port = port
449-
self.__pool.return_unowned(sock)
450-
close = False
451-
return
452-
except socket.error, e:
453-
sock_error = True
454-
finally:
455-
if sock is not None and close:
456-
sock.close()
457-
if sock_error or self.__host is None:
458-
raise AutoReconnect("could not find master")
459-
raise ConfigurationError("No master node in %r. You must specify "
460-
"slave_okay to connect to "
461-
"slaves." % self.__nodes)
473+
for node in nodes:
474+
if self.__try_node(node) is True:
475+
return node
476+
477+
raise AutoReconnect("could not find master/primary")
462478

463479
def __connect(self):
464480
"""(Re-)connect to Mongo and return a new (connected) socket.
465481
466482
Connect to the master if this is a paired connection.
467483
"""
468-
if self.__host is None or self.__port is None:
469-
self.__find_master()
484+
host, port = (self.__host, self.__port)
485+
if host is None or port is None:
486+
host, port = self.__find_master()
470487

471488
try:
472489
sock = socket.socket()
473490
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
474491
sock.settimeout(self.__network_timeout or _CONNECT_TIMEOUT)
475-
sock.connect((self.__host, self.__port))
492+
sock.connect((host, port))
476493
sock.settimeout(self.__network_timeout)
477494
return sock
478495
except socket.error:
479-
raise AutoReconnect("could not connect to %r" % self.__nodes)
496+
self.disconnect()
497+
raise AutoReconnect("could not connect to %r" % list(self.__nodes))
480498

481499
def disconnect(self):
482500
"""Disconnect from MongoDB.
@@ -492,17 +510,8 @@ def disconnect(self):
492510
.. versionadded:: 1.3
493511
"""
494512
self.__pool = Pool(self.__connect)
495-
496-
def _reset(self):
497-
"""Reset everything and start connecting again.
498-
499-
Closes all open sockets and resets them to None. Re-finds the master.
500-
501-
This should be done in case of a connection failure or a "not master"
502-
error.
503-
"""
504-
self.disconnect()
505-
self.__find_master()
513+
self.__host = None
514+
self.__port = None
506515

507516
def set_cursor_manager(self, manager_class):
508517
"""Set this connection's cursor manager.
@@ -540,7 +549,7 @@ def __check_response_to_last_error(self, response):
540549
if error.get("err", 0) is None:
541550
return error
542551
if error["err"] == "not master":
543-
self._reset()
552+
self.disconnect()
544553

545554
if "code" in error:
546555
if error["code"] in [11000, 11001]:
@@ -580,7 +589,7 @@ def _send_message(self, message, with_last_error=False):
580589
return self.__check_response_to_last_error(response)
581590
return None
582591
except (ConnectionFailure, socket.error), e:
583-
self._reset()
592+
self.disconnect()
584593
raise AutoReconnect(str(e))
585594

586595
def __receive_data_on_socket(self, length, sock):
@@ -642,7 +651,7 @@ def _send_message_with_response(self, message, _sock=None,
642651
return self.__send_and_receive(message, _sock)
643652
except (ConnectionFailure, socket.error), e:
644653
if reset:
645-
self._reset()
654+
self.disconnect()
646655
raise AutoReconnect(str(e))
647656
finally:
648657
if "network_timeout" in kwargs:
@@ -687,12 +696,8 @@ def __cmp__(self, other):
687696
def __repr__(self):
688697
if len(self.__nodes) == 1:
689698
return "Connection(%r, %r)" % (self.__host, self.__port)
690-
elif len(self.__nodes) == 2:
691-
return ("Connection.paired((%r, %r), (%r, %r))" %
692-
(self.__nodes[0][0],
693-
self.__nodes[0][1],
694-
self.__nodes[1][0],
695-
self.__nodes[1][1]))
699+
else:
700+
return "Connection(%r)" % ["%s:%d" % n for n in self.__nodes]
696701

697702
def __getattr__(self, name):
698703
"""Get a database by name.

pymongo/cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ def __send_message(self, message):
495495
response = helpers._unpack_response(response, self.__id,
496496
self.__as_class, self.__tz_aware)
497497
except AutoReconnect:
498-
db.connection._reset()
498+
db.connection.disconnect()
499499
raise
500500
self.__id = response["cursor_id"]
501501

pymongo/database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def error(self):
391391
if error.get("err", 0) is None:
392392
return None
393393
if error["err"] == "not master":
394-
self.__connection._reset()
394+
self.__connection.disconnect()
395395
return error
396396

397397
def last_status(self):

test/test_connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def test_types(self):
5656
self.assertRaises(TypeError, Connection, "localhost", 1.14)
5757
self.assertRaises(TypeError, Connection, "localhost", [])
5858

59+
self.assertRaises(ConfigurationError, Connection, [])
60+
5961
def test_constants(self):
6062
Connection.HOST = self.host
6163
Connection.PORT = self.port

test/test_paired.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import warnings
2929
sys.path[0:0] = [""]
3030

31-
from pymongo.errors import ConnectionFailure, ConfigurationError
31+
from pymongo.errors import ConnectionFailure
3232
from pymongo.connection import Connection
3333

3434
skip_tests = True
@@ -83,17 +83,17 @@ def test_connect(self):
8383
self.assertEqual(port, connection.port)
8484

8585
slave = self.left == (host, port) and self.right or self.left
86-
self.assertRaises(ConfigurationError, Connection.paired,
86+
self.assertRaises(ConnectionFailure, Connection.paired,
8787
slave, self.bad)
88-
self.assertRaises(ConfigurationError, Connection.paired,
88+
self.assertRaises(ConnectionFailure, Connection.paired,
8989
self.bad, slave)
9090

9191
def test_repr(self):
9292
self.skip()
9393
connection = Connection.paired(self.left, self.right)
9494

9595
self.assertEqual(repr(connection),
96-
"Connection.paired(('%s', %s), ('%s', %s))" %
96+
"Connection(['%s:%s', '%s:%s'])" %
9797
(self.left[0],
9898
self.left[1],
9999
self.right[0],

tools/auto_reconnect_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import threading
1818
import time
1919

20-
from pymongo.errors import ConnectionFailure
20+
from pymongo.errors import AutoReconnect
2121
from pymongo.connection import Connection
2222

2323
db = Connection.paired(("localhost", 27018)).test
@@ -28,12 +28,12 @@ def run(self):
2828
while True:
2929
time.sleep(1)
3030
try:
31-
id = db.test.save({"x": 1})
31+
id = db.test.save({"x": 1}, safe=True)
3232
assert db.test.find_one(id)["x"] == 1
3333
db.test.remove(id)
3434
db.connection.end_request()
3535
print "Y"
36-
except ConnectionFailure, e:
36+
except Exception, e:
3737
print e
3838
print "N"
3939

0 commit comments

Comments
 (0)