diff --git a/django_mongodb_backend/__init__.py b/django_mongodb_backend/__init__.py index 00700421..1c9f88f3 100644 --- a/django_mongodb_backend/__init__.py +++ b/django_mongodb_backend/__init__.py @@ -2,7 +2,7 @@ # Check Django compatibility before other imports which may fail if the # wrong version of Django is installed. -from .utils import check_django_compatability, parse_uri +from .utils import check_django_compatability, get_auto_encryption_options, parse_uri check_django_compatability() @@ -15,7 +15,7 @@ from .lookups import register_lookups # noqa: E402 from .query import register_nodes # noqa: E402 -__all__ = ["parse_uri"] +__all__ = ["get_auto_encryption_options", "parse_uri"] register_aggregates() register_checks() diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index fa73461d..dd5efec2 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -577,3 +577,21 @@ def supports_atlas_search(self): return False else: return True + + @cached_property + def supports_queryable_encryption(self): + """ + Queryable Encryption is supported if the server is Atlas or Enterprise + and if pymongocrypt is installed. + """ + self.connection.ensure_connection() + client = self.connection.connection.admin + build_info = client.command("buildInfo") + is_enterprise = "enterprise" in build_info.get("modules") + try: + import pymongocrypt # noqa: F401 + + has_pymongocrypt = True + except ImportError: + has_pymongocrypt = False + return is_enterprise and has_pymongocrypt diff --git a/django_mongodb_backend/models.py b/django_mongodb_backend/models.py index adeba21e..a34f5191 100644 --- a/django_mongodb_backend/models.py +++ b/django_mongodb_backend/models.py @@ -14,3 +14,8 @@ def delete(self, *args, **kwargs): def save(self, *args, **kwargs): raise NotSupportedError("EmbeddedModels cannot be saved.") + + +class EncryptedModel(models.Model): + class Meta: + abstract = True diff --git a/django_mongodb_backend/utils.py b/django_mongodb_backend/utils.py index ced60bc8..aa5da735 100644 --- a/django_mongodb_backend/utils.py +++ b/django_mongodb_backend/utils.py @@ -1,4 +1,5 @@ import copy +import os import time import django @@ -8,6 +9,7 @@ from django.utils.functional import SimpleLazyObject from django.utils.text import format_lazy from django.utils.version import get_version_tuple +from pymongo.encryption_options import AutoEncryptionOpts from pymongo.uri_parser import parse_uri as pymongo_parse_uri @@ -28,7 +30,19 @@ def check_django_compatability(): ) -def parse_uri(uri, *, db_name=None, test=None): +def get_auto_encryption_options(crypt_shared_lib_path=None): + key_vault_database_name = "encryption" + key_vault_collection_name = "__keyVault" + key_vault_namespace = f"{key_vault_database_name}.{key_vault_collection_name}" + kms_providers = {"local": {"key": os.urandom(96)}} + return AutoEncryptionOpts( + key_vault_namespace=key_vault_namespace, + kms_providers=kms_providers, + crypt_shared_lib_path=crypt_shared_lib_path, + ) + + +def parse_uri(uri, *, auto_encryption_options=None, db_name=None, test=None): """ Convert the given uri into a dictionary suitable for Django's DATABASES setting. @@ -48,6 +62,9 @@ def parse_uri(uri, *, db_name=None, test=None): db_name = db_name or uri["database"] if not db_name: raise ImproperlyConfigured("You must provide the db_name parameter.") + options = uri.get("options") + if auto_encryption_options: + options = {**uri.get("options"), "auto_encryption_options": auto_encryption_options} settings_dict = { "ENGINE": "django_mongodb_backend", "NAME": db_name, @@ -55,7 +72,7 @@ def parse_uri(uri, *, db_name=None, test=None): "PORT": port, "USER": uri.get("username"), "PASSWORD": uri.get("password"), - "OPTIONS": uri.get("options"), + "OPTIONS": options, } if "authSource" not in settings_dict["OPTIONS"] and uri["database"]: settings_dict["OPTIONS"]["authSource"] = uri["database"] diff --git a/tests/backend_/utils/test_parse_uri.py b/tests/backend_/utils/test_parse_uri.py index 3198a463..0c239cfe 100644 --- a/tests/backend_/utils/test_parse_uri.py +++ b/tests/backend_/utils/test_parse_uri.py @@ -2,9 +2,9 @@ import pymongo from django.core.exceptions import ImproperlyConfigured -from django.test import SimpleTestCase +from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature -from django_mongodb_backend import parse_uri +from django_mongodb_backend import get_auto_encryption_options, parse_uri class ParseURITests(SimpleTestCase): @@ -94,3 +94,11 @@ def test_invalid_credentials(self): def test_no_scheme(self): with self.assertRaisesMessage(pymongo.errors.InvalidURI, "Invalid URI scheme"): parse_uri("cluster0.example.mongodb.net") + + +# TODO: This can be moved to `test_features` once transaction support is merged. +class ParseUriOptionsTests(TestCase): + @skipUnlessDBFeature("supports_queryable_encryption") + def test_auto_encryption_options(self): + auto_encryption_options = get_auto_encryption_options(crypt_shared_lib_path="/path/to/lib") + parse_uri("mongodb://localhost/db", auto_encryption_options=auto_encryption_options) diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index 3d3a1584..4839f3db 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -8,7 +8,7 @@ EmbeddedModelField, ObjectIdField, ) -from django_mongodb_backend.models import EmbeddedModel +from django_mongodb_backend.models import EmbeddedModel, EncryptedModel # ObjectIdField @@ -136,6 +136,10 @@ class Author(EmbeddedModel): skills = ArrayField(models.CharField(max_length=100), null=True, blank=True) +class EncryptedData(EncryptedModel): + pass + + class Book(models.Model): name = models.CharField(max_length=100) author = EmbeddedModelField(Author) diff --git a/tests/model_fields_/test_encrypted_model.py b/tests/model_fields_/test_encrypted_model.py new file mode 100644 index 00000000..f5426f73 --- /dev/null +++ b/tests/model_fields_/test_encrypted_model.py @@ -0,0 +1,8 @@ +from django.test import TestCase + +from .models import EncryptedData + + +class ModelTests(TestCase): + def test_save_load(self): + EncryptedData.objects.create()