Skip to content

Commit 8069e13

Browse files
committed
PYTHON-1721 Improve GridFS file download performance (mongodb#413)
This change uses a cursor to download all the chunks in a GridFS file instead of using individual find_one operations to read each chunk. Detect truncated/missing/extra chunks in _GridOutChunkIterator. Only detect extra chunks after reading the final chunk, not on every call to read(). Retry once after CursorNotFound for backward compatibility. (cherry picked from commit 956fd92)
1 parent 905c578 commit 8069e13

File tree

6 files changed

+186
-48
lines changed

6 files changed

+186
-48
lines changed

doc/changelog.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ Changes in Version 3.8.0.dev0
1515

1616
- :class:`~bson.objectid.ObjectId` now implements the `ObjectID specification
1717
version 0.2 <https://github.com/mongodb/specifications/blob/master/source/objectid.rst>`_.
18+
- For better performance and to better follow the GridFS spec,
19+
:class:`~gridfs.grid_file.GridOut` now uses a single cursor to read all the
20+
chunks in the file. Previously, each chunk in the file was queried
21+
individually using :meth:`~pymongo.collection.Collection.find_one`.
22+
- :meth:`gridfs.grid_file.GridOut.read` now only checks for extra chunks after
23+
reading the entire file. Previously, this method would check for extra
24+
chunks on every call.
1825

1926
- :meth:`~pymongo.database.Database.current_op` now always uses the
2027
``Database``'s :attr:`~pymongo.database.Database.codec_options`

gridfs/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -715,9 +715,9 @@ def download_to_stream(self, file_id, destination, session=None):
715715
.. versionchanged:: 3.6
716716
Added ``session`` parameter.
717717
"""
718-
gout = self.open_download_stream(file_id, session=session)
719-
for chunk in gout:
720-
destination.write(chunk)
718+
with self.open_download_stream(file_id, session=session) as gout:
719+
for chunk in gout:
720+
destination.write(chunk)
721721

722722
def delete(self, file_id, session=None):
723723
"""Given an file_id, delete this stored file's files collection document
@@ -890,10 +890,10 @@ def download_to_stream_by_name(self, filename, destination, revision=-1,
890890
.. versionchanged:: 3.6
891891
Added ``session`` parameter.
892892
"""
893-
gout = self.open_download_stream_by_name(
894-
filename, revision, session=session)
895-
for chunk in gout:
896-
destination.write(chunk)
893+
with self.open_download_stream_by_name(
894+
filename, revision, session=session) as gout:
895+
for chunk in gout:
896+
destination.write(chunk)
897897

898898
def rename(self, file_id, new_filename, session=None):
899899
"""Renames the stored file with the specified file_id.

gridfs/grid_file.py

Lines changed: 136 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pymongo.collection import Collection
2828
from pymongo.cursor import Cursor
2929
from pymongo.errors import (ConfigurationError,
30+
CursorNotFound,
3031
DuplicateKeyError,
3132
OperationFailure)
3233
from pymongo.read_preferences import ReadPreference
@@ -419,6 +420,11 @@ def __init__(self, root_collection, file_id=None, file_document=None,
419420
:class:`~pymongo.client_session.ClientSession` to use for all
420421
commands
421422
423+
.. versionchanged:: 3.8
424+
For better performance and to better follow the GridFS spec,
425+
:class:`GridOut` now uses a single cursor to read all the chunks in
426+
the file.
427+
422428
.. versionchanged:: 3.6
423429
Added ``session`` parameter.
424430
@@ -434,6 +440,7 @@ def __init__(self, root_collection, file_id=None, file_document=None,
434440
self.__files = root_collection.files
435441
self.__file_id = file_id
436442
self.__buffer = EMPTY
443+
self.__chunk_iter = None
437444
self.__position = 0
438445
self._file = file_document
439446
self._session = session
@@ -477,12 +484,11 @@ def readchunk(self):
477484
chunk_data = self.__buffer
478485
elif self.__position < int(self.length):
479486
chunk_number = int((received + self.__position) / chunk_size)
480-
chunk = self.__chunks.find_one({"files_id": self._id,
481-
"n": chunk_number},
482-
session=self._session)
483-
if not chunk:
484-
raise CorruptGridFile("no chunk #%d" % chunk_number)
487+
if self.__chunk_iter is None:
488+
self.__chunk_iter = _GridOutChunkIterator(
489+
self, self.__chunks, self._session, chunk_number)
485490

491+
chunk = self.__chunk_iter.next()
486492
chunk_data = chunk["data"][self.__position % chunk_size:]
487493

488494
if not chunk_data:
@@ -501,33 +507,34 @@ def read(self, size=-1):
501507
502508
:Parameters:
503509
- `size` (optional): the number of bytes to read
510+
511+
.. versionchanged:: 3.8
512+
This method now only checks for extra chunks after reading the
513+
entire file. Previously, this method would check for extra chunks
514+
on every call.
504515
"""
505516
self._ensure_file()
506517

507-
if size == 0:
508-
return EMPTY
509-
510518
remainder = int(self.length) - self.__position
511519
if size < 0 or size > remainder:
512520
size = remainder
513521

522+
if size == 0:
523+
return EMPTY
524+
514525
received = 0
515526
data = StringIO()
516527
while received < size:
517528
chunk_data = self.readchunk()
518529
received += len(chunk_data)
519530
data.write(chunk_data)
520531

521-
# Detect extra chunks.
522-
max_chunk_n = math.ceil(self.length / float(self.chunk_size))
523-
chunk = self.__chunks.find_one({"files_id": self._id,
524-
"n": {"$gte": max_chunk_n}},
525-
session=self._session)
526-
# According to spec, ignore extra chunks if they are empty.
527-
if chunk is not None and len(chunk['data']):
528-
raise CorruptGridFile(
529-
"Extra chunk found: expected %i chunks but found "
530-
"chunk with n=%i" % (max_chunk_n, chunk['n']))
532+
# Detect extra chunks after reading the entire file.
533+
if size == remainder and self.__chunk_iter:
534+
try:
535+
self.__chunk_iter.next()
536+
except StopIteration:
537+
pass
531538

532539
self.__position -= received - size
533540

@@ -543,13 +550,13 @@ def readline(self, size=-1):
543550
:Parameters:
544551
- `size` (optional): the maximum number of bytes to read
545552
"""
546-
if size == 0:
547-
return b''
548-
549553
remainder = int(self.length) - self.__position
550554
if size < 0 or size > remainder:
551555
size = remainder
552556

557+
if size == 0:
558+
return EMPTY
559+
553560
received = 0
554561
data = StringIO()
555562
while received < size:
@@ -600,8 +607,15 @@ def seek(self, pos, whence=_SEEK_SET):
600607
if new_pos < 0:
601608
raise IOError(22, "Invalid value for `pos` - must be positive")
602609

610+
# Optimization, continue using the same buffer and chunk iterator.
611+
if new_pos == self.__position:
612+
return
613+
603614
self.__position = new_pos
604615
self.__buffer = EMPTY
616+
if self.__chunk_iter:
617+
self.__chunk_iter.close()
618+
self.__chunk_iter = None
605619

606620
def __iter__(self):
607621
"""Return an iterator over all of this file's data.
@@ -610,12 +624,20 @@ def __iter__(self):
610624
:class:`str` (:class:`bytes` in python 3). This can be
611625
useful when serving files using a webserver that handles
612626
such an iterator efficiently.
627+
628+
.. versionchanged:: 3.8
629+
The iterator now raises :class:`CorruptGridFile` when encountering
630+
any truncated, missing, or extra chunk in a file. The previous
631+
behavior was to only raise :class:`CorruptGridFile` on a missing
632+
chunk.
613633
"""
614634
return GridOutIterator(self, self.__chunks, self._session)
615635

616636
def close(self):
617637
"""Make GridOut more generically file-like."""
618-
pass
638+
if self.__chunk_iter:
639+
self.__chunk_iter.close()
640+
self.__chunk_iter = None
619641

620642
def __enter__(self):
621643
"""Makes it possible to use :class:`GridOut` files
@@ -627,30 +649,108 @@ def __exit__(self, exc_type, exc_val, exc_tb):
627649
"""Makes it possible to use :class:`GridOut` files
628650
with the context manager protocol.
629651
"""
652+
self.close()
630653
return False
631654

632655

656+
class _GridOutChunkIterator(object):
657+
"""Iterates over a file's chunks using a single cursor.
658+
659+
Raises CorruptGridFile when encountering any truncated, missing, or extra
660+
chunk in a file.
661+
"""
662+
def __init__(self, grid_out, chunks, session, next_chunk):
663+
self._id = grid_out._id
664+
self._chunk_size = int(grid_out.chunk_size)
665+
self._length = int(grid_out.length)
666+
self._chunks = chunks
667+
self._session = session
668+
self._next_chunk = next_chunk
669+
self._num_chunks = math.ceil(float(self._length) / self._chunk_size)
670+
self._cursor = None
671+
672+
def expected_chunk_length(self, chunk_n):
673+
if chunk_n < self._num_chunks - 1:
674+
return self._chunk_size
675+
return self._length - (self._chunk_size * (self._num_chunks - 1))
676+
677+
def __iter__(self):
678+
return self
679+
680+
def _create_cursor(self):
681+
filter = {"files_id": self._id}
682+
if self._next_chunk > 0:
683+
filter["n"] = {"$gte": self._next_chunk}
684+
self._cursor = self._chunks.find(filter, sort=[("n", 1)],
685+
session=self._session)
686+
687+
def _next_with_retry(self):
688+
"""Return the next chunk and retry once on CursorNotFound.
689+
690+
We retry on CursorNotFound to maintain backwards compatibility in
691+
cases where two calls to read occur more than 10 minutes apart (the
692+
server's default cursor timeout).
693+
"""
694+
if self._cursor is None:
695+
self._create_cursor()
696+
697+
try:
698+
return self._cursor.next()
699+
except CursorNotFound:
700+
self._cursor.close()
701+
self._create_cursor()
702+
return self._cursor.next()
703+
704+
def next(self):
705+
try:
706+
chunk = self._next_with_retry()
707+
except StopIteration:
708+
if self._next_chunk >= self._num_chunks:
709+
raise
710+
raise CorruptGridFile("no chunk #%d" % self._next_chunk)
711+
712+
if chunk["n"] != self._next_chunk:
713+
self.close()
714+
raise CorruptGridFile(
715+
"Missing chunk: expected chunk #%d but found "
716+
"chunk with n=%d" % (self._next_chunk, chunk["n"]))
717+
718+
if chunk["n"] >= self._num_chunks:
719+
# According to spec, ignore extra chunks if they are empty.
720+
if len(chunk["data"]):
721+
self.close()
722+
raise CorruptGridFile(
723+
"Extra chunk found: expected %d chunks but found "
724+
"chunk with n=%d" % (self._num_chunks, chunk["n"]))
725+
726+
expected_length = self.expected_chunk_length(chunk["n"])
727+
if len(chunk["data"]) != expected_length:
728+
self.close()
729+
raise CorruptGridFile(
730+
"truncated chunk #%d: expected chunk length to be %d but "
731+
"found chunk with length %d" % (
732+
chunk["n"], expected_length, len(chunk["data"])))
733+
734+
self._next_chunk += 1
735+
return chunk
736+
737+
__next__ = next
738+
739+
def close(self):
740+
if self._cursor:
741+
self._cursor.close()
742+
self._cursor = None
743+
744+
633745
class GridOutIterator(object):
634746
def __init__(self, grid_out, chunks, session):
635-
self.__id = grid_out._id
636-
self.__chunks = chunks
637-
self.__session = session
638-
self.__current_chunk = 0
639-
self.__max_chunk = math.ceil(float(grid_out.length) /
640-
grid_out.chunk_size)
747+
self.__chunk_iter = _GridOutChunkIterator(grid_out, chunks, session, 0)
641748

642749
def __iter__(self):
643750
return self
644751

645752
def next(self):
646-
if self.__current_chunk >= self.__max_chunk:
647-
raise StopIteration
648-
chunk = self.__chunks.find_one({"files_id": self.__id,
649-
"n": self.__current_chunk},
650-
session=self.__session)
651-
if not chunk:
652-
raise CorruptGridFile("no chunk #%d" % self.__current_chunk)
653-
self.__current_chunk += 1
753+
chunk = self.__chunk_iter.next()
654754
return bytes(chunk["data"])
655755

656756
__next__ = next

test/test_grid_file.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@
3333
from gridfs.errors import NoFile
3434
from pymongo import MongoClient
3535
from pymongo.errors import ConfigurationError, ServerSelectionTimeoutError
36+
from pymongo.message import _CursorAddress
3637
from test import (IntegrationTest,
3738
unittest,
3839
qcheck)
39-
from test.utils import rs_or_single_client
40+
from test.utils import rs_or_single_client, EventListener
4041

4142

4243
class TestGridFileNoConnect(unittest.TestCase):
@@ -616,6 +617,33 @@ def test_unacknowledged(self):
616617
with self.assertRaises(ConfigurationError):
617618
GridIn(rs_or_single_client(w=0).pymongo_test.fs)
618619

620+
def test_survive_cursor_not_found(self):
621+
# By default the find command returns 101 documents in the first batch.
622+
# Use 102 batches to cause a single getMore.
623+
chunk_size = 1024
624+
data = b'd' * (102 * chunk_size)
625+
listener = EventListener()
626+
client = rs_or_single_client(event_listeners=[listener])
627+
db = client.pymongo_test
628+
with GridIn(db.fs, chunk_size=chunk_size) as infile:
629+
infile.write(data)
630+
631+
with GridOut(db.fs, infile._id) as outfile:
632+
self.assertEqual(len(outfile.readchunk()), chunk_size)
633+
634+
# Kill the cursor to simulate the cursor timing out on the server
635+
# when an application spends a long time between two calls to
636+
# readchunk().
637+
client._close_cursor_now(
638+
outfile._GridOut__chunk_iter._cursor.cursor_id,
639+
_CursorAddress(client.address, db.fs.chunks.full_name))
640+
641+
# Read the rest of the file without error.
642+
self.assertEqual(len(outfile.read()), len(data) - chunk_size)
643+
644+
# Paranoid, ensure that a getMore was actually sent.
645+
self.assertIn("getMore", listener.started_command_names())
646+
619647

620648
if __name__ == "__main__":
621649
unittest.main()

test/test_gridfs_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ def run_scenario(self):
163163

164164
if test['assert'].get("error", False):
165165
self.assertIsNotNone(error)
166-
self.assertTrue(isinstance(error,
167-
errors[test['assert']['error']]))
166+
self.assertIsInstance(error, errors[test['assert']['error']],
167+
test['description'])
168168
else:
169169
self.assertIsNone(error)
170170

test/test_session.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,11 @@ def test_gridfsbucket_cursor(self):
571571
for f in files:
572572
f.read()
573573

574-
with self.assertRaisesRegex(InvalidOperation, "ended session"):
575-
files[0].read()
574+
for f in files:
575+
# Attempt to read the file again.
576+
f.seek(0)
577+
with self.assertRaisesRegex(InvalidOperation, "ended session"):
578+
f.read()
576579

577580
def test_aggregate(self):
578581
client = self.client

0 commit comments

Comments
 (0)