diff --git a/pymongo/message.py b/pymongo/message.py index d51c77a174..272f38463d 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -24,6 +24,7 @@ import datetime import random import struct +from collections import ChainMap from io import BytesIO as _BytesIO from typing import ( TYPE_CHECKING, @@ -1115,8 +1116,18 @@ def _check_doc_size_limits( # key and the index of its namespace within ns_info as its value. op_doc[op_type] = ns_info[namespace] # type: ignore[index] + # Since the data document itself is nested within the insert document + # it won't be automatically re-ordered by the BSON conversion. + # We use ChainMap here to make the _id field the first field instead. + doc_to_encode = op_doc + if real_op_type == "insert": + doc = op_doc["document"] + if not isinstance(doc, RawBSONDocument): + doc_to_encode = op_doc.copy() # type: ignore[attr-defined] # Shallow copy + doc_to_encode["document"] = ChainMap(doc, {"_id": doc["_id"]}) # type: ignore[index] + # Encode current operation doc and, if newly added, namespace doc. - op_doc_encoded = _dict_to_bson(op_doc, False, opts) + op_doc_encoded = _dict_to_bson(doc_to_encode, False, opts) op_length = len(op_doc_encoded) if ns_doc: ns_doc_encoded = _dict_to_bson(ns_doc, False, opts) diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index 49f969fa34..a8afc5cc79 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -18,6 +18,9 @@ import os import sys +from bson import encode +from bson.raw_bson import RawBSONDocument + sys.path[0:0] = [""] from test.asynchronous import ( @@ -82,6 +85,16 @@ async def test_formats_write_error_correctly(self): self.assertEqual(write_error["idx"], 1) self.assertEqual(write_error["op"], {"insert": 0, "document": {"_id": 1}}) + @async_client_context.require_version_min(8, 0, 0, -24) + async def test_raw_bson_not_inflated(self): + doc = RawBSONDocument(encode({"a": "b" * 100})) + models = [ + InsertOne(namespace="db.coll", document=doc), + ] + await self.client.bulk_write(models=models) + + self.assertIsNone(doc._RawBSONDocument__inflated_doc) + # https://github.com/mongodb/specifications/tree/master/source/crud/tests # Note: tests 1 and 2 are in test_read_write_concern_spec.py diff --git a/test/mockupdb/test_id_ordering.py b/test/mockupdb/test_id_ordering.py new file mode 100644 index 0000000000..7e2c91d592 --- /dev/null +++ b/test/mockupdb/test_id_ordering.py @@ -0,0 +1,94 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from test import PyMongoTestCase + +import pytest + +from pymongo import InsertOne + +try: + from mockupdb import MockupDB, OpMsg, go, going + + _HAVE_MOCKUPDB = True +except ImportError: + _HAVE_MOCKUPDB = False + + +from bson.objectid import ObjectId + +pytestmark = pytest.mark.mockupdb + + +# https://github.com/mongodb/specifications/blob/master/source/crud/tests/README.md#16-generated-document-identifiers-are-the-first-field-in-their-document +class TestIdOrdering(PyMongoTestCase): + def test_16_generated_document_ids_are_first_field(self): + server = MockupDB() + server.autoresponds( + "hello", + isWritablePrimary=True, + msg="isdbgrid", + minWireVersion=0, + maxWireVersion=25, + helloOk=True, + serviceId=ObjectId(), + ) + server.run() + self.addCleanup(server.stop) + + # We also verify that the original document contains an _id field after each insert + document = {"x": 1} + + client = self.simple_client(server.uri, loadBalanced=True) + collection = client.db.coll + with going(collection.insert_one, document): + request = server.receives() + self.assertEqual("_id", next(iter(request["documents"][0]))) + request.reply({"ok": 1}) + self.assertIn("_id", document) + + document = {"x1": 1} + + with going(collection.bulk_write, [InsertOne(document)]): + request = server.receives() + self.assertEqual("_id", next(iter(request["documents"][0]))) + request.reply({"ok": 1}) + self.assertIn("_id", document) + + document = {"x2": 1} + with going(client.bulk_write, [InsertOne(namespace="db.coll", document=document)]): + request = server.receives() + self.assertEqual("_id", next(iter(request["ops"][0]["document"]))) + request.reply({"ok": 1}) + self.assertIn("_id", document) + + # Re-ordering user-supplied _id fields is not required by the spec, but PyMongo does it for performance reasons + with going(collection.insert_one, {"x": 1, "_id": 111}): + request = server.receives() + self.assertEqual("_id", next(iter(request["documents"][0]))) + request.reply({"ok": 1}) + + with going(collection.bulk_write, [InsertOne({"x1": 1, "_id": 1111})]): + request = server.receives() + self.assertEqual("_id", next(iter(request["documents"][0]))) + request.reply({"ok": 1}) + + with going( + client.bulk_write, [InsertOne(namespace="db.coll", document={"x2": 1, "_id": 11111})] + ): + request = server.receives() + self.assertEqual("_id", next(iter(request["ops"][0]["document"]))) + request.reply({"ok": 1}) diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index 0cb6845099..92e87241f3 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -18,6 +18,9 @@ import os import sys +from bson import encode +from bson.raw_bson import RawBSONDocument + sys.path[0:0] = [""] from test import ( @@ -82,6 +85,16 @@ def test_formats_write_error_correctly(self): self.assertEqual(write_error["idx"], 1) self.assertEqual(write_error["op"], {"insert": 0, "document": {"_id": 1}}) + @client_context.require_version_min(8, 0, 0, -24) + def test_raw_bson_not_inflated(self): + doc = RawBSONDocument(encode({"a": "b" * 100})) + models = [ + InsertOne(namespace="db.coll", document=doc), + ] + self.client.bulk_write(models=models) + + self.assertIsNone(doc._RawBSONDocument__inflated_doc) + # https://github.com/mongodb/specifications/tree/master/source/crud/tests # Note: tests 1 and 2 are in test_read_write_concern_spec.py