Skip to content

Commit 2bba32b

Browse files
committed
Fix race condition in ensure_index PYTHON-284
1 parent fad1c38 commit 2bba32b

5 files changed

Lines changed: 64 additions & 30 deletions

File tree

pymongo/collection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -764,8 +764,8 @@ def ensure_index(self, key_or_list, deprecated_unique=None,
764764
keys = helpers._index_list(key_or_list)
765765
name = kwargs["name"] = _gen_index_name(keys)
766766

767-
if self.__database.connection._cache_index(self.__database.name,
768-
self.__name, name, ttl):
767+
if not self.__database.connection._cached(self.__database.name,
768+
self.__name, name):
769769
return self.create_index(key_or_list, deprecated_unique,
770770
ttl, **kwargs)
771771
return None

pymongo/connection.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -407,13 +407,18 @@ def paired(cls, left, right=None, **connection_args):
407407
return cls([":".join(map(str, left)), ":".join(map(str, right))],
408408
**connection_args)
409409

410+
def _cached(self, dbname, coll, index):
411+
"""Test if `index` is cached.
412+
"""
413+
cache = self.__index_cache
414+
now = datetime.datetime.utcnow()
415+
return (dbname in cache and
416+
coll in cache[dbname] and
417+
index in cache[dbname][coll] and
418+
now < cache[dbname][coll][index])
419+
410420
def _cache_index(self, database, collection, index, ttl):
411421
"""Add an index to the index cache for ensure_index operations.
412-
413-
Return ``True`` if the index has been newly cached or if the index had
414-
expired and is being re-cached.
415-
416-
Return ``False`` if the index exists and is valid.
417422
"""
418423
now = datetime.datetime.utcnow()
419424
expire = datetime.timedelta(seconds=ttl) + now
@@ -422,19 +427,13 @@ def _cache_index(self, database, collection, index, ttl):
422427
self.__index_cache[database] = {}
423428
self.__index_cache[database][collection] = {}
424429
self.__index_cache[database][collection][index] = expire
425-
return True
426430

427-
if collection not in self.__index_cache[database]:
431+
elif collection not in self.__index_cache[database]:
428432
self.__index_cache[database][collection] = {}
429433
self.__index_cache[database][collection][index] = expire
430-
return True
431434

432-
if index in self.__index_cache[database][collection]:
433-
if now < self.__index_cache[database][collection][index]:
434-
return False
435-
436-
self.__index_cache[database][collection][index] = expire
437-
return True
435+
else:
436+
self.__index_cache[database][collection][index] = expire
438437

439438
def _purge_index(self, database_name,
440439
collection_name=None, index_name=None):

pymongo/master_slave_connection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,10 @@ def __iter__(self):
270270
def next(self):
271271
raise TypeError("'MasterSlaveConnection' object is not iterable")
272272

273+
def _cached(self, database_name, collection_name, index_name):
274+
return self.__master._cached(database_name,
275+
collection_name, index_name)
276+
273277
def _cache_index(self, database_name, collection_name, index_name, ttl):
274278
return self.__master._cache_index(database_name, collection_name,
275279
index_name, ttl)

pymongo/replica_set_connection.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,18 @@ def __init__(self, hosts_or_uri=None, max_pool_size=10,
219219
if not self[db_name].authenticate(username, password):
220220
raise ConfigurationError("authentication failed")
221221

222+
def _cached(self, dbname, coll, index):
223+
"""Test if `index` is cached.
224+
"""
225+
cache = self.__index_cache
226+
now = datetime.datetime.utcnow()
227+
return (dbname in cache and
228+
coll in cache[dbname] and
229+
index in cache[dbname][coll] and
230+
now < cache[dbname][coll][index])
231+
222232
def _cache_index(self, dbase, collection, index, ttl):
223233
"""Add an index to the index cache for ensure_index operations.
224-
225-
Return ``True`` if the index has been newly cached or if the index had
226-
expired and is being re-cached.
227-
228-
Return ``False`` if the index exists and is valid.
229234
"""
230235
now = datetime.datetime.utcnow()
231236
expire = datetime.timedelta(seconds=ttl) + now
@@ -234,19 +239,13 @@ def _cache_index(self, dbase, collection, index, ttl):
234239
self.__index_cache[dbase] = {}
235240
self.__index_cache[dbase][collection] = {}
236241
self.__index_cache[dbase][collection][index] = expire
237-
return True
238242

239-
if collection not in self.__index_cache[dbase]:
243+
elif collection not in self.__index_cache[dbase]:
240244
self.__index_cache[dbase][collection] = {}
241245
self.__index_cache[dbase][collection][index] = expire
242-
return True
243246

244-
if index in self.__index_cache[dbase][collection]:
245-
if now < self.__index_cache[dbase][collection][index]:
246-
return False
247-
248-
self.__index_cache[dbase][collection][index] = expire
249-
return True
247+
else:
248+
self.__index_cache[dbase][collection][index] = expire
250249

251250
def _purge_index(self, database_name,
252251
collection_name=None, index_name=None):

test/test_collection.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import itertools
2020
import re
2121
import sys
22+
import threading
2223
import time
2324
import unittest
2425
import warnings
@@ -176,9 +177,40 @@ def test_ensure_index(self):
176177
time.sleep(1.1)
177178
self.assertEqual("goodbye_1",
178179
db.test.ensure_index("goodbye"))
180+
# Make sure the expiration time is updated.
181+
self.assertEqual(None,
182+
db.test.ensure_index("goodbye"))
179183
# Clean up indexes for later tests
180184
db.test.drop_indexes()
181185

186+
def test_ensure_unique_index_threaded(self):
187+
db = self.db
188+
db.test.drop()
189+
db.test.insert({'foo': i} for i in xrange(10000))
190+
191+
class Indexer(threading.Thread):
192+
def run(self):
193+
try:
194+
db.test.ensure_index('foo', unique=True)
195+
db.test.insert({'foo': 'bar'})
196+
db.test.insert({'foo': 'bar'})
197+
except OperationFailure:
198+
pass
199+
200+
threads = []
201+
for _ in xrange(10):
202+
t = Indexer()
203+
threads.append(t)
204+
205+
for i in xrange(10):
206+
threads[i].start()
207+
208+
for i in xrange(10):
209+
threads[i].join()
210+
211+
self.assertEqual(10001, db.test.count())
212+
db.test.drop()
213+
182214
def test_index_on_binary(self):
183215
db = self.db
184216
db.drop_collection("test")

0 commit comments

Comments
 (0)