Skip to content

Commit bc52c8e

Browse files
committed
INTPYTHON-527 Add Queryable Encryption support
1 parent 4ca6c90 commit bc52c8e

File tree

12 files changed

+250
-4
lines changed

12 files changed

+250
-4
lines changed

django_mongodb_backend/base.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import contextlib
2+
import copy
23
import os
34

45
from django.core.exceptions import ImproperlyConfigured
56
from django.db import DEFAULT_DB_ALIAS
6-
from django.db.backends.base.base import BaseDatabaseWrapper
7+
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
78
from django.db.backends.utils import debug_transaction
89
from django.utils.asyncio import async_unsafe
910
from django.utils.functional import cached_property
@@ -156,6 +157,9 @@ def _isnull_operator(a, b):
156157
def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
157158
super().__init__(settings_dict, alias=alias)
158159
self.session = None
160+
# Cache the `settings_dict` in case we need to check for
161+
# auto_encryption_opts later.
162+
self.__dict__["_settings_dict"] = copy.deepcopy(settings_dict)
159163

160164
def get_collection(self, name, **kwargs):
161165
collection = Collection(self.database, name, **kwargs)
@@ -287,3 +291,25 @@ def validate_no_broken_transaction(self):
287291
def get_database_version(self):
288292
"""Return a tuple of the database's version."""
289293
return tuple(self.connection.server_info()["versionArray"])
294+
295+
@contextlib.contextmanager
296+
def _nodb_cursor(self):
297+
"""
298+
Returns a cursor from an unencrypted connection for operations
299+
that do not support encryption.
300+
301+
Encryption is only supported on encrypted models.
302+
"""
303+
304+
# Remove auto_encryption_opts from OPTIONS
305+
if self.settings_dict.get("OPTIONS", {}).get("auto_encryption_opts"):
306+
self.settings_dict["OPTIONS"].pop("auto_encryption_opts")
307+
308+
# Create a new connection without OPTIONS["auto_encryption_opts": …]
309+
conn = self.__class__({**self.settings_dict}, alias=NO_DB_ALIAS)
310+
311+
try:
312+
with conn.cursor() as cursor:
313+
yield cursor
314+
finally:
315+
conn.close()

django_mongodb_backend/encryption.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Queryable Encryption helpers
2+
#
3+
# TODO: Decide if these helpers should even exist, and if so, find a permanent
4+
# place for them.
5+
6+
from bson.binary import STANDARD
7+
from bson.codec_options import CodecOptions
8+
from pymongo.encryption import AutoEncryptionOpts, ClientEncryption
9+
10+
11+
def get_encrypted_client(auto_encryption_opts, encrypted_connection):
12+
"""
13+
Returns a `ClientEncryption` instance for MongoDB Client-Side Field Level
14+
Encryption (CSFLE) that can be used to create an encrypted collection.
15+
"""
16+
17+
key_vault_namespace = auto_encryption_opts._key_vault_namespace
18+
kms_providers = auto_encryption_opts._kms_providers
19+
codec_options = CodecOptions(uuid_representation=STANDARD)
20+
return ClientEncryption(kms_providers, key_vault_namespace, encrypted_connection, codec_options)
21+
22+
23+
def get_auto_encryption_opts(crypt_shared_lib_path=None, kms_providers=None):
24+
"""
25+
Returns an `AutoEncryptionOpts` instance for MongoDB Client-Side Field
26+
Level Encryption (CSFLE) that can be used to create an encrypted connection.
27+
"""
28+
key_vault_database_name = "encryption"
29+
key_vault_collection_name = "__keyVault"
30+
key_vault_namespace = f"{key_vault_database_name}.{key_vault_collection_name}"
31+
return AutoEncryptionOpts(
32+
key_vault_namespace=key_vault_namespace,
33+
kms_providers=kms_providers,
34+
crypt_shared_lib_path=crypt_shared_lib_path,
35+
)
36+
37+
38+
def get_customer_master_key():
39+
"""
40+
Returns a 96-byte local master key for use with MongoDB Client-Side Field Level
41+
Encryption (CSFLE). For local testing purposes only. In production, use a secure KMS
42+
like AWS, Azure, GCP, or KMIP.
43+
Returns:
44+
bytes: A 96-byte key.
45+
"""
46+
# WARNING: This is a static key for testing only.
47+
# Generate with: os.urandom(96)
48+
return bytes.fromhex(
49+
"000102030405060708090a0b0c0d0e0f"
50+
"101112131415161718191a1b1c1d1e1f"
51+
"202122232425262728292a2b2c2d2e2f"
52+
"303132333435363738393a3b3c3d3e3f"
53+
"404142434445464748494a4b4c4d4e4f"
54+
"505152535455565758595a5b5c5d5e5f"
55+
)
56+
57+
58+
def get_kms_providers():
59+
"""
60+
Return supported KMS providers for MongoDB Client-Side Field Level Encryption (CSFLE).
61+
"""
62+
return {
63+
"local": {
64+
"key": get_customer_master_key(),
65+
},
66+
}

django_mongodb_backend/features.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,3 +624,18 @@ def supports_transactions(self):
624624
hello = client.command("hello")
625625
# a replica set or a sharded cluster
626626
return "setName" in hello or hello.get("msg") == "isdbgrid"
627+
628+
@cached_property
629+
def supports_encryption(self):
630+
"""
631+
Encryption is supported if the server is Atlas or Enterprise
632+
and is configured as a replica set or sharded cluster.
633+
"""
634+
self.connection.ensure_connection()
635+
client = self.connection.connection.admin
636+
build_info = client.command("buildInfo")
637+
is_enterprise = "enterprise" in build_info.get("modules")
638+
# `supports_transactions` already checks if the server is a
639+
# replica set or sharded cluster.
640+
is_not_single = self.supports_transactions
641+
return is_enterprise and is_not_single

django_mongodb_backend/fields/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .duration import register_duration_field
44
from .embedded_model import EmbeddedModelField
55
from .embedded_model_array import EmbeddedModelArrayField
6+
from .encryption import EncryptedCharField
67
from .json import register_json_field
78
from .objectid import ObjectIdField
89

@@ -11,6 +12,7 @@
1112
"ArrayField",
1213
"EmbeddedModelArrayField",
1314
"EmbeddedModelField",
15+
"EncryptedCharField",
1416
"ObjectIdAutoField",
1517
"ObjectIdField",
1618
]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from django.db import models
2+
3+
4+
class EncryptedCharField(models.CharField):
5+
def __init__(self, *args, **kwargs):
6+
super().__init__(*args, **kwargs)
7+
self.encrypted = True

django_mongodb_backend/models.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,26 @@ def delete(self, *args, **kwargs):
1414

1515
def save(self, *args, **kwargs):
1616
raise NotSupportedError("EmbeddedModels cannot be saved.")
17+
18+
19+
class EncryptedModelBase(models.base.ModelBase):
20+
def __new__(cls, name, bases, attrs, **kwargs):
21+
new_class = super().__new__(cls, name, bases, attrs, **kwargs)
22+
23+
# Build a map of encrypted fields
24+
encrypted_fields = {
25+
"fields": {
26+
field.name: field.__class__.__name__
27+
for field in new_class._meta.fields
28+
if getattr(field, "encrypted", False)
29+
}
30+
}
31+
32+
# Store it as a class-level attribute
33+
new_class.encrypted_fields_map = encrypted_fields
34+
return new_class
35+
36+
37+
class EncryptedModel(models.Model, metaclass=EncryptedModelBase):
38+
class Meta:
39+
abstract = True

django_mongodb_backend/schema.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import contextlib
2+
13
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
24
from django.db.models import Index, UniqueConstraint
5+
from pymongo.encryption import EncryptedCollectionError
36
from pymongo.operations import SearchIndexModel
47

5-
from django_mongodb_backend.indexes import SearchIndex
6-
8+
from .encryption import get_encrypted_client
79
from .fields import EmbeddedModelField
10+
from .indexes import SearchIndex
811
from .query import wrap_database_errors
912
from .utils import OperationCollector
1013

@@ -41,7 +44,7 @@ def get_database(self):
4144
@wrap_database_errors
4245
@ignore_embedded_models
4346
def create_model(self, model):
44-
self.get_database().create_collection(model._meta.db_table)
47+
self._create_collection(model)
4548
self._create_model_indexes(model)
4649
# Make implicit M2M tables.
4750
for field in model._meta.local_many_to_many:
@@ -418,3 +421,45 @@ def _field_should_have_unique(self, field):
418421
db_type = field.db_type(self.connection)
419422
# The _id column is automatically unique.
420423
return db_type and field.unique and field.column != "_id"
424+
425+
def _supports_encryption(self, model):
426+
"""
427+
Check for `supports_encryption` feature and `auto_encryption_opts`
428+
and `embedded_fields_map`. If `supports_encryption` is True and
429+
`auto_encryption_opts` is in the cached connection settings and
430+
the model has an embedded_fields_map property, then encryption
431+
is supported.
432+
"""
433+
return (
434+
self.connection.features.supports_encryption
435+
and self.connection._settings_dict.get("OPTIONS", {}).get("auto_encryption_opts")
436+
and hasattr(model, "encrypted_fields_map")
437+
)
438+
439+
def _create_collection(self, model):
440+
"""
441+
Create a collection or, if encryption is supported, create
442+
an encrypted connection then use it to create an encrypted
443+
client then use that to create an encrypted collection.
444+
"""
445+
446+
if self._supports_encryption(model):
447+
auto_encryption_opts = self.connection._settings_dict.get("OPTIONS", {}).get(
448+
"auto_encryption_opts"
449+
)
450+
# Use the cached settings dict to create a new connection
451+
encrypted_connection = self.connection.get_new_connection(
452+
self.connection._settings_dict
453+
)
454+
# Use the encrypted connection and auto_encryption_opts to create an encrypted client
455+
encrypted_client = get_encrypted_client(auto_encryption_opts, encrypted_connection)
456+
457+
with contextlib.suppress(EncryptedCollectionError):
458+
encrypted_client.create_encrypted_collection(
459+
encrypted_connection[self.connection.database.name],
460+
model._meta.db_table,
461+
model.encrypted_fields_map,
462+
"local", # TODO: KMS provider should be configurable
463+
)
464+
else:
465+
self.get_database().create_collection(model._meta.db_table)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
Encrypted models
2+
================
3+
4+
``EncryptedCharField``
5+
----------------------
6+
7+
The basics
8+
~~~~~~~~~~
9+
10+
Let's consider this example::
11+
12+
from django.db import models
13+
14+
from django_mongodb_backend.fields import EncryptedCharField
15+
from django_mongodb_backend.models import EncryptedModel
16+
17+
18+
class Person(EncryptedModel):
19+
ssn = EncryptedCharField("ssn", max_length=11)
20+
21+
def __str__(self):
22+
return self.ssn

docs/source/topics/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ know:
1010

1111
cache
1212
embedded-models
13+
encrypted-models
1314
known-issues

tests/encryption_/__init__.py

Whitespace-only changes.

tests/encryption_/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from django_mongodb_backend.fields import EncryptedCharField
2+
from django_mongodb_backend.models import EncryptedModel
3+
4+
5+
class Person(EncryptedModel):
6+
ssn = EncryptedCharField("ssn", max_length=11)
7+
8+
def __str__(self):
9+
return self.ssn

tests/encryption_/tests.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from django.test import TestCase
2+
3+
from .models import Person
4+
5+
6+
class EncryptedModelTests(TestCase):
7+
@classmethod
8+
def setUpTestData(cls):
9+
cls.objs = [Person.objects.create()]
10+
11+
def test_encrypted_fields_map_on_class(self):
12+
expected = {
13+
"fields": {
14+
"ssn": "EncryptedCharField",
15+
}
16+
}
17+
self.assertEqual(Person.encrypted_fields_map, expected)
18+
19+
def test_encrypted_fields_map_on_instance(self):
20+
instance = Person(ssn="123-45-6789")
21+
expected = {
22+
"fields": {
23+
"ssn": "EncryptedCharField",
24+
}
25+
}
26+
self.assertEqual(instance.encrypted_fields_map, expected)
27+
28+
def test_non_encrypted_fields_not_included(self):
29+
encrypted_field_names = Person.encrypted_fields_map.keys()
30+
self.assertNotIn("ssn", encrypted_field_names)

0 commit comments

Comments
 (0)