diff --git a/.github/workflows/test-python-atlas.yml b/.github/workflows/test-python-atlas.yml index 175dfe18..f84a9934 100644 --- a/.github/workflows/test-python-atlas.yml +++ b/.github/workflows/test-python-atlas.yml @@ -53,4 +53,4 @@ jobs: working-directory: . run: bash .github/workflows/start_local_atlas.sh mongodb/mongodb-atlas-local:7 - name: Run tests - run: python3 django_repo/tests/runtests_.py + run: python3 django_repo/tests/runtests.py --settings mongodb_settings -v 2 diff --git a/django_mongodb_backend/__init__.py b/django_mongodb_backend/__init__.py index 00700421..1c9f88f3 100644 --- a/django_mongodb_backend/__init__.py +++ b/django_mongodb_backend/__init__.py @@ -2,7 +2,7 @@ # Check Django compatibility before other imports which may fail if the # wrong version of Django is installed. -from .utils import check_django_compatability, parse_uri +from .utils import check_django_compatability, get_auto_encryption_options, parse_uri check_django_compatability() @@ -15,7 +15,7 @@ from .lookups import register_lookups # noqa: E402 from .query import register_nodes # noqa: E402 -__all__ = ["parse_uri"] +__all__ = ["get_auto_encryption_options", "parse_uri"] register_aggregates() register_checks() diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index c6110dbb..fc21fa5b 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -2,7 +2,9 @@ import os from django.core.exceptions import ImproperlyConfigured +from django.db import DEFAULT_DB_ALIAS from django.db.backends.base.base import BaseDatabaseWrapper +from django.db.backends.utils import debug_transaction from django.utils.asyncio import async_unsafe from django.utils.functional import cached_property from pymongo.collection import Collection @@ -32,6 +34,17 @@ def __exit__(self, exception_type, exception_value, exception_traceback): pass +def requires_transaction_support(func): + """Make a method a no-op if transactions aren't supported.""" + + def wrapper(self, *args, **kwargs): + if not self.features.supports_transactions: + return + func(self, *args, **kwargs) + + return wrapper + + class DatabaseWrapper(BaseDatabaseWrapper): data_types = { "AutoField": "int", @@ -140,6 +153,10 @@ def _isnull_operator(a, b): ops_class = DatabaseOperations validation_class = DatabaseValidation + def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS): + super().__init__(settings_dict, alias=alias) + self.session = None + def get_collection(self, name, **kwargs): collection = Collection(self.database, name, **kwargs) if self.queries_logged: @@ -189,14 +206,48 @@ def _driver_info(self): return DriverInfo("django-mongodb-backend", django_mongodb_backend_version) return None + @requires_transaction_support def _commit(self): - pass + if self.session: + with debug_transaction(self, "session.commit_transaction()"): + self.session.commit_transaction() + self._end_session() + @requires_transaction_support def _rollback(self): - pass + if self.session: + with debug_transaction(self, "session.abort_transaction()"): + self.session.abort_transaction() + self._end_session() + + def _start_transaction(self): + # Private API, specific to this backend. + if self.session is None: + self.session = self.connection.start_session() + with debug_transaction(self, "session.start_transaction()"): + self.session.start_transaction() - def set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False): - self.autocommit = autocommit + def _end_session(self): + # Private API, specific to this backend. + self.session.end_session() + self.session = None + + @requires_transaction_support + def _start_transaction_under_autocommit(self): + # Implementing this hook (intended only for SQLite), allows + # BaseDatabaseWrapper.set_autocommit() to use it to start a transaction + # rather than set_autocommit(), bypassing set_autocommit()'s call to + # debug_transaction(self, "BEGIN") which isn't semantic for a no-SQL + # backend. + self._start_transaction() + + @requires_transaction_support + def _set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False): + # Besides @transaction.atomic() (which uses + # _start_transaction_under_autocommit(), disabling autocommit is + # another way to start a transaction. + if not autocommit: + self._start_transaction() def _close(self): # Normally called by close(), this method is also called by some tests. @@ -210,6 +261,10 @@ def close(self): def close_pool(self): """Close the MongoClient.""" + # Clear commit hooks and session. + self.run_on_commit = [] + if self.session: + self._end_session() connection = self.connection if connection is None: return @@ -225,6 +280,10 @@ def close_pool(self): def cursor(self): return Cursor() + @requires_transaction_support + def validate_no_broken_transaction(self): + super().validate_no_broken_transaction() + def get_database_version(self): """Return a tuple of the database's version.""" return tuple(self.connection.server_info()["versionArray"]) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 12da13a1..b7e264f8 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -685,7 +685,10 @@ def execute_sql(self, returning_fields=None): @wrap_database_errors def insert(self, docs, returning_fields=None): """Store a list of documents using field columns as element names.""" - inserted_ids = self.collection.insert_many(docs).inserted_ids + self.connection.validate_no_broken_transaction() + inserted_ids = self.collection.insert_many( + docs, session=self.connection.session + ).inserted_ids return [(x,) for x in inserted_ids] if returning_fields else [] @cached_property @@ -768,7 +771,10 @@ def execute_sql(self, result_type): @wrap_database_errors def update(self, criteria, pipeline): - return self.collection.update_many(criteria, pipeline).matched_count + self.connection.validate_no_broken_transaction() + return self.collection.update_many( + criteria, pipeline, session=self.connection.session + ).matched_count def check_query(self): super().check_query() diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index fa73461d..487abffe 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -36,8 +36,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_temporal_subtraction = True # MongoDB stores datetimes in UTC. supports_timezones = False - # Not implemented: https://github.com/mongodb/django-mongodb-backend/issues/7 - supports_transactions = False supports_unspecified_pk = True uses_savepoints = False @@ -50,8 +48,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): "aggregation.tests.AggregateTestCase.test_order_by_aggregate_transform", # 'NulledTransform' object has no attribute 'as_mql'. "lookup.tests.LookupTests.test_exact_none_transform", - # "Save with update_fields did not affect any rows." - "basic.tests.SelectOnSaveTests.test_select_on_save_lying_update", # BaseExpression.convert_value() crashes with Decimal128. "aggregation.tests.AggregateTestCase.test_combine_different_types", "annotations.tests.NonAggregateAnnotationTestCase.test_combined_f_expression_annotation_with_aggregation", @@ -96,6 +92,36 @@ class DatabaseFeatures(BaseDatabaseFeatures): "expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null", "expressions.tests.ExpressionOperatorTests.test_lefthand_transformed_field_bitwise_or", } + _django_test_expected_failures_no_transactions = { + # "Save with update_fields did not affect any rows." instead of + # "An error occurred in the current transaction. You can't execute + # queries until the end of the 'atomic' block." + "basic.tests.SelectOnSaveTests.test_select_on_save_lying_update", + } + _django_test_expected_failures_transactions = { + # When update_or_create() fails with IntegrityError, the transaction + # is no longer usable. + "get_or_create.tests.UpdateOrCreateTests.test_manual_primary_key_test", + "get_or_create.tests.UpdateOrCreateTestsWithManualPKs.test_create_with_duplicate_primary_key", + # Tests that require savepoints + "admin_views.tests.AdminViewBasicTest.test_disallowed_to_field", + "admin_views.tests.AdminViewPermissionsTest.test_add_view", + "admin_views.tests.AdminViewPermissionsTest.test_change_view", + "admin_views.tests.AdminViewPermissionsTest.test_change_view_save_as_new", + "admin_views.tests.AdminViewPermissionsTest.test_delete_view", + "auth_tests.test_views.ChangelistTests.test_view_user_password_is_readonly", + "fixtures.tests.FixtureLoadingTests.test_loaddata_app_option", + "fixtures.tests.FixtureLoadingTests.test_unmatched_identifier_loading", + "fixtures_model_package.tests.FixtureTestCase.test_loaddata", + "get_or_create.tests.GetOrCreateTests.test_get_or_create_invalid_params", + "get_or_create.tests.UpdateOrCreateTests.test_integrity", + "many_to_many.tests.ManyToManyTests.test_add", + "many_to_one.tests.ManyToOneTests.test_fk_assignment_and_related_object_cache", + "model_fields.test_booleanfield.BooleanFieldTests.test_null_default", + "model_fields.test_floatfield.TestFloatField.test_float_validates_object", + "multiple_database.tests.QueryTestCase.test_generic_key_cross_database_protection", + "multiple_database.tests.QueryTestCase.test_m2m_cross_database_protection", + } @cached_property def django_test_expected_failures(self): @@ -103,6 +129,10 @@ def django_test_expected_failures(self): expected_failures.update(self._django_test_expected_failures) if not self.is_mongodb_6_3: expected_failures.update(self._django_test_expected_failures_bitwise) + if self.supports_transactions: + expected_failures.update(self._django_test_expected_failures_transactions) + else: + expected_failures.update(self._django_test_expected_failures_no_transactions) return expected_failures django_test_skips = { @@ -485,16 +515,6 @@ def django_test_expected_failures(self): "Connection health checks not implemented.": { "backends.base.test_base.ConnectionHealthChecksTests", }, - "transaction.atomic() is not supported.": { - "backends.base.test_base.DatabaseWrapperLoggingTests", - "migrations.test_executor.ExecutorTests.test_atomic_operation_in_non_atomic_migration", - "migrations.test_operations.OperationTests.test_run_python_atomic", - }, - "transaction.rollback() is not supported.": { - "transactions.tests.AtomicMiscTests.test_mark_for_rollback_on_error_in_autocommit", - "transactions.tests.AtomicMiscTests.test_mark_for_rollback_on_error_in_transaction", - "transactions.tests.NonAutocommitTests.test_orm_query_after_error_and_rollback", - }, "migrate --fake-initial is not supported.": { "migrations.test_commands.MigrateTests.test_migrate_fake_initial", "migrations.test_commands.MigrateTests.test_migrate_fake_split_initial", @@ -533,8 +553,18 @@ def django_test_expected_failures(self): "foreign_object.test_tuple_lookups.TupleLookupsTests", }, "ColPairs is not supported.": { - # 'ColPairs' object has no attribute 'as_mql' "auth_tests.test_views.CustomUserCompositePrimaryKeyPasswordResetTest", + "composite_pk.test_aggregate.CompositePKAggregateTests", + "composite_pk.test_create.CompositePKCreateTests", + "composite_pk.test_delete.CompositePKDeleteTests", + "composite_pk.test_filter.CompositePKFilterTests", + "composite_pk.test_get.CompositePKGetTests", + "composite_pk.test_models.CompositePKModelsTests", + "composite_pk.test_order_by.CompositePKOrderByTests", + "composite_pk.test_update.CompositePKUpdateTests", + "composite_pk.test_values.CompositePKValuesTests", + "composite_pk.tests.CompositePKTests", + "composite_pk.tests.CompositePKFixturesTests", }, "Custom lookups are not supported.": { "custom_lookups.tests.BilateralTransformTests", @@ -577,3 +607,35 @@ def supports_atlas_search(self): return False else: return True + + @cached_property + def supports_select_union(self): + # Stage not supported inside of a multi-document transaction: $unionWith + return not self.supports_transactions + + @cached_property + def supports_transactions(self): + """ + Transactions are enabled if the MongoDB configuration supports it: + MongoDB must be configured as a replica set or sharded cluster, and + the store engine must be WiredTiger. + """ + self.connection.ensure_connection() + client = self.connection.connection.admin + hello_response = client.command("hello") + is_replica_set = "setName" in hello_response + is_sharded_cluster = hello_response.get("msg") == "isdbgrid" + if is_replica_set or is_sharded_cluster: + engine = client.command("serverStatus").get("storageEngine", {}) + return engine.get("name") == "wiredTiger" + return False + + @cached_property + def supports_queryable_encryption(self): + """ + Queryable Encryption is available if the server is Atlas or Enterprise. + """ + self.connection.ensure_connection() + client = self.connection.connection.admin + build_info = client.command("buildInfo") + return "enterprise" in build_info.get("modules") diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index 04977520..d59bc163 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -61,9 +61,12 @@ def __repr__(self): @wrap_database_errors def delete(self): """Execute a delete query.""" + self.compiler.connection.validate_no_broken_transaction() if self.compiler.subqueries: raise NotSupportedError("Cannot use QuerySet.delete() when a subquery is required.") - return self.compiler.collection.delete_many(self.match_mql).deleted_count + return self.compiler.collection.delete_many( + self.match_mql, session=self.compiler.connection.session + ).deleted_count @wrap_database_errors def get_cursor(self): @@ -71,7 +74,10 @@ def get_cursor(self): Return a pymongo CommandCursor that can be iterated on to give the results of the query. """ - return self.compiler.collection.aggregate(self.get_pipeline()) + self.compiler.connection.validate_no_broken_transaction() + return self.compiler.collection.aggregate( + self.get_pipeline(), session=self.compiler.connection.session + ) def get_pipeline(self): pipeline = [] diff --git a/django_mongodb_backend/queryset.py b/django_mongodb_backend/queryset.py index 4a2d884d..b02496fe 100644 --- a/django_mongodb_backend/queryset.py +++ b/django_mongodb_backend/queryset.py @@ -35,7 +35,7 @@ def __init__(self, pipeline, using, model): def _execute_query(self): connection = connections[self.using] collection = connection.get_collection(self.model._meta.db_table) - self.cursor = collection.aggregate(self.pipeline) + self.cursor = collection.aggregate(self.pipeline, session=connection.session) def __str__(self): return str(self.pipeline) diff --git a/django_mongodb_backend/utils.py b/django_mongodb_backend/utils.py index ced60bc8..15ebc103 100644 --- a/django_mongodb_backend/utils.py +++ b/django_mongodb_backend/utils.py @@ -1,5 +1,8 @@ import copy +import os import time +from pathlib import Path +from urllib.parse import urlencode import django from django.conf import settings @@ -8,6 +11,7 @@ from django.utils.functional import SimpleLazyObject from django.utils.text import format_lazy from django.utils.version import get_version_tuple +from pymongo.encryption_options import AutoEncryptionOpts from pymongo.uri_parser import parse_uri as pymongo_parse_uri @@ -28,6 +32,62 @@ def check_django_compatability(): ) +# Queryable Encryption-related functions based on helpers from Python Queryable +# Encryption Tutorial +# https://github.com/mongodb/docs/tree/master/source/includes/qe-tutorials/python/ +def _get_kms_provider_credentials(kms_provider_name): + """ + "A KMS is a remote service that securely stores and manages your encryption keys." + + Via https://www.mongodb.com/docs/manual/core/queryable-encryption/quick-start/ + + Here we check the provider name and return the appropriate credentials. + """ + # TODO: Add support for other KMS providers. + if kms_provider_name == "local": + if not Path("./customer-master-key.txt").exists: + try: + path = "customer-master-key.txt" + file_bytes = os.urandom(96) + with Path.open(path, "wb") as f: + f.write(file_bytes) + except Exception as e: + raise Exception( + "Unable to write Customer Master Key to file due to the following error: " + ) from e + + try: + path = "./customer-master-key.txt" + with Path.open(path, "rb") as f: + local_master_key = f.read() + if len(local_master_key) != 96: + raise Exception("Expected the customer master key file to be 96 bytes.") + return { + "local": {"key": local_master_key}, + } + except Exception as e: + raise Exception( + "Unable to read Customer Master Key from file due to the following error: " + ) from e + else: + raise ValueError( + "Unrecognized value for kms_provider_name encountered while retrieving KMS credentials." + ) + + +def get_auto_encryption_options(kms_provider_name): + key_vault_database_name = "encryption" + key_vault_collection_name = "__keyVault" + key_vault_namespace = f"{key_vault_database_name}.{key_vault_collection_name}" + kms_provider_credentials = _get_kms_provider_credentials(kms_provider_name) + auto_encryption_opts = AutoEncryptionOpts( + kms_provider_credentials, + key_vault_namespace, + crypt_shared_lib_path=os.environ.get("SHARED_LIB_PATH"), + ) + return urlencode(auto_encryption_opts) + + def parse_uri(uri, *, db_name=None, test=None): """ Convert the given uri into a dictionary suitable for Django's DATABASES diff --git a/docs/source/conf.py b/docs/source/conf.py index a4e54938..2f1c8675 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -45,6 +45,7 @@ "pymongo": ("https://pymongo.readthedocs.io/en/stable/", None), "python": ("https://docs.python.org/3/", None), "atlas": ("https://www.mongodb.com/docs/atlas/", None), + "manual": ("https://www.mongodb.com/docs/manual/", None), } root_doc = "contents" diff --git a/docs/source/ref/database.rst b/docs/source/ref/database.rst index 365e1c73..3b0d0930 100644 --- a/docs/source/ref/database.rst +++ b/docs/source/ref/database.rst @@ -41,3 +41,45 @@ effect. Rather, if you need to close the connection pool, use .. versionadded:: 5.2.0b0 Support for connection pooling and ``connection.close_pool()`` were added. + +.. _transactions: + +Transactions +============ + +.. versionadded:: 5.2.0b2 + +Support for :doc:`Django's transactions APIs ` +is enabled if the MongoDB configuration supports them: MongoDB must be +configured as a :doc:`replica set ` or :doc:`sharded +cluster `, and the store engine must be :doc:`WiredTiger +`. + +If transactions aren't supported, query execution uses Django and MongoDB's +default behavior of autocommit mode. Each query is immediately committed to the +database. Django's transaction management APIs, such as +:func:`~django.db.transaction.atomic`, function as no-ops. + +.. _transactions-limitations: + +Limitations +----------- + +MongoDB's transaction limitations that are applicable to Django are: + +- :meth:`QuerySet.union() ` is not + supported inside a transaction. +- If a transaction raises an exception, the transaction is no longer usable. + For example, if the update stage of :meth:`QuerySet.update_or_create() + ` fails with + :class:`~django.db.IntegrityError` due to a unique constraint violation, the + create stage won't be able to proceed. + :class:`pymongo.errors.OperationFailure` is raised, wrapped by + :class:`django.db.DatabaseError`. +- Savepoints (i.e. nested :func:`~django.db.transaction.atomic` blocks) aren't + supported. The outermost :func:`~django.db.transaction.atomic` will start + a transaction while any subsequent :func:`~django.db.transaction.atomic` + blocks will have no effect. +- Migration operations aren't :ref:`wrapped in a transaction + ` because of MongoDB restrictions such as + adding indexes to existing collections while in a transaction. diff --git a/docs/source/releases/5.2.x.rst b/docs/source/releases/5.2.x.rst index de4b6efc..e5e9b05f 100644 --- a/docs/source/releases/5.2.x.rst +++ b/docs/source/releases/5.2.x.rst @@ -2,6 +2,16 @@ Django MongoDB Backend 5.2.x ============================ +5.2.0 beta 2 +============ + +*Unreleased* + +New features +------------ + +- Added support for :ref:`database transactions `. + 5.2.0 beta 1 ============ diff --git a/docs/source/topics/known-issues.rst b/docs/source/topics/known-issues.rst index 4b9edee7..60628f81 100644 --- a/docs/source/topics/known-issues.rst +++ b/docs/source/topics/known-issues.rst @@ -80,11 +80,7 @@ Database functions Transaction management ====================== -Query execution uses Django and MongoDB's default behavior of autocommit mode. -Each query is immediately committed to the database. - -Django's :doc:`transaction management APIs ` -are not supported. +See :ref:`transactions` for details. Database introspection ====================== diff --git a/tests/backend_/test_base.py b/tests/backend_/test_base.py index 7695b6f4..9d4e7006 100644 --- a/tests/backend_/test_base.py +++ b/tests/backend_/test_base.py @@ -1,7 +1,7 @@ from django.core.exceptions import ImproperlyConfigured from django.db import connection from django.db.backends.signals import connection_created -from django.test import SimpleTestCase, TestCase +from django.test import SimpleTestCase, TransactionTestCase from django_mongodb_backend.base import DatabaseWrapper @@ -15,7 +15,9 @@ def test_database_name_empty(self): DatabaseWrapper(settings).get_connection_params() -class DatabaseWrapperConnectionTests(TestCase): +class DatabaseWrapperConnectionTests(TransactionTestCase): + available_apps = ["backend_"] + def test_set_autocommit(self): self.assertIs(connection.get_autocommit(), True) connection.set_autocommit(False) diff --git a/tests/backend_/test_features.py b/tests/backend_/test_features.py new file mode 100644 index 00000000..3b18e64a --- /dev/null +++ b/tests/backend_/test_features.py @@ -0,0 +1,76 @@ +from unittest.mock import patch + +from django.db import connection +from django.test import TestCase + + +class SupportsTransactionsTests(TestCase): + def setUp(self): + # Clear the cached property. + del connection.features.supports_transactions + + def tearDown(self): + del connection.features.supports_transactions + + def test_replica_set(self): + """A replica set supports transactions.""" + + def mocked_command(command): + if command == "hello": + return {"setName": "foo"} + if command == "serverStatus": + return {"storageEngine": {"name": "wiredTiger"}} + raise Exception("Unexpected command") + + with patch("pymongo.synchronous.database.Database.command", wraps=mocked_command): + self.assertIs(connection.features.supports_transactions, True) + + def test_replica_set_other_storage_engine(self): + """No support on a non-wiredTiger replica set.""" + + def mocked_command(command): + if command == "hello": + return {"setName": "foo"} + if command == "serverStatus": + return {"storageEngine": {"name": "other"}} + raise Exception("Unexpected command") + + with patch("pymongo.synchronous.database.Database.command", wraps=mocked_command): + self.assertIs(connection.features.supports_transactions, False) + + def test_sharded_cluster(self): + """A sharded cluster with wiredTiger storage engine supports them.""" + + def mocked_command(command): + if command == "hello": + return {"msg": "isdbgrid"} + if command == "serverStatus": + return {"storageEngine": {"name": "wiredTiger"}} + raise Exception("Unexpected command") + + with patch("pymongo.synchronous.database.Database.command", wraps=mocked_command): + self.assertIs(connection.features.supports_transactions, True) + + def test_sharded_cluster_other_storage_engine(self): + """No support on a non-wiredTiger shared cluster.""" + + def mocked_command(command): + if command == "hello": + return {"msg": "isdbgrid"} + if command == "serverStatus": + return {"storageEngine": {"name": "other"}} + raise Exception("Unexpected command") + + with patch("pymongo.synchronous.database.Database.command", wraps=mocked_command): + self.assertIs(connection.features.supports_transactions, False) + + def test_no_support(self): + """No support on a non-replica set, non-sharded cluster.""" + + def mocked_command(command): + if command == "hello": + return {} + raise Exception("Unexpected command") + + with patch("pymongo.synchronous.database.Database.command", wraps=mocked_command): + self.assertIs(connection.features.supports_transactions, False) diff --git a/tests/backend_/utils/test_parse_uri.py b/tests/backend_/utils/test_parse_uri.py index 3198a463..c5511baf 100644 --- a/tests/backend_/utils/test_parse_uri.py +++ b/tests/backend_/utils/test_parse_uri.py @@ -1,10 +1,11 @@ from unittest.mock import patch +from urllib.parse import parse_qs import pymongo from django.core.exceptions import ImproperlyConfigured from django.test import SimpleTestCase -from django_mongodb_backend import parse_uri +from django_mongodb_backend import get_auto_encryption_options, parse_uri class ParseURITests(SimpleTestCase): @@ -94,3 +95,10 @@ def test_invalid_credentials(self): def test_no_scheme(self): with self.assertRaisesMessage(pymongo.errors.InvalidURI, "Invalid URI scheme"): parse_uri("cluster0.example.mongodb.net") + + def test_queryable_encryption_config(self): + auto_encryption_options = get_auto_encryption_options("local") + settings_dict = parse_uri( + f"mongodb://cluster0.example.mongodb.net/myDatabase{auto_encryption_options}" + ) + self.assertEqual(settings_dict["OPTIONS"], parse_qs(auto_encryption_options)) diff --git a/tests/queries_/test_objectid.py b/tests/queries_/test_objectid.py index 490d1b33..36d9561a 100644 --- a/tests/queries_/test_objectid.py +++ b/tests/queries_/test_objectid.py @@ -1,6 +1,6 @@ from bson import ObjectId from django.core.exceptions import ValidationError -from django.test import TestCase +from django.test import TestCase, skipUnlessDBFeature from .models import Order, OrderItem, Tag @@ -75,6 +75,7 @@ def test_filter_parent_by_children_values_obj(self): parent_qs = Tag.objects.filter(children__id__in=child_ids).distinct().order_by("name") self.assertSequenceEqual(parent_qs, [self.t1]) + @skipUnlessDBFeature("supports_select_union") def test_filter_group_id_union_with_str(self): """Combine queries using union with string values.""" qs_a = Tag.objects.filter(group_id=self.group_id_str_1) @@ -82,6 +83,7 @@ def test_filter_group_id_union_with_str(self): union_qs = qs_a.union(qs_b).order_by("name") self.assertSequenceEqual(union_qs, [self.t3, self.t4]) + @skipUnlessDBFeature("supports_select_union") def test_filter_group_id_union_with_obj(self): """Combine queries using union with ObjectId values.""" qs_a = Tag.objects.filter(group_id=self.group_id_obj_1) diff --git a/tests/raw_query_/test_raw_aggregate.py b/tests/raw_query_/test_raw_aggregate.py index 72cd74d0..ce87311a 100644 --- a/tests/raw_query_/test_raw_aggregate.py +++ b/tests/raw_query_/test_raw_aggregate.py @@ -182,7 +182,8 @@ def test_different_db_key_order(self): { field.name: getattr(author, field.name) for field in reversed(Author._meta.concrete_fields) - } + }, + session=connection.session, ) query = [] authors = Author.objects.all()