Skip to content

Commit d48ebc9

Browse files
committed
Add embedding
- EncryptedEmbeddedModel for encrypted objects - EmbeddedModel for models with encrypted fields
1 parent 100a4d0 commit d48ebc9

File tree

9 files changed

+46
-19
lines changed

9 files changed

+46
-19
lines changed

django_mongodb_backend/fields/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
EncryptedDecimalField,
1414
EncryptedDurationField,
1515
EncryptedEmailField,
16+
EncryptedEmbeddedModelField,
1617
EncryptedFieldMixin,
1718
EncryptedFloatField,
1819
EncryptedGenericIPAddressField,
@@ -43,6 +44,7 @@
4344
"EncryptedDecimalField",
4445
"EncryptedDurationField",
4546
"EncryptedEmailField",
47+
"EncryptedEmbeddedModelField",
4648
"EncryptedFieldMixin",
4749
"EncryptedFloatField",
4850
"EncryptedGenericIPAddressField",

django_mongodb_backend/fields/encryption.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
from django.db import models
22

3+
from django_mongodb_backend.fields import EmbeddedModelField
4+
5+
6+
class EncryptedEmbeddedModelField(EmbeddedModelField):
7+
encrypted = True
8+
9+
def db_type(self, connection):
10+
return "object"
11+
312

413
class EncryptedFieldMixin:
514
encrypted = True

django_mongodb_backend/management/commands/showencryptedfieldsmap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from django.core.management.base import BaseCommand
44
from django.db import DEFAULT_DB_ALIAS, connections, router
55

6-
from django_mongodb_backend.model_utils import model_has_encrypted_fields
6+
from django_mongodb_backend.utils import model_has_encrypted_fields
77

88

99
class Command(BaseCommand):

django_mongodb_backend/model_utils.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

django_mongodb_backend/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,11 @@ def delete(self, *args, **kwargs):
1414

1515
def save(self, *args, **kwargs):
1616
raise NotSupportedError("EmbeddedModels cannot be saved.")
17+
18+
19+
class EncryptedEmbeddedModel(EmbeddedModel):
20+
encrypted = True
21+
22+
class Meta:
23+
abstract = True
24+
required_db_features = {"supports_queryable_encryption"}

django_mongodb_backend/schema.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111

1212
from .fields import EmbeddedModelField
1313
from .gis.schema import GISSchemaEditor
14-
from .model_utils import model_has_encrypted_fields
1514
from .query import wrap_database_errors
16-
from .utils import OperationCollector
15+
from .utils import OperationCollector, model_has_encrypted_fields
1716

1817

1918
def ignore_embedded_models(func):
@@ -502,6 +501,9 @@ def _get_encrypted_fields(self, model, client, create_data_keys=False):
502501
db_table = model._meta.db_table
503502
field_list = []
504503
for field in fields:
504+
if isinstance(field, EmbeddedModelField):
505+
# Recursively get encrypted fields for the embedded model.
506+
self._get_encrypted_fields(field.embedded_model, client, create_data_keys)
505507
if getattr(field, "encrypted", False):
506508
key_alt_name = f"{db_table}.{field.column}"
507509
if create_data_keys:
@@ -524,7 +526,7 @@ def _get_encrypted_fields(self, model, client, create_data_keys=False):
524526
"path": field.column,
525527
"keyId": data_key,
526528
}
527-
if field.queries:
529+
if getattr(field, "queries", False):
528530
field_dict["queries"] = field.queries
529531
field_list.append(field_dict)
530532
return {"fields": field_list}

django_mongodb_backend/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,13 @@ def wrapper(self, *args, **kwargs):
186186
self.log(method, args, kwargs)
187187

188188
return wrapper
189+
190+
191+
def model_has_encrypted_fields(model):
192+
from django_mongodb_backend.models import EncryptedEmbeddedModel # noqa: PLC0415
193+
194+
for field in model._meta.fields:
195+
if getattr(field, "encrypted", False):
196+
return True
197+
198+
return bool(issubclass(model, EncryptedEmbeddedModel))

tests/encryption_/models.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from django.db import models
22

33
from django_mongodb_backend.fields import (
4-
EmbeddedModelField,
54
EncryptedBigIntegerField,
65
EncryptedBinaryField,
76
EncryptedBooleanField,
@@ -11,6 +10,7 @@
1110
EncryptedDecimalField,
1211
EncryptedDurationField,
1312
EncryptedEmailField,
13+
EncryptedEmbeddedModelField,
1414
EncryptedFloatField,
1515
EncryptedGenericIPAddressField,
1616
EncryptedIntegerField,
@@ -22,23 +22,23 @@
2222
EncryptedTimeField,
2323
EncryptedURLField,
2424
)
25-
from django_mongodb_backend.models import EmbeddedModel
25+
from django_mongodb_backend.models import EncryptedEmbeddedModel
2626

2727

28-
class Billing(EmbeddedModel):
28+
class Billing(EncryptedEmbeddedModel):
2929
cc_type = models.CharField(max_length=50)
3030
cc_number = models.CharField(max_length=20)
3131

3232

33-
class PatientRecord(EmbeddedModel):
34-
ssn = models.CharField(max_length=11)
35-
billing = EmbeddedModelField(Billing)
33+
class PatientRecord(EncryptedEmbeddedModel):
34+
ssn = EncryptedCharField(max_length=11, queries={"queryType": "equality"})
35+
billing = EncryptedEmbeddedModelField(Billing)
3636

3737

3838
class Patient(models.Model):
3939
patient_name = models.CharField(max_length=255)
4040
patient_id = models.BigIntegerField()
41-
patient_record = EmbeddedModelField(PatientRecord)
41+
patient_record = EncryptedEmbeddedModelField(PatientRecord)
4242

4343
def __str__(self):
4444
return f"{self.patient_name} ({self.patient_id})"

tests/encryption_/test_fields.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,20 @@
2929
)
3030

3131

32+
@skipUnlessDBFeature("supports_queryable_encryption")
3233
class PatientModelTests(TestCase):
34+
databases = {"default", "encrypted"}
35+
3336
def setUp(self):
3437
self.billing = Billing(cc_type="Visa", cc_number="4111111111111111")
3538
self.patient_record = PatientRecord(ssn="123-45-6789", billing=self.billing)
3639
self.patient = Patient.objects.create(
3740
patient_name="John Doe", patient_id=123456789, patient_record=self.patient_record
3841
)
3942

40-
def test_patient_record_content(self):
41-
"""Embedded patient record data should be stored and retrieved correctly."""
43+
def test_patient(self):
4244
patient = Patient.objects.get(id=self.patient.id)
4345
self.assertEqual(patient.patient_record.ssn, "123-45-6789")
44-
45-
def test_billing_information(self):
46-
"""Billing data inside the encrypted embedded model should be correct."""
47-
patient = Patient.objects.get(id=self.patient.id)
4846
self.assertEqual(patient.patient_record.billing.cc_type, "Visa")
4947
self.assertEqual(patient.patient_record.billing.cc_number, "4111111111111111")
5048

0 commit comments

Comments
 (0)