Skip to content

Commit fad1c38

Browse files
committed
Check keys in Collection.save PYTHON-245.
1 parent 0203182 commit fad1c38

4 files changed

Lines changed: 20 additions & 8 deletions

File tree

pymongo/_cmessagemodule.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,18 +222,19 @@ static PyObject* _cbson_update_message(PyObject* self, PyObject* args) {
222222
unsigned char multi;
223223
unsigned char upsert;
224224
unsigned char safe;
225+
unsigned char check_keys;
225226
PyObject* last_error_args;
226227
int options;
227228
buffer_t buffer;
228229
int length_location, message_length;
229230
PyObject* result;
230231

231-
if (!PyArg_ParseTuple(args, "et#bbOObO",
232+
if (!PyArg_ParseTuple(args, "et#bbOObbO",
232233
"utf-8",
233234
&collection_name,
234235
&collection_name_length,
235236
&upsert, &multi, &spec, &doc, &safe,
236-
&last_error_args)) {
237+
&check_keys, &last_error_args)) {
237238
return NULL;
238239
}
239240

@@ -282,7 +283,7 @@ static PyObject* _cbson_update_message(PyObject* self, PyObject* args) {
282283
max_size = buffer_get_position(buffer) - before;
283284

284285
before = buffer_get_position(buffer);
285-
if (!write_dict(buffer, doc, 0, 1)) {
286+
if (!write_dict(buffer, doc, check_keys, 1)) {
286287
buffer_free(buffer);
287288
PyMem_Free(collection_name);
288289
return NULL;

pymongo/collection.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def save(self, to_save, manipulate=True, safe=False, **kwargs):
218218
return self.insert(to_save, manipulate, safe, **kwargs)
219219
else:
220220
self.update({"_id": to_save["_id"]}, to_save, True,
221-
manipulate, safe, **kwargs)
221+
manipulate, safe, _check_keys=True, **kwargs)
222222
return to_save.get("_id", None)
223223

224224
def insert(self, doc_or_docs, manipulate=True,
@@ -298,7 +298,7 @@ def insert(self, doc_or_docs, manipulate=True,
298298
return return_one and ids[0] or ids
299299

300300
def update(self, spec, document, upsert=False, manipulate=False,
301-
safe=False, multi=False, **kwargs):
301+
safe=False, multi=False, _check_keys=False, **kwargs):
302302
"""Update a document(s) in this collection.
303303
304304
Raises :class:`TypeError` if either `spec` or `document` is
@@ -385,9 +385,12 @@ def update(self, spec, document, upsert=False, manipulate=False,
385385
if not kwargs:
386386
kwargs.update(self.get_lasterror_options())
387387

388+
# _check_keys is used by save() so we don't upsert pre-existing
389+
# documents after adding an invalid key like 'a.b'. It can't really
390+
# be used for any other update operations.
388391
return self.__database.connection._send_message(
389392
message.update(self.__full_name, upsert, multi,
390-
spec, document, safe, kwargs), safe)
393+
spec, document, safe, _check_keys, kwargs), safe)
391394

392395
def drop(self):
393396
"""Alias for :meth:`~pymongo.database.Database.drop_collection`.

pymongo/message.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def insert(collection_name, docs, check_keys,
8888
insert = _cmessage._insert_message
8989

9090

91-
def update(collection_name, upsert, multi, spec, doc, safe, last_error_args):
91+
def update(collection_name, upsert, multi,
92+
spec, doc, safe, check_keys, last_error_args):
9293
"""Get an **update** message.
9394
"""
9495
options = 0
@@ -101,7 +102,7 @@ def update(collection_name, upsert, multi, spec, doc, safe, last_error_args):
101102
data += bson._make_c_string(collection_name)
102103
data += struct.pack("<i", options)
103104
data += bson.BSON.encode(spec)
104-
encoded = bson.BSON.encode(doc)
105+
encoded = bson.BSON.encode(doc, check_keys)
105106
data += encoded
106107
if safe:
107108
(_, update_message) = __pack_message(2001, data)

test/test_collection.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,13 @@ def test_save(self):
504504
self.assertEqual(self.db.test.find_one()["_id"], id)
505505
self.assert_(isinstance(id, ObjectId))
506506

507+
def test_save_with_invalid_key(self):
508+
self.db.drop_collection("test")
509+
self.assert_(self.db.test.insert({"hello": "world"}))
510+
doc = self.db.test.find_one()
511+
doc['a.b'] = 'c'
512+
self.assertRaises(InvalidDocument, self.db.test.save, doc)
513+
507514
def test_unique_index(self):
508515
db = self.db
509516

0 commit comments

Comments
 (0)