Skip to content

INTPYTHON-527 Add Queryable Encryption support #329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion django_mongodb_backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,4 +286,6 @@ def validate_no_broken_transaction(self):

def get_database_version(self):
"""Return a tuple of the database's version."""
return tuple(self.connection.server_info()["versionArray"])
return (8, 1, 1)
# TODO: provide an unencrypted connection for this method.
# return tuple(self.connection.server_info()["versionArray"])
66 changes: 66 additions & 0 deletions django_mongodb_backend/encryption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Queryable Encryption helpers
#
# TODO: Decide if these helpers should even exist, and if so, find a permanent
# place for them.

from bson.binary import STANDARD
from bson.codec_options import CodecOptions
from pymongo.encryption import AutoEncryptionOpts, ClientEncryption


def get_client_encryption(auto_encryption_opts, encrypted_connection):
"""
Returns a `ClientEncryption` instance for MongoDB Client-Side Field Level
Encryption (CSFLE) that can be used to create an encrypted collection.
"""

key_vault_namespace = auto_encryption_opts._key_vault_namespace
kms_providers = auto_encryption_opts._kms_providers
codec_options = CodecOptions(uuid_representation=STANDARD)
return ClientEncryption(kms_providers, key_vault_namespace, encrypted_connection, codec_options)


def get_auto_encryption_opts(crypt_shared_lib_path=None, kms_providers=None):
Copy link
Collaborator Author

@aclark4life aclark4life Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

crypt_shared library is in the pymongocrypt wheel, which is much easier than downloading separately and telling MongoClient where it is.

Copy link
Collaborator Author

@aclark4life aclark4life Jun 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More to this story:

  • libmongocrypt is in the pymongocrypt wheel, not crypt_shared which must always be downloaded and configured manually.
  • libmongocrypt works because mongocryptd is running on enterprise.

We should document this.

(via @ShaneHarvey, thanks!)

"""
Returns an `AutoEncryptionOpts` instance for MongoDB Client-Side Field
Level Encryption (CSFLE) that can be used to create an encrypted connection.
"""
key_vault_database_name = "encryption"
key_vault_collection_name = "__keyVault"
key_vault_namespace = f"{key_vault_database_name}.{key_vault_collection_name}"
return AutoEncryptionOpts(
key_vault_namespace=key_vault_namespace,
kms_providers=kms_providers,
crypt_shared_lib_path=crypt_shared_lib_path,
)


def get_customer_master_key():
"""
Returns a 96-byte local master key for use with MongoDB Client-Side Field Level
Encryption (CSFLE). For local testing purposes only. In production, use a secure KMS
like AWS, Azure, GCP, or KMIP.
Returns:
bytes: A 96-byte key.
"""
# WARNING: This is a static key for testing only.
# Generate with: os.urandom(96)
return bytes.fromhex(
"000102030405060708090a0b0c0d0e0f"
"101112131415161718191a1b1c1d1e1f"
"202122232425262728292a2b2c2d2e2f"
"303132333435363738393a3b3c3d3e3f"
"404142434445464748494a4b4c4d4e4f"
"505152535455565758595a5b5c5d5e5f"
)


def get_kms_providers():
"""
Return supported KMS providers for MongoDB Client-Side Field Level Encryption (CSFLE).
"""
return {
"local": {
"key": get_customer_master_key(),
},
}
15 changes: 15 additions & 0 deletions django_mongodb_backend/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,18 @@ def supports_transactions(self):
hello = client.command("hello")
# a replica set or a sharded cluster
return "setName" in hello or hello.get("msg") == "isdbgrid"

@cached_property
def supports_encryption(self):
"""
Encryption is supported if the server is Atlas or Enterprise
and is configured as a replica set or sharded cluster.
"""
self.connection.ensure_connection()
client = self.connection.connection.admin
build_info = client.command("buildInfo")
is_enterprise = "enterprise" in build_info.get("modules")
# `supports_transactions` already checks if the server is a
# replica set or sharded cluster.
is_not_single = self.supports_transactions
return is_enterprise and is_not_single
2 changes: 2 additions & 0 deletions django_mongodb_backend/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .duration import register_duration_field
from .embedded_model import EmbeddedModelField
from .embedded_model_array import EmbeddedModelArrayField
from .encryption import EncryptedCharField
from .json import register_json_field
from .objectid import ObjectIdField

Expand All @@ -11,6 +12,7 @@
"ArrayField",
"EmbeddedModelArrayField",
"EmbeddedModelField",
"EncryptedCharField",
"ObjectIdAutoField",
"ObjectIdField",
]
Expand Down
7 changes: 7 additions & 0 deletions django_mongodb_backend/fields/encryption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from django.db import models


class EncryptedCharField(models.CharField):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.encrypted = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd think this could be a class-level variable.

23 changes: 23 additions & 0 deletions django_mongodb_backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,26 @@ def delete(self, *args, **kwargs):

def save(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be saved.")


class EncryptedModelBase(models.base.ModelBase):
def __new__(cls, name, bases, attrs, **kwargs):
new_class = super().__new__(cls, name, bases, attrs, **kwargs)

# Build a map of encrypted fields
encrypted_fields = {
"fields": {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add query conditions

field.name: field.__class__.__name__
for field in new_class._meta.fields
if getattr(field, "encrypted", False)
}
}

# Store it as a class-level attribute
new_class.encrypted_fields_map = encrypted_fields
return new_class


class EncryptedModel(models.Model, metaclass=EncryptedModelBase):
class Meta:
abstract = True
26 changes: 23 additions & 3 deletions django_mongodb_backend/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from django.db.models import Index, UniqueConstraint
from pymongo.operations import SearchIndexModel

from django_mongodb_backend.indexes import SearchIndex

from .encryption import get_client_encryption
from .fields import EmbeddedModelField
from .indexes import SearchIndex
from .query import wrap_database_errors
from .utils import OperationCollector

Expand Down Expand Up @@ -41,7 +41,7 @@ def get_database(self):
@wrap_database_errors
@ignore_embedded_models
def create_model(self, model):
self.get_database().create_collection(model._meta.db_table)
self._create_collection(model)
self._create_model_indexes(model)
# Make implicit M2M tables.
for field in model._meta.local_many_to_many:
Expand Down Expand Up @@ -418,3 +418,23 @@ def _field_should_have_unique(self, field):
db_type = field.db_type(self.connection)
# The _id column is automatically unique.
return db_type and field.unique and field.column != "_id"

def _create_collection(self, model):
"""
Create a collection or encrypted collection for the model.
"""

if hasattr(model, "encrypted_fields_map"):
auto_encryption_opts = self.connection.settings_dict.get("OPTIONS", {}).get(
"auto_encryption_opts"
)
client = self.connection.connection
client_encryption = get_client_encryption(auto_encryption_opts, client)
client_encryption.create_encrypted_collection(
client.database,
model._meta.db_table,
model.encrypted_fields_map,
"local", # TODO: KMS provider should be configurable
)
else:
self.get_database().create_collection(model._meta.db_table)
22 changes: 22 additions & 0 deletions docs/source/topics/encrypted-models.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Encrypted models
================

``EncryptedCharField``
----------------------

The basics
~~~~~~~~~~

Let's consider this example::

from django.db import models

from django_mongodb_backend.fields import EncryptedCharField
from django_mongodb_backend.models import EncryptedModel


class Person(EncryptedModel):
ssn = EncryptedCharField("ssn", max_length=11)

def __str__(self):
return self.ssn
1 change: 1 addition & 0 deletions docs/source/topics/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ know:

cache
embedded-models
encrypted-models
known-issues
Empty file added tests/encryption_/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions tests/encryption_/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from django.db import models

from django_mongodb_backend.fields import EncryptedCharField
from django_mongodb_backend.models import EncryptedModel


class Person(EncryptedModel):
name = models.CharField("name", max_length=100)
ssn = EncryptedCharField("ssn", max_length=11)

def __str__(self):
return self.name
30 changes: 30 additions & 0 deletions tests/encryption_/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from django.test import TestCase

from .models import Person


class EncryptedModelTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.objs = [Person.objects.create()]

def test_encrypted_fields_map_on_class(self):
expected = {
"fields": {
"ssn": "EncryptedCharField",
}
}
self.assertEqual(Person.encrypted_fields_map, expected)

def test_encrypted_fields_map_on_instance(self):
instance = Person(ssn="123-45-6789")
expected = {
"fields": {
"ssn": "EncryptedCharField",
}
}
self.assertEqual(instance.encrypted_fields_map, expected)

def test_non_encrypted_fields_not_included(self):
encrypted_field_names = Person.encrypted_fields_map.get("fields").keys()
self.assertNotIn("name", encrypted_field_names)
Loading