diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index fc21fa5b..ad60588d 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -286,4 +286,6 @@ def validate_no_broken_transaction(self): def get_database_version(self): """Return a tuple of the database's version.""" - return tuple(self.connection.server_info()["versionArray"]) + return (8, 1, 1) + # TODO: provide an unencrypted connection for this method. + # return tuple(self.connection.server_info()["versionArray"]) diff --git a/django_mongodb_backend/encryption.py b/django_mongodb_backend/encryption.py new file mode 100644 index 00000000..2921f343 --- /dev/null +++ b/django_mongodb_backend/encryption.py @@ -0,0 +1,66 @@ +# Queryable Encryption helpers +# +# TODO: Decide if these helpers should even exist, and if so, find a permanent +# place for them. + +from bson.binary import STANDARD +from bson.codec_options import CodecOptions +from pymongo.encryption import AutoEncryptionOpts, ClientEncryption + + +def get_client_encryption(auto_encryption_opts, encrypted_connection): + """ + Returns a `ClientEncryption` instance for MongoDB Client-Side Field Level + Encryption (CSFLE) that can be used to create an encrypted collection. + """ + + key_vault_namespace = auto_encryption_opts._key_vault_namespace + kms_providers = auto_encryption_opts._kms_providers + codec_options = CodecOptions(uuid_representation=STANDARD) + return ClientEncryption(kms_providers, key_vault_namespace, encrypted_connection, codec_options) + + +def get_auto_encryption_opts(crypt_shared_lib_path=None, kms_providers=None): + """ + Returns an `AutoEncryptionOpts` instance for MongoDB Client-Side Field + Level Encryption (CSFLE) that can be used to create an encrypted connection. + """ + key_vault_database_name = "encryption" + key_vault_collection_name = "__keyVault" + key_vault_namespace = f"{key_vault_database_name}.{key_vault_collection_name}" + return AutoEncryptionOpts( + key_vault_namespace=key_vault_namespace, + kms_providers=kms_providers, + crypt_shared_lib_path=crypt_shared_lib_path, + ) + + +def get_customer_master_key(): + """ + Returns a 96-byte local master key for use with MongoDB Client-Side Field Level + Encryption (CSFLE). For local testing purposes only. In production, use a secure KMS + like AWS, Azure, GCP, or KMIP. + Returns: + bytes: A 96-byte key. + """ + # WARNING: This is a static key for testing only. + # Generate with: os.urandom(96) + return bytes.fromhex( + "000102030405060708090a0b0c0d0e0f" + "101112131415161718191a1b1c1d1e1f" + "202122232425262728292a2b2c2d2e2f" + "303132333435363738393a3b3c3d3e3f" + "404142434445464748494a4b4c4d4e4f" + "505152535455565758595a5b5c5d5e5f" + ) + + +def get_kms_providers(): + """ + Return supported KMS providers for MongoDB Client-Side Field Level Encryption (CSFLE). + """ + return { + "local": { + "key": get_customer_master_key(), + }, + } diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index 3e9cc292..1feef98e 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -624,3 +624,18 @@ def supports_transactions(self): hello = client.command("hello") # a replica set or a sharded cluster return "setName" in hello or hello.get("msg") == "isdbgrid" + + @cached_property + def supports_encryption(self): + """ + Encryption is supported if the server is Atlas or Enterprise + and is configured as a replica set or sharded cluster. + """ + self.connection.ensure_connection() + client = self.connection.connection.admin + build_info = client.command("buildInfo") + is_enterprise = "enterprise" in build_info.get("modules") + # `supports_transactions` already checks if the server is a + # replica set or sharded cluster. + is_not_single = self.supports_transactions + return is_enterprise and is_not_single diff --git a/django_mongodb_backend/fields/__init__.py b/django_mongodb_backend/fields/__init__.py index be95fa5e..ced7fa2b 100644 --- a/django_mongodb_backend/fields/__init__.py +++ b/django_mongodb_backend/fields/__init__.py @@ -3,6 +3,7 @@ from .duration import register_duration_field from .embedded_model import EmbeddedModelField from .embedded_model_array import EmbeddedModelArrayField +from .encryption import EncryptedCharField from .json import register_json_field from .objectid import ObjectIdField @@ -11,6 +12,7 @@ "ArrayField", "EmbeddedModelArrayField", "EmbeddedModelField", + "EncryptedCharField", "ObjectIdAutoField", "ObjectIdField", ] diff --git a/django_mongodb_backend/fields/encryption.py b/django_mongodb_backend/fields/encryption.py new file mode 100644 index 00000000..7fb80a02 --- /dev/null +++ b/django_mongodb_backend/fields/encryption.py @@ -0,0 +1,7 @@ +from django.db import models + + +class EncryptedCharField(models.CharField): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.encrypted = True diff --git a/django_mongodb_backend/models.py b/django_mongodb_backend/models.py index adeba21e..822c744b 100644 --- a/django_mongodb_backend/models.py +++ b/django_mongodb_backend/models.py @@ -14,3 +14,26 @@ def delete(self, *args, **kwargs): def save(self, *args, **kwargs): raise NotSupportedError("EmbeddedModels cannot be saved.") + + +class EncryptedModelBase(models.base.ModelBase): + def __new__(cls, name, bases, attrs, **kwargs): + new_class = super().__new__(cls, name, bases, attrs, **kwargs) + + # Build a map of encrypted fields + encrypted_fields = { + "fields": { + field.name: field.__class__.__name__ + for field in new_class._meta.fields + if getattr(field, "encrypted", False) + } + } + + # Store it as a class-level attribute + new_class.encrypted_fields_map = encrypted_fields + return new_class + + +class EncryptedModel(models.Model, metaclass=EncryptedModelBase): + class Meta: + abstract = True diff --git a/django_mongodb_backend/schema.py b/django_mongodb_backend/schema.py index da3ec961..96311c93 100644 --- a/django_mongodb_backend/schema.py +++ b/django_mongodb_backend/schema.py @@ -2,9 +2,9 @@ from django.db.models import Index, UniqueConstraint from pymongo.operations import SearchIndexModel -from django_mongodb_backend.indexes import SearchIndex - +from .encryption import get_client_encryption from .fields import EmbeddedModelField +from .indexes import SearchIndex from .query import wrap_database_errors from .utils import OperationCollector @@ -41,7 +41,7 @@ def get_database(self): @wrap_database_errors @ignore_embedded_models def create_model(self, model): - self.get_database().create_collection(model._meta.db_table) + self._create_collection(model) self._create_model_indexes(model) # Make implicit M2M tables. for field in model._meta.local_many_to_many: @@ -418,3 +418,23 @@ def _field_should_have_unique(self, field): db_type = field.db_type(self.connection) # The _id column is automatically unique. return db_type and field.unique and field.column != "_id" + + def _create_collection(self, model): + """ + Create a collection or encrypted collection for the model. + """ + + if hasattr(model, "encrypted_fields_map"): + auto_encryption_opts = self.connection.settings_dict.get("OPTIONS", {}).get( + "auto_encryption_opts" + ) + client = self.connection.connection + client_encryption = get_client_encryption(auto_encryption_opts, client) + client_encryption.create_encrypted_collection( + client.database, + model._meta.db_table, + model.encrypted_fields_map, + "local", # TODO: KMS provider should be configurable + ) + else: + self.get_database().create_collection(model._meta.db_table) diff --git a/docs/source/topics/encrypted-models.rst b/docs/source/topics/encrypted-models.rst new file mode 100644 index 00000000..4e40bc48 --- /dev/null +++ b/docs/source/topics/encrypted-models.rst @@ -0,0 +1,22 @@ +Encrypted models +================ + +``EncryptedCharField`` +---------------------- + +The basics +~~~~~~~~~~ + +Let's consider this example:: + + from django.db import models + + from django_mongodb_backend.fields import EncryptedCharField + from django_mongodb_backend.models import EncryptedModel + + + class Person(EncryptedModel): + ssn = EncryptedCharField("ssn", max_length=11) + + def __str__(self): + return self.ssn diff --git a/docs/source/topics/index.rst b/docs/source/topics/index.rst index 47e0c6dc..285fd718 100644 --- a/docs/source/topics/index.rst +++ b/docs/source/topics/index.rst @@ -10,4 +10,5 @@ know: cache embedded-models + encrypted-models known-issues diff --git a/tests/encryption_/__init__.py b/tests/encryption_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/encryption_/models.py b/tests/encryption_/models.py new file mode 100644 index 00000000..8adbf1a0 --- /dev/null +++ b/tests/encryption_/models.py @@ -0,0 +1,12 @@ +from django.db import models + +from django_mongodb_backend.fields import EncryptedCharField +from django_mongodb_backend.models import EncryptedModel + + +class Person(EncryptedModel): + name = models.CharField("name", max_length=100) + ssn = EncryptedCharField("ssn", max_length=11) + + def __str__(self): + return self.name diff --git a/tests/encryption_/tests.py b/tests/encryption_/tests.py new file mode 100644 index 00000000..04bf4531 --- /dev/null +++ b/tests/encryption_/tests.py @@ -0,0 +1,30 @@ +from django.test import TestCase + +from .models import Person + + +class EncryptedModelTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.objs = [Person.objects.create()] + + def test_encrypted_fields_map_on_class(self): + expected = { + "fields": { + "ssn": "EncryptedCharField", + } + } + self.assertEqual(Person.encrypted_fields_map, expected) + + def test_encrypted_fields_map_on_instance(self): + instance = Person(ssn="123-45-6789") + expected = { + "fields": { + "ssn": "EncryptedCharField", + } + } + self.assertEqual(instance.encrypted_fields_map, expected) + + def test_non_encrypted_fields_not_included(self): + encrypted_field_names = Person.encrypted_fields_map.get("fields").keys() + self.assertNotIn("name", encrypted_field_names)