Skip to content

Commit b9baa8a

Browse files
committed
PYTHON-960 - GridFS spec compliance
1 parent 8b986a4 commit b9baa8a

File tree

10 files changed

+2721
-107
lines changed

10 files changed

+2721
-107
lines changed

gridfs/__init__.py

Lines changed: 375 additions & 39 deletions
Large diffs are not rendered by default.

gridfs/grid_file.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,22 @@
1313
# limitations under the License.
1414

1515
"""Tools for representing files stored in GridFS."""
16-
1716
import datetime
1817
import math
1918
import os
2019

20+
from hashlib import md5
21+
2122
from bson.binary import Binary
2223
from bson.objectid import ObjectId
2324
from bson.py3compat import text_type, StringIO
24-
from gridfs.errors import (CorruptGridFile,
25-
FileExists,
26-
NoFile)
25+
from gridfs.errors import CorruptGridFile, FileExists, NoFile
2726
from pymongo import ASCENDING
2827
from pymongo.collection import Collection
29-
from pymongo.common import UNAUTHORIZED_CODES
3028
from pymongo.cursor import Cursor
3129
from pymongo.errors import (ConfigurationError,
3230
DuplicateKeyError,
3331
OperationFailure)
34-
from pymongo.read_preferences import ReadPreference
3532

3633
try:
3734
_SEEK_SET = os.SEEK_SET
@@ -50,6 +47,9 @@
5047
# Slightly under a power of 2, to work well with server's record allocations.
5148
DEFAULT_CHUNK_SIZE = 255 * 1024
5249

50+
_C_INDEX = [("files_id", ASCENDING), ("n", ASCENDING)]
51+
_F_INDEX = [("filename", ASCENDING), ("uploadDate", ASCENDING)]
52+
5353

5454
def _grid_in_property(field_name, docstring, read_only=False,
5555
closed_only=False):
@@ -155,6 +155,7 @@ def __init__(self, root_collection, **kwargs):
155155
if "chunk_size" in kwargs:
156156
kwargs["chunkSize"] = kwargs.pop("chunk_size")
157157

158+
kwargs['md5'] = md5()
158159
# Defaults
159160
kwargs["_id"] = kwargs.get("_id", ObjectId())
160161
kwargs["chunkSize"] = kwargs.get("chunkSize", DEFAULT_CHUNK_SIZE)
@@ -167,18 +168,30 @@ def __init__(self, root_collection, **kwargs):
167168
object.__setattr__(self, "_closed", False)
168169
object.__setattr__(self, "_ensured_index", False)
169170

170-
def _ensure_index(self):
171-
if not object.__getattribute__(self, "_ensured_index"):
171+
def __create_index(self, collection, index, unique):
172+
doc = collection.find_one(projection={"_id": 1})
173+
if doc is None:
172174
try:
173-
self._coll.chunks.create_index(
174-
[("files_id", ASCENDING), ("n", ASCENDING)],
175-
unique=True)
176-
except OperationFailure as exc:
177-
if not (exc.code in UNAUTHORIZED_CODES
178-
or "authorized" in str(exc)):
179-
raise exc
175+
indexes = list(collection.list_indexes())
176+
except OperationFailure:
177+
indexes = []
178+
if index not in indexes:
179+
collection.create_index(index, unique=unique)
180+
181+
def __ensure_indexes(self):
182+
if not object.__getattribute__(self, "_ensured_index"):
183+
self.__create_index(self._coll.files, _F_INDEX, False)
184+
self.__create_index(self._coll.chunks, _C_INDEX, True)
180185
object.__setattr__(self, "_ensured_index", True)
181186

187+
def abort(self):
188+
"""Remove all chunks/files that may have been uploaded and close.
189+
"""
190+
self._coll.chunks.delete_many({"files_id": self._file['_id']})
191+
self._coll.files.delete_one({"_id": self._file['_id']})
192+
object.__setattr__(self, "_closed", True)
193+
194+
182195
@property
183196
def closed(self):
184197
"""Is this file closed?
@@ -225,7 +238,8 @@ def __flush_data(self, data):
225238
"""
226239
# Ensure the index, even if there's nothing to write, so
227240
# the filemd5 command always succeeds.
228-
self._ensure_index()
241+
self.__ensure_indexes()
242+
self._file['md5'].update(data)
229243

230244
if not data:
231245
return
@@ -255,12 +269,7 @@ def __flush(self):
255269
try:
256270
self.__flush_buffer()
257271

258-
db = self._coll.database
259-
md5 = db.command(
260-
"filemd5", self._id, root=self._coll.name,
261-
read_preference=ReadPreference.PRIMARY)["md5"]
262-
263-
self._file["md5"] = md5
272+
self._file['md5'] = self._file["md5"].hexdigest()
264273
self._file["length"] = self._position
265274
self._file["uploadDate"] = datetime.datetime.utcnow()
266275

@@ -326,10 +335,14 @@ def write(self, data):
326335
# Make sure to flush only when _buffer is complete
327336
space = self.chunk_size - self._buffer.tell()
328337
if space:
329-
to_write = read(space)
338+
try:
339+
to_write = read(space)
340+
except:
341+
self.abort()
342+
raise
330343
self._buffer.write(to_write)
331344
if len(to_write) < space:
332-
return # EOF or incomplete
345+
return # EOF or incomplete
333346
self.__flush_buffer()
334347
to_write = read(self.chunk_size)
335348
while to_write and len(to_write) == self.chunk_size:
@@ -475,6 +488,16 @@ def read(self, size=-1):
475488
received += len(chunk_data)
476489
data.write(chunk_data)
477490

491+
# Detect extra chunks.
492+
max_chunk_n = math.ceil(self.length / float(self.chunk_size))
493+
chunk = self.__chunks.find_one({"files_id": self._id,
494+
"n": {"$gte": max_chunk_n}})
495+
# According to spec, ignore extra chunks if they are empty.
496+
if chunk is not None and len(chunk['data']):
497+
raise CorruptGridFile(
498+
"Extra chunk found: expected %i chunks but found "
499+
"chunk with n=%i" % (max_chunk_n, chunk['n']))
500+
478501
self.__position -= received - size
479502

480503
# Return 'size' bytes and store the rest.
@@ -605,7 +628,7 @@ class GridOutCursor(Cursor):
605628
of an arbitrary query against the GridFS files collection.
606629
"""
607630
def __init__(self, collection, filter=None, skip=0, limit=0,
608-
no_cursor_timeout=False, sort=None):
631+
no_cursor_timeout=False, sort=None, batch_size=0):
609632
"""Create a new cursor, similar to the normal
610633
:class:`~pymongo.cursor.Cursor`.
611634
@@ -621,7 +644,8 @@ def __init__(self, collection, filter=None, skip=0, limit=0,
621644

622645
super(GridOutCursor, self).__init__(
623646
collection.files, filter, skip=skip, limit=limit,
624-
no_cursor_timeout=no_cursor_timeout, sort=sort)
647+
no_cursor_timeout=no_cursor_timeout, sort=sort,
648+
batch_size=batch_size)
625649

626650
def next(self):
627651
"""Get next GridOut object from cursor.

0 commit comments

Comments
 (0)