diff --git a/.github/workflows/zizmor.yml b/.github/workflows/zizmor.yml index 6d9506776a..1d58c0d5fb 100644 --- a/.github/workflows/zizmor.yml +++ b/.github/workflows/zizmor.yml @@ -18,4 +18,4 @@ jobs: with: persist-credentials: false - name: Run zizmor 🌈 - uses: zizmorcore/zizmor-action@0f0557ab4a0b31211d42435e42df31cbd63fdd59 + uses: zizmorcore/zizmor-action@1c7106082dbc1753372e3924b7da1b9417011a21 diff --git a/doc/changelog.rst b/doc/changelog.rst index 933e2922db..2d1447d505 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -10,6 +10,7 @@ PyMongo 4.14 brings a number of changes including: - Added :meth:`pymongo.asynchronous.mongo_client.AsyncMongoClient.append_metadata` and :meth:`pymongo.mongo_client.MongoClient.append_metadata` to allow instantiated MongoClients to send client metadata on-demand +- Improved performance of selecting a server with the Primary selector. - Introduces a minor breaking change. When encoding :class:`bson.binary.BinaryVector`, a ``ValueError`` will be raised if the 'padding' metadata field is < 0 or > 7, or non-zero for any type other than PACKED_BIT. diff --git a/pymongo/topology_description.py b/pymongo/topology_description.py index 29293b2314..5650ca8391 100644 --- a/pymongo/topology_description.py +++ b/pymongo/topology_description.py @@ -34,7 +34,7 @@ from bson.objectid import ObjectId from pymongo import common from pymongo.errors import ConfigurationError, PyMongoError -from pymongo.read_preferences import ReadPreference, _AggWritePref, _ServerMode +from pymongo.read_preferences import Primary, ReadPreference, _AggWritePref, _ServerMode from pymongo.server_description import ServerDescription from pymongo.server_selectors import Selection from pymongo.server_type import SERVER_TYPE @@ -324,6 +324,14 @@ def apply_selector( description = self.server_descriptions().get(address) return [description] if description else [] + # Primary selection fast path. + if self.topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary and type(selector) is Primary: + for sd in self._server_descriptions.values(): + if sd.server_type == SERVER_TYPE.RSPrimary: + return [sd] + # No primary found, return an empty list. + return [] + selection = Selection.from_topology_description(self) # Ignore read preference for sharded clusters. if self.topology_type != TOPOLOGY_TYPE.Sharded: diff --git a/test/asynchronous/test_server_selection.py b/test/asynchronous/test_server_selection.py index f98a05ee91..f570662b85 100644 --- a/test/asynchronous/test_server_selection.py +++ b/test/asynchronous/test_server_selection.py @@ -130,12 +130,12 @@ async def test_selector_called(self): test_collection = mongo_client.testdb.test_collection self.addAsyncCleanup(mongo_client.drop_database, "testdb") - # Do N operations and test selector is called at least N times. + # Do N operations and test selector is called at least N-1 times due to fast path. await test_collection.insert_one({"age": 20, "name": "John"}) await test_collection.insert_one({"age": 31, "name": "Jane"}) await test_collection.update_one({"name": "Jane"}, {"$set": {"age": 21}}) await test_collection.find_one({"name": "Roe"}) - self.assertGreaterEqual(selector.call_count, 4) + self.assertGreaterEqual(selector.call_count, 3) @async_client_context.require_replica_set async def test_latency_threshold_application(self): diff --git a/test/test_server_selection.py b/test/test_server_selection.py index aec8e2e47a..4384deac2b 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -130,12 +130,12 @@ def test_selector_called(self): test_collection = mongo_client.testdb.test_collection self.addCleanup(mongo_client.drop_database, "testdb") - # Do N operations and test selector is called at least N times. + # Do N operations and test selector is called at least N-1 times due to fast path. test_collection.insert_one({"age": 20, "name": "John"}) test_collection.insert_one({"age": 31, "name": "Jane"}) test_collection.update_one({"name": "Jane"}, {"$set": {"age": 21}}) test_collection.find_one({"name": "Roe"}) - self.assertGreaterEqual(selector.call_count, 4) + self.assertGreaterEqual(selector.call_count, 3) @client_context.require_replica_set def test_latency_threshold_application(self):