Skip to content

Commit 3716b44

Browse files
committed
PYTHON-871 - Fix encoding of defaultdict.
1 parent 5df17c2 commit 3716b44

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

bson/_cbsonmodule.c

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,17 +1424,34 @@ int write_dict(PyObject* self, buffer_t buffer,
14241424
}
14251425

14261426
/* Write _id first if this is a top level doc. */
1427-
if (top_level && PyMapping_HasKeyString(dict, "_id")) {
1428-
PyObject* _id = PyMapping_GetItemString(dict, "_id");
1429-
if (!_id) {
1430-
return 0;
1431-
}
1432-
if (!write_pair(self, buffer, "_id", 3,
1433-
_id, check_keys, options, 1)) {
1427+
if (top_level) {
1428+
/*
1429+
* If "dict" is a defaultdict we don't want to call
1430+
* PyMapping_GetItemString on it. That would **create**
1431+
* an _id where one didn't previously exist (PYTHON-871).
1432+
*/
1433+
if (PyDict_Check(dict)) {
1434+
/* PyDict_GetItemString returns a borrowed reference. */
1435+
PyObject* _id = PyDict_GetItemString(dict, "_id");
1436+
if (_id) {
1437+
if (!write_pair(self, buffer, "_id", 3,
1438+
_id, check_keys, options, 1)) {
1439+
return 0;
1440+
}
1441+
}
1442+
} else if (PyMapping_HasKeyString(dict, "_id")) {
1443+
PyObject* _id = PyMapping_GetItemString(dict, "_id");
1444+
if (!_id) {
1445+
return 0;
1446+
}
1447+
if (!write_pair(self, buffer, "_id", 3,
1448+
_id, check_keys, options, 1)) {
1449+
Py_DECREF(_id);
1450+
return 0;
1451+
}
1452+
/* PyMapping_GetItemString returns a new reference. */
14341453
Py_DECREF(_id);
1435-
return 0;
14361454
}
1437-
Py_DECREF(_id);
14381455
}
14391456

14401457
iter = PyObject_GetIter(dict);

test/test_bson.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ def test_encode_then_decode(self):
143143
def test_encode_then_decode_any_mapping(self):
144144
self.check_encode_then_decode(doc_class=NotADict)
145145

146+
def test_encoding_defaultdict(self):
147+
dct = collections.defaultdict(dict, [('foo', 'bar')])
148+
BSON.encode(dct)
149+
self.assertEqual(dct, collections.defaultdict(dict, [('foo', 'bar')]))
150+
146151
def test_basic_validation(self):
147152
self.assertRaises(TypeError, is_valid, 100)
148153
self.assertRaises(TypeError, is_valid, u("test"))

test/test_collection.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import sys
2121
import threading
2222

23+
from collections import defaultdict
24+
2325
sys.path[0:0] = [""]
2426

2527
from bson.regex import Regex
@@ -657,6 +659,13 @@ def test_delete_many(self):
657659
self.assertFalse(result.acknowledged)
658660
wait_until(lambda: 0 == db.test.count(), 'delete 2 documents')
659661

662+
def test_find_by_default_dct(self):
663+
db = self.db
664+
db.test.insert_one({'foo': 'bar'})
665+
dct = defaultdict(dict, [('foo', 'bar')])
666+
self.assertIsNotNone(db.test.find_one(dct))
667+
self.assertEqual(dct, defaultdict(dict, [('foo', 'bar')]))
668+
660669
def test_find_w_fields(self):
661670
db = self.db
662671
db.test.delete_many({})

0 commit comments

Comments
 (0)