Skip to content

Commit 264cdd8

Browse files
committed
PYTHON-1070 - Make index cache thread safe
1 parent cb4a80a commit 264cdd8

4 files changed

Lines changed: 142 additions & 50 deletions

File tree

pymongo/collection.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,6 +1594,12 @@ def ensure_index(self, key_or_list, cache_for=300, **kwargs):
15941594
keys = helpers._index_list(key_or_list)
15951595
name = kwargs["name"] = _gen_index_name(keys)
15961596

1597+
# Note that there is a race condition here. One thread could
1598+
# check if the index is cached and be preempted before creating
1599+
# and caching the index. This means multiple threads attempting
1600+
# to create the same index concurrently could send the index
1601+
# to the server two or more times. This has no practical impact
1602+
# other than wasted round trips.
15971603
if not self.__database.connection._cached(self.__database.name,
15981604
self.__name, name):
15991605
return self.create_index(key_or_list, cache_for, **kwargs)

pymongo/mongo_client.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def __init__(self, host=None, port=None, max_pool_size=100,
407407

408408
# cache of existing indexes used by ensure_index ops
409409
self.__index_cache = {}
410+
self.__index_cache_lock = threading.Lock()
410411
self.__auth_credentials = {}
411412

412413
super(MongoClient, self).__init__(**options)
@@ -446,28 +447,36 @@ def _cached(self, dbname, coll, index):
446447
"""
447448
cache = self.__index_cache
448449
now = datetime.datetime.utcnow()
449-
return (dbname in cache and
450-
coll in cache[dbname] and
451-
index in cache[dbname][coll] and
452-
now < cache[dbname][coll][index])
450+
self.__index_cache_lock.acquire()
451+
try:
452+
return (dbname in cache and
453+
coll in cache[dbname] and
454+
index in cache[dbname][coll] and
455+
now < cache[dbname][coll][index])
456+
finally:
457+
self.__index_cache_lock.release()
453458

454459
def _cache_index(self, database, collection, index, cache_for):
455460
"""Add an index to the index cache for ensure_index operations.
456461
"""
457462
now = datetime.datetime.utcnow()
458463
expire = datetime.timedelta(seconds=cache_for) + now
459464

460-
if database not in self.__index_cache:
461-
self.__index_cache[database] = {}
462-
self.__index_cache[database][collection] = {}
463-
self.__index_cache[database][collection][index] = expire
465+
self.__index_cache_lock.acquire()
466+
try:
467+
if database not in self.__index_cache:
468+
self.__index_cache[database] = {}
469+
self.__index_cache[database][collection] = {}
470+
self.__index_cache[database][collection][index] = expire
464471

465-
elif collection not in self.__index_cache[database]:
466-
self.__index_cache[database][collection] = {}
467-
self.__index_cache[database][collection][index] = expire
472+
elif collection not in self.__index_cache[database]:
473+
self.__index_cache[database][collection] = {}
474+
self.__index_cache[database][collection][index] = expire
468475

469-
else:
470-
self.__index_cache[database][collection][index] = expire
476+
else:
477+
self.__index_cache[database][collection][index] = expire
478+
finally:
479+
self.__index_cache_lock.release()
471480

472481
def _purge_index(self, database_name,
473482
collection_name=None, index_name=None):
@@ -477,22 +486,26 @@ def _purge_index(self, database_name,
477486
478487
If `collection_name` is None purge an entire database.
479488
"""
480-
if not database_name in self.__index_cache:
481-
return
489+
self.__index_cache_lock.acquire()
490+
try:
491+
if not database_name in self.__index_cache:
492+
return
482493

483-
if collection_name is None:
484-
del self.__index_cache[database_name]
485-
return
494+
if collection_name is None:
495+
del self.__index_cache[database_name]
496+
return
486497

487-
if not collection_name in self.__index_cache[database_name]:
488-
return
498+
if not collection_name in self.__index_cache[database_name]:
499+
return
489500

490-
if index_name is None:
491-
del self.__index_cache[database_name][collection_name]
492-
return
501+
if index_name is None:
502+
del self.__index_cache[database_name][collection_name]
503+
return
493504

494-
if index_name in self.__index_cache[database_name][collection_name]:
495-
del self.__index_cache[database_name][collection_name][index_name]
505+
if index_name in self.__index_cache[database_name][collection_name]:
506+
del self.__index_cache[database_name][collection_name][index_name]
507+
finally:
508+
self.__index_cache_lock.release()
496509

497510
def _cache_credentials(self, source, credentials, connect=True):
498511
"""Add credentials to the database authentication cache

pymongo/mongo_replica_set_client.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@ def __init__(self, hosts_or_uri=None, max_pool_size=100,
584584
self.__opts = {}
585585
self.__seeds = set()
586586
self.__index_cache = {}
587+
self.__index_cache_lock = threading.Lock()
587588
self.__auth_credentials = {}
588589

589590
self.__monitor = None
@@ -780,28 +781,36 @@ def _cached(self, dbname, coll, index):
780781
"""
781782
cache = self.__index_cache
782783
now = datetime.datetime.utcnow()
783-
return (dbname in cache and
784-
coll in cache[dbname] and
785-
index in cache[dbname][coll] and
786-
now < cache[dbname][coll][index])
784+
self.__index_cache_lock.acquire()
785+
try:
786+
return (dbname in cache and
787+
coll in cache[dbname] and
788+
index in cache[dbname][coll] and
789+
now < cache[dbname][coll][index])
790+
finally:
791+
self.__index_cache_lock.release()
787792

788793
def _cache_index(self, dbase, collection, index, cache_for):
789794
"""Add an index to the index cache for ensure_index operations.
790795
"""
791796
now = datetime.datetime.utcnow()
792797
expire = datetime.timedelta(seconds=cache_for) + now
793798

794-
if dbase not in self.__index_cache:
795-
self.__index_cache[dbase] = {}
796-
self.__index_cache[dbase][collection] = {}
797-
self.__index_cache[dbase][collection][index] = expire
799+
self.__index_cache_lock.acquire()
800+
try:
801+
if dbase not in self.__index_cache:
802+
self.__index_cache[dbase] = {}
803+
self.__index_cache[dbase][collection] = {}
804+
self.__index_cache[dbase][collection][index] = expire
798805

799-
elif collection not in self.__index_cache[dbase]:
800-
self.__index_cache[dbase][collection] = {}
801-
self.__index_cache[dbase][collection][index] = expire
806+
elif collection not in self.__index_cache[dbase]:
807+
self.__index_cache[dbase][collection] = {}
808+
self.__index_cache[dbase][collection][index] = expire
802809

803-
else:
804-
self.__index_cache[dbase][collection][index] = expire
810+
else:
811+
self.__index_cache[dbase][collection][index] = expire
812+
finally:
813+
self.__index_cache_lock.release()
805814

806815
def _purge_index(self, database_name,
807816
collection_name=None, index_name=None):
@@ -811,22 +820,26 @@ def _purge_index(self, database_name,
811820
812821
If `collection_name` is None purge an entire database.
813822
"""
814-
if not database_name in self.__index_cache:
815-
return
823+
self.__index_cache_lock.acquire()
824+
try:
825+
if not database_name in self.__index_cache:
826+
return
816827

817-
if collection_name is None:
818-
del self.__index_cache[database_name]
819-
return
828+
if collection_name is None:
829+
del self.__index_cache[database_name]
830+
return
820831

821-
if not collection_name in self.__index_cache[database_name]:
822-
return
832+
if not collection_name in self.__index_cache[database_name]:
833+
return
823834

824-
if index_name is None:
825-
del self.__index_cache[database_name][collection_name]
826-
return
835+
if index_name is None:
836+
del self.__index_cache[database_name][collection_name]
837+
return
827838

828-
if index_name in self.__index_cache[database_name][collection_name]:
829-
del self.__index_cache[database_name][collection_name][index_name]
839+
if index_name in self.__index_cache[database_name][collection_name]:
840+
del self.__index_cache[database_name][collection_name][index_name]
841+
finally:
842+
self.__index_cache_lock.release()
830843

831844
def _cache_credentials(self, source, credentials, connect=True):
832845
"""Add credentials to the database authentication cache

test/test_collection.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,66 @@ def test_deprecated_ttl_index_kwarg(self):
259259
ctx.exit()
260260
self.assertEqual(None, db.test.ensure_index("goodbye"))
261261

262+
def test_ensure_index_threaded(self):
263+
coll = self.db.threaded_index_creation
264+
index_docs = []
265+
266+
class Indexer(threading.Thread):
267+
def run(self):
268+
coll.ensure_index('foo0')
269+
coll.ensure_index('foo1')
270+
coll.ensure_index('foo2')
271+
index_docs.append(coll.index_information())
272+
273+
try:
274+
threads = []
275+
for _ in range(10):
276+
t = Indexer()
277+
t.setDaemon(True)
278+
threads.append(t)
279+
280+
for thread in threads:
281+
thread.start()
282+
283+
joinall(threads)
284+
285+
first = index_docs[0]
286+
for index_doc in index_docs[1:]:
287+
self.assertEqual(index_doc, first)
288+
finally:
289+
coll.drop()
290+
291+
def test_ensure_purge_index_threaded(self):
292+
coll = self.db.threaded_index_creation
293+
294+
class Indexer(threading.Thread):
295+
def run(self):
296+
coll.ensure_index('foo')
297+
try:
298+
coll.drop_index('foo')
299+
except OperationFailure:
300+
# The index may have already been dropped.
301+
pass
302+
coll.ensure_index('foo')
303+
coll.drop_indexes()
304+
coll.ensure_index('foo')
305+
306+
try:
307+
threads = []
308+
for _ in range(10):
309+
t = Indexer()
310+
t.setDaemon(True)
311+
threads.append(t)
312+
313+
for thread in threads:
314+
thread.start()
315+
316+
joinall(threads)
317+
318+
self.assertTrue('foo_1' in coll.index_information())
319+
finally:
320+
coll.drop()
321+
262322
def test_ensure_unique_index_threaded(self):
263323
coll = self.db.test_unique_threaded
264324
coll.drop()

0 commit comments

Comments
 (0)