Skip to content

Commit b4cb9be

Browse files
author
Mike Dirolf
committed
Add database support to DBRefs (as )
1 parent 9bbeea4 commit b4cb9be

File tree

10 files changed

+105
-136
lines changed

10 files changed

+105
-136
lines changed

pymongo/_cbsonmodule.c

Lines changed: 19 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -602,86 +602,16 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
602602
}
603603
return 1;
604604
} else if (PyObject_IsInstance(value, DBRef)) {
605-
int start_position,
606-
length_location,
607-
collection_length,
608-
type_pos,
609-
length;
610-
PyObject* collection_object;
611-
PyObject* encoded_collection;
612-
PyObject* id_object;
613-
char zero = 0;
614-
615-
*(buffer->buffer + type_byte) = 0x03;
616-
start_position = buffer->position;
617-
618-
/* save space for length */
619-
length_location = buffer_save_bytes(buffer, 4);
620-
if (length_location == -1) {
605+
PyObject* as_doc = PyObject_CallMethod(value, "as_doc", NULL);
606+
if (!as_doc) {
621607
return 0;
622608
}
623-
624-
collection_object = PyObject_GetAttrString(value, "collection");
625-
if (!collection_object) {
609+
if (!write_dict(buffer, as_doc, 0)) {
610+
Py_DECREF(as_doc);
626611
return 0;
627612
}
628-
encoded_collection = PyUnicode_AsUTF8String(collection_object);
629-
Py_DECREF(collection_object);
630-
if (!encoded_collection) {
631-
return 0;
632-
}
633-
{
634-
const char* collection = PyString_AsString(encoded_collection);
635-
if (!collection) {
636-
Py_DECREF(encoded_collection);
637-
return 0;
638-
}
639-
id_object = PyObject_GetAttrString(value, "id");
640-
if (!id_object) {
641-
Py_DECREF(encoded_collection);
642-
return 0;
643-
}
644-
645-
if (!buffer_write_bytes(buffer, "\x02$ref\x00", 6)) {
646-
Py_DECREF(encoded_collection);
647-
Py_DECREF(id_object);
648-
return 0;
649-
}
650-
collection_length = strlen(collection) + 1;
651-
if (!buffer_write_bytes(buffer, (const char*)&collection_length, 4)) {
652-
Py_DECREF(encoded_collection);
653-
Py_DECREF(id_object);
654-
return 0;
655-
}
656-
if (!buffer_write_bytes(buffer, collection, collection_length)) {
657-
Py_DECREF(encoded_collection);
658-
Py_DECREF(id_object);
659-
return 0;
660-
}
661-
}
662-
Py_DECREF(encoded_collection);
663-
664-
type_pos = buffer_save_bytes(buffer, 1);
665-
if (type_pos == -1) {
666-
Py_DECREF(id_object);
667-
return 0;
668-
}
669-
if (!buffer_write_bytes(buffer, "$id\x00", 4)) {
670-
Py_DECREF(id_object);
671-
return 0;
672-
}
673-
if (!write_element_to_buffer(buffer, type_pos, id_object, check_keys)) {
674-
Py_DECREF(id_object);
675-
return 0;
676-
}
677-
Py_DECREF(id_object);
678-
679-
/* write null byte and fill in length */
680-
if (!buffer_write_bytes(buffer, &zero, 1)) {
681-
return 0;
682-
}
683-
length = buffer->position - start_position;
684-
memcpy(buffer->buffer + length_location, &length, 4);
613+
Py_DECREF(as_doc);
614+
*(buffer->buffer + type_byte) = 0x03;
685615
return 1;
686616
}
687617
else if (PyObject_HasAttrString(value, "pattern") &&
@@ -1043,33 +973,22 @@ static PyObject* get_value(const char* buffer, int* position, int type) {
1043973
{
1044974
int size;
1045975
memcpy(&size, buffer + *position, 4);
976+
value = elements_to_dict(buffer + *position + 4, size - 5);
977+
if (!value) {
978+
return NULL;
979+
}
980+
981+
/* Decoding for DBRefs */
1046982
if (strcmp(buffer + *position + 5, "$ref") == 0) { /* DBRef */
1047-
char id_type;
1048-
PyObject* id;
983+
PyObject* id = PyDict_GetItemString(value, "$id");
984+
PyObject* collection = PyDict_GetItemString(value, "$ref");
985+
PyObject* database = PyDict_GetItemString(value, "$db");
1049986

1050-
int offset = *position + 14;
1051-
int collection_length = strlen(buffer + offset);
1052-
PyObject* collection = PyUnicode_DecodeUTF8(buffer + offset, collection_length, "strict");
1053-
if (!collection) {
1054-
return NULL;
1055-
}
1056-
offset += collection_length + 1;
1057-
id_type = buffer[offset];
1058-
offset += 5;
1059-
id = get_value(buffer, &offset, (int)id_type);
1060-
if (!id) {
1061-
Py_DECREF(collection);
1062-
return NULL;
1063-
}
1064-
value = PyObject_CallFunctionObjArgs(DBRef, collection, id, NULL);
1065-
Py_DECREF(collection);
1066-
Py_DECREF(id);
1067-
} else {
1068-
value = elements_to_dict(buffer + *position + 4, size - 5);
1069-
if (!value) {
1070-
return NULL;
1071-
}
987+
/* This works even if there is no $db since database will be NULL and
988+
the call will be as if there were only two arguments specified. */
989+
value = PyObject_CallFunctionObjArgs(DBRef, collection, id, database, NULL);
1072990
}
991+
1073992
*position += size;
1074993
break;
1075994
}

pymongo/bson.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _get_string(data):
236236
def _get_object(data):
237237
(object, data) = _bson_to_dict(data)
238238
if "$ref" in object:
239-
return (DBRef(object["$ref"], object["$id"]), data)
239+
return (DBRef(object["$ref"], object["$id"], object.get("$db", None)), data)
240240
return (object, data)
241241

242242

@@ -456,10 +456,8 @@ def _element_to_bson(key, value, check_keys):
456456
flags += "x"
457457
return "\x0B" + name + _make_c_string(pattern) + _make_c_string(flags)
458458
if isinstance(value, DBRef):
459-
return _element_to_bson(key,
460-
SON([("$ref", value.collection),
461-
("$id", value.id)]),
462-
False)
459+
return _element_to_bson(key, value.as_doc(), False)
460+
463461
raise InvalidDocument("cannot convert value of type %s to bson" %
464462
type(value))
465463

pymongo/database.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,14 +377,20 @@ def logout(self):
377377
def dereference(self, dbref):
378378
"""Dereference a DBRef, getting the SON object it points to.
379379
380-
Raises TypeError if dbref is not an instance of DBRef. Returns a SON
381-
object or None if the reference does not point to a valid object.
380+
Raises TypeError if `dbref` is not an instance of DBRef. Returns a SON
381+
object or None if the reference does not point to a valid object. Raises
382+
ValueError if `dbref` has a database specified that is different from
383+
the current database.
382384
383385
:Parameters:
384386
- `dbref`: the reference
385387
"""
386388
if not isinstance(dbref, DBRef):
387389
raise TypeError("cannot dereference a %s" % type(dbref))
390+
if dbref.database is not None and dbref.database != self.__name:
391+
raise ValueError("trying to dereference a DBRef that points to "
392+
"another database (%r not %r)" % (dbref.database,
393+
self.__name))
388394
return self[dbref.collection].find_one({"_id": dbref.id})
389395

390396
def eval(self, code, *args):

pymongo/dbref.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,40 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Tools for manipulating DBRefs (references to Mongo objects)."""
15+
"""Tools for manipulating DBRefs (references to MongoDB documents)."""
1616

1717
import types
1818

19+
from son import SON
20+
1921

2022
class DBRef(object):
21-
"""A reference to an object stored in a Mongo database.
23+
"""A reference to a document stored in a Mongo database.
2224
"""
2325

24-
def __init__(self, collection, id):
26+
def __init__(self, collection, id, database=None):
2527
"""Initialize a new DBRef.
2628
27-
Raises TypeError if collection is not an instance of (str, unicode).
29+
Raises TypeError if collection or database is not an instance of
30+
(str, unicode). `database` is optional and allows references to
31+
documents to work across databases.
2832
2933
:Parameters:
30-
- `collection`: the collection the object is stored in
31-
- `id`: the value of the object's _id field
34+
- `collection`: name of the collection the document is stored in
35+
- `id`: the value of the document's _id field
36+
- `database` (optional): name of the database to reference
3237
"""
3338
if not isinstance(collection, types.StringTypes):
3439
raise TypeError("collection must be an instance of (str, unicode)")
35-
36-
if isinstance(collection, types.StringType):
37-
collection = unicode(collection, "utf-8")
40+
if not isinstance(database, (types.StringTypes, types.NoneType)):
41+
raise TypeError("database must be an instance of (str, unicode)")
3842

3943
self.__collection = collection
4044
self.__id = id
45+
self.__database = database
4146

4247
def collection(self):
43-
"""Get this DBRef's collection as unicode.
48+
"""Get the name of this DBRef's collection as unicode.
4449
"""
4550
return self.__collection
4651
collection = property(collection)
@@ -51,14 +56,35 @@ def id(self):
5156
return self.__id
5257
id = property(id)
5358

59+
def database(self):
60+
"""Get the name of this DBRef's database.
61+
62+
Returns None if this DBRef doesn't specify a database.
63+
"""
64+
return self.__database
65+
database = property(database)
66+
67+
def as_doc(self):
68+
"""Get the SON document representation of this DBRef.
69+
70+
Generally not needed by application developers
71+
"""
72+
doc = SON([("$ref", self.collection),
73+
("$id", self.id)])
74+
if self.database is not None:
75+
doc["$db"] = self.database
76+
return doc
77+
5478
def __repr__(self):
55-
return "DBRef(" + repr(self.collection) + ", " + repr(self.id) + ")"
79+
if self.database is None:
80+
return "DBRef(%r, %r)" % (self.collection, self.id)
81+
return "DBRef(%r, %r, %r)" % (self.collection, self.id, self.database)
5682

5783
def __cmp__(self, other):
5884
if isinstance(other, DBRef):
59-
return cmp([self.__collection, self.__id],
60-
[other.__collection, other.__id])
85+
return cmp([self.__database, self.__collection, self.__id],
86+
[other.__database, other.__collection, other.__id])
6187
return NotImplemented
62-
88+
6389
def __hash__(self):
64-
return hash((self.__collection, self.__id))
90+
return hash((self.__collection, self.__id, self.__database))

pymongo/son.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,6 @@
2424
import base64
2525
import types
2626

27-
try:
28-
import xml.etree.ElementTree as ET
29-
except ImportError:
30-
import elementtree.ElementTree as ET
31-
32-
from code import Code
33-
from binary import Binary
34-
from objectid import ObjectId
35-
from dbref import DBRef
36-
from errors import UnsupportedTag
37-
3827

3928
class SON(dict):
4029
"""SON data.
@@ -223,7 +212,19 @@ def transform_value(value):
223212

224213
def from_xml(cls, xml):
225214
"""Create an instance of SON from an xml document.
215+
216+
This is really only used for testing, and is probably unnecessary.
226217
"""
218+
try:
219+
import xml.etree.ElementTree as ET
220+
except ImportError:
221+
import elementtree.ElementTree as ET
222+
223+
from code import Code
224+
from binary import Binary
225+
from objectid import ObjectId
226+
from dbref import DBRef
227+
from errors import UnsupportedTag
227228

228229
def pad(list, index):
229230
while index >= len(list):

pymongo/son_manipulator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ class AutoReference(SONManipulator):
117117
only be auto-referenced if they have an *_ns* field.
118118
119119
NOTE: this will behave poorly if you have a circular reference.
120+
121+
TODO: this only works for documents that are in the same database. To fix
122+
this we'll need to add a DatabaseInjector that adds *_db* and then make
123+
use of the optional *database* support for DBRefs.
120124
"""
121125

122126
def __init__(self, db):

test/test_bson.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def helper(dict):
166166
helper({"another binary": Binary("test")})
167167
helper(SON([(u'test dst', datetime.datetime(1993, 4, 4, 2))]))
168168
helper({"big float": float(10000000000)})
169+
helper({"ref": DBRef("coll", 5)})
170+
helper({"ref": DBRef("coll", 5, "foo")})
169171

170172
def from_then_to_dict(dict):
171173
return dict == (BSON.from_dict(dict)).to_dict()

test/test_database.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ def test_deref(self):
271271
obj = {"x": True}
272272
key = db.test.save(obj)
273273
self.assertEqual(obj, db.dereference(DBRef("test", key)))
274+
self.assertEqual(obj, db.dereference(DBRef("test", key, "pymongo_test")))
275+
self.assertRaises(ValueError, db.dereference, DBRef("test", key, "foo"))
274276

275277
self.assertEqual(None, db.dereference(DBRef("test", 4)))
276278
obj = {"_id": 4}

test/test_dbref.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ def test_creation(self):
3636
self.assertRaises(TypeError, DBRef, 1.5, a)
3737
self.assertRaises(TypeError, DBRef, a, a)
3838
self.assertRaises(TypeError, DBRef, None, a)
39+
self.assertRaises(TypeError, DBRef, "coll", a, 5)
3940
self.assert_(DBRef("coll", a))
4041
self.assert_(DBRef(u"coll", a))
4142
self.assert_(DBRef(u"coll", 5))
43+
self.assert_(DBRef(u"coll", 5, "database"))
4244

4345
def test_read_only(self):
4446
a = DBRef("coll", ObjectId())
@@ -49,25 +51,35 @@ def foo():
4951
def bar():
5052
a.id = "aoeu"
5153

52-
a.collection
54+
self.assertEqual("coll", a.collection)
5355
a.id
56+
self.assertEqual(None, a.database)
5457
self.assertRaises(AttributeError, foo)
5558
self.assertRaises(AttributeError, bar)
5659

5760
def test_repr(self):
5861
self.assertEqual(repr(DBRef("coll", ObjectId("1234567890abcdef12345678"))),
59-
"DBRef(u'coll', ObjectId('1234567890abcdef12345678'))")
62+
"DBRef('coll', ObjectId('1234567890abcdef12345678'))")
6063
self.assertEqual(repr(DBRef(u"coll", ObjectId("1234567890abcdef12345678"))),
6164
"DBRef(u'coll', ObjectId('1234567890abcdef12345678'))")
65+
self.assertEqual(repr(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo")),
66+
"DBRef('coll', ObjectId('1234567890abcdef12345678'), 'foo')")
6267

6368
def test_cmp(self):
6469
self.assertEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")),
6570
DBRef(u"coll", ObjectId("1234567890abcdef12345678")))
71+
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")),
72+
DBRef(u"coll", ObjectId("1234567890abcdef12345678"), "foo"))
6673
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")),
6774
DBRef("col", ObjectId("1234567890abcdef12345678")))
6875
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")),
6976
DBRef("coll", ObjectId("123456789011")))
7077
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")), 4)
78+
self.assertEqual(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo"),
79+
DBRef(u"coll", ObjectId("1234567890abcdef12345678"), "foo"))
80+
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo"),
81+
DBRef(u"coll", ObjectId("1234567890abcdef12345678"), "bar"))
82+
7183

7284
if __name__ == "__main__":
7385
unittest.main()

0 commit comments

Comments
 (0)