@@ -363,6 +363,15 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
363363 * (buffer -> buffer + type_byte ) = 0x07 ;
364364 return 1 ;
365365 } else if (PyObject_IsInstance (value , DBRef )) {
366+ * (buffer -> buffer + type_byte ) = 0x03 ;
367+ int start_position = buffer -> position ;
368+
369+ // save space for length
370+ int length_location = buffer_save_bytes (buffer , 4 );
371+ if (length_location == -1 ) {
372+ return 0 ;
373+ }
374+
366375 PyObject * collection_object = PyObject_GetAttrString (value , "collection" );
367376 if (!collection_object ) {
368377 return 0 ;
@@ -382,36 +391,47 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
382391 Py_DECREF (encoded_collection );
383392 return 0 ;
384393 }
385- PyObject * id_str = PyObject_Str (id_object );
386- Py_DECREF (id_object );
387- if (!id_str ) {
388- Py_DECREF (encoded_collection );
389- return 0 ;
390- }
391- const char * id = PyString_AsString (id_str );
392- if (!id ) {
394+
395+ if (!buffer_write_bytes (buffer , "\x02$ref\x00" , 6 )) {
393396 Py_DECREF (encoded_collection );
394- Py_DECREF (id_str );
397+ Py_DECREF (id_object );
395398 return 0 ;
396399 }
397400 int collection_length = strlen (collection ) + 1 ;
398401 if (!buffer_write_bytes (buffer , (const char * )& collection_length , 4 )) {
399402 Py_DECREF (encoded_collection );
400- Py_DECREF (id_str );
403+ Py_DECREF (id_object );
401404 return 0 ;
402405 }
403406 if (!buffer_write_bytes (buffer , collection , collection_length )) {
404407 Py_DECREF (encoded_collection );
405- Py_DECREF (id_str );
408+ Py_DECREF (id_object );
406409 return 0 ;
407410 }
408411 Py_DECREF (encoded_collection );
409- if (!write_shuffled_oid (buffer , id )) {
410- Py_DECREF (id_str );
412+
413+ int type_pos = buffer_save_bytes (buffer , 1 );
414+ if (type_pos == -1 ) {
415+ Py_DECREF (id_object );
416+ return 0 ;
417+ }
418+ if (!buffer_write_bytes (buffer , "$id\x00" , 4 )) {
419+ Py_DECREF (id_object );
420+ return 0 ;
421+ }
422+ if (!write_element_to_buffer (buffer , type_pos , id_object )) {
423+ Py_DECREF (id_object );
411424 return 0 ;
412425 }
413- Py_DECREF (id_str );
414- * (buffer -> buffer + type_byte ) = 0x0C ;
426+ Py_DECREF (id_object );
427+
428+ // write null byte and fill in length
429+ char zero = 0 ;
430+ if (!buffer_write_bytes (buffer , & zero , 1 )) {
431+ return 0 ;
432+ }
433+ int length = buffer -> position - start_position ;
434+ memcpy (buffer -> buffer + length_location , & length , 4 );
415435 return 1 ;
416436 }
417437 else if (PyObject_HasAttrString (value , "pattern" ) &&
@@ -626,9 +646,29 @@ static PyObject* get_value(const char* buffer, int* position, int type) {
626646 {
627647 int size ;
628648 memcpy (& size , buffer + * position , 4 );
629- value = elements_to_dict (buffer + * position + 4 , size - 5 );
630- if (!value ) {
631- return NULL ;
649+ if (strcmp (buffer + * position + 5 , "$ref" ) == 0 ) { // DBRef
650+ int offset = * position + 14 ;
651+ int collection_length = strlen (buffer + offset );
652+ PyObject * collection = PyUnicode_DecodeUTF8 (buffer + offset , collection_length , "strict" );
653+ if (!collection ) {
654+ return NULL ;
655+ }
656+ offset += collection_length + 1 ;
657+ char id_type = buffer [offset ];
658+ offset += 5 ;
659+ PyObject * id = get_value (buffer , & offset , (int )id_type );
660+ if (!id ) {
661+ Py_DECREF (collection );
662+ return NULL ;
663+ }
664+ value = PyObject_CallFunctionObjArgs (DBRef , collection , id , NULL );
665+ Py_DECREF (collection );
666+ Py_DECREF (id );
667+ } else {
668+ value = elements_to_dict (buffer + * position + 4 , size - 5 );
669+ if (!value ) {
670+ return NULL ;
671+ }
632672 }
633673 * position += size ;
634674 break ;
0 commit comments