Skip to content

Commit c38bb5c

Browse files
committed
Code review fixes (1/x)
1 parent 3296549 commit c38bb5c

File tree

8 files changed

+104
-141
lines changed

8 files changed

+104
-141
lines changed

django_mongodb_backend/fields/encryption.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from django.db import models
22

33

4-
class EncryptedFieldMixin(models.Field):
4+
class EncryptedFieldMixin:
55
encrypted = True
66

77
def __init__(self, *args, queries=None, **kwargs):

django_mongodb_backend/management/commands/showencryptedfieldsmap.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
class Command(BaseCommand):
1010
help = """
11-
This command shows the mapping of encrypted fields to attributes
12-
including data type, data keys and query types. It can be used to set the
13-
``encrypted_fields_map`` in ``AutoEncryptionOpts``. Defaults to showing
14-
existing keys from the configured key vault.
11+
Shows the mapping of encrypted fields to field attributes, including data
12+
type, data keys and query types. The output can be used to set
13+
``encrypted_fields_map`` in ``AutoEncryptionOpts``.
14+
15+
Defaults to showing keys from the ``key_vault_namespace`` collection.
1516
"""
1617

1718
def add_arguments(self, parser):

docs/source/ref/models/encrypted-fields.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ supported by Queryable Encryption.
6363
``EncryptedFieldMixin``
6464
=======================
6565

66-
.. class:: EncryptedFieldMixin(models.Field)
66+
.. class:: EncryptedFieldMixin
6767

6868
.. versionadded:: 5.2.0rc1
6969

tests/encryption_/models.py

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
EncryptedFloatField,
1212
EncryptedGenericIPAddressField,
1313
EncryptedIntegerField,
14-
EncryptedPositiveSmallIntegerField,
1514
EncryptedSmallIntegerField,
1615
EncryptedTextField,
1716
EncryptedTimeField,
@@ -22,45 +21,42 @@
2221
RANGE_QUERY = {"queryType": "range"}
2322

2423

25-
class Appointment(models.Model):
26-
time = EncryptedTimeField(queries=EQUALITY_QUERY)
27-
24+
class QueryableEncryptionModelBase(models.Model):
2825
class Meta:
26+
abstract = True
2927
required_db_features = {"supports_queryable_encryption"}
3028

3129

32-
class Billing(models.Model):
33-
cc_type = EncryptedCharField(max_length=20, queries=EQUALITY_QUERY)
34-
cc_number = EncryptedIntegerField(queries=EQUALITY_QUERY)
30+
class Appointment(QueryableEncryptionModelBase):
31+
time = EncryptedTimeField(queries=EQUALITY_QUERY)
32+
33+
34+
class Billing(QueryableEncryptionModelBase):
3535
account_balance = EncryptedDecimalField(max_digits=10, decimal_places=2, queries=RANGE_QUERY)
3636

37-
class Meta:
38-
required_db_features = {"supports_queryable_encryption"}
37+
38+
class CreditCard(QueryableEncryptionModelBase):
39+
cc_type = EncryptedCharField(max_length=20, queries=EQUALITY_QUERY)
40+
cc_number = EncryptedIntegerField(queries=EQUALITY_QUERY)
3941

4042

41-
class PatientPortalUser(models.Model):
43+
class PatientPortalUser(QueryableEncryptionModelBase):
4244
ip_address = EncryptedGenericIPAddressField(queries=EQUALITY_QUERY)
4345
url = EncryptedURLField(queries=EQUALITY_QUERY)
4446

45-
class Meta:
46-
required_db_features = {"supports_queryable_encryption"}
47-
4847

49-
class PatientRecord(models.Model):
48+
class PatientRecord(QueryableEncryptionModelBase):
5049
ssn = EncryptedCharField(max_length=11, queries=EQUALITY_QUERY)
5150
birth_date = EncryptedDateField(queries=RANGE_QUERY)
5251
profile_picture = EncryptedBinaryField(queries=EQUALITY_QUERY)
5352
patient_age = EncryptedSmallIntegerField(queries={**RANGE_QUERY, "min": 0, "max": 100})
5453
weight = EncryptedFloatField(queries=RANGE_QUERY)
5554

5655
# TODO: Embed Billing model
57-
# billing =
56+
# billing = EncryptedEmbeddedField(Billing)
5857

59-
class Meta:
60-
required_db_features = {"supports_queryable_encryption"}
6158

62-
63-
class Patient(models.Model):
59+
class Patient(QueryableEncryptionModelBase):
6460
patient_id = EncryptedIntegerField(queries=EQUALITY_QUERY)
6561
patient_name = EncryptedCharField(max_length=100)
6662
patient_notes = EncryptedTextField(queries=EQUALITY_QUERY)
@@ -69,27 +65,4 @@ class Patient(models.Model):
6965
email = EncryptedEmailField(queries=EQUALITY_QUERY)
7066

7167
# TODO: Embed PatientRecord model
72-
# patient_record =
73-
74-
class Meta:
75-
required_db_features = {"supports_queryable_encryption"}
76-
77-
78-
class EncryptedNumbers(models.Model):
79-
# Not tested elsewhere
80-
pos_smallint = EncryptedPositiveSmallIntegerField(queries=EQUALITY_QUERY)
81-
smallint = EncryptedSmallIntegerField(queries=EQUALITY_QUERY)
82-
83-
class Meta:
84-
required_db_features = {"supports_queryable_encryption"}
85-
86-
87-
class SensitiveData(models.Model):
88-
# Example from documentation
89-
name = EncryptedCharField(max_length=100)
90-
email = EncryptedEmailField()
91-
phone_number = EncryptedCharField(max_length=15)
92-
93-
sensitive_text = EncryptedTextField()
94-
sensitive_integer = EncryptedIntegerField()
95-
sensitive_date = EncryptedDateField()
68+
# patient_record = EncryptedEmbeddedField(PatientRecord)

tests/encryption_/test_base.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,10 @@
1-
from datetime import datetime
2-
31
from django.test import TransactionTestCase, override_settings, skipUnlessDBFeature
42

5-
from .models import (
6-
Appointment,
7-
Billing,
8-
EncryptedNumbers,
9-
Patient,
10-
PatientPortalUser,
11-
PatientRecord,
12-
)
133
from .routers import TestEncryptedRouter
144

155

166
@skipUnlessDBFeature("supports_queryable_encryption")
177
@override_settings(DATABASE_ROUTERS=[TestEncryptedRouter()])
18-
class QueryableEncryptionTests(TransactionTestCase):
8+
class QueryableEncryptionTestCase(TransactionTestCase):
199
databases = {"default", "encrypted"}
2010
available_apps = ["encryption_"]
21-
22-
def setUp(self):
23-
"""
24-
Used in schema and field tests.
25-
"""
26-
self.appointment = Appointment.objects.create(time="8:00")
27-
self.billing = Billing.objects.create(
28-
cc_type="Visa", cc_number=1234567890123456, account_balance=100.50
29-
)
30-
self.portal_user = PatientPortalUser.objects.create(
31-
ip_address="127.0.0.1",
32-
url="https://example.com",
33-
)
34-
self.patientrecord = PatientRecord.objects.create(
35-
ssn="123-45-6789",
36-
birth_date="1970-01-01",
37-
profile_picture=b"image data",
38-
weight=175.5,
39-
patient_age=47,
40-
)
41-
self.patient = Patient.objects.create(
42-
patient_id=1,
43-
patient_name="John Doe",
44-
patient_notes="patient notes " * 25,
45-
registration_date=datetime(2023, 10, 1, 12, 0, 0),
46-
is_active=True,
47-
48-
)
49-
EncryptedNumbers.objects.create(
50-
pos_smallint=12345,
51-
smallint=-12345,
52-
)

tests/encryption_/test_connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from .test_base import QueryableEncryptionTests
1+
from .test_base import QueryableEncryptionTestCase
22

33

4-
class TestConnection(QueryableEncryptionTests):
4+
class TestConnection(QueryableEncryptionTestCase):
55
def test_connection(self):
66
# raise ImproperlyConfigured(
77
# "Encrypted fields found but "

tests/encryption_/test_fields.py

Lines changed: 74 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from datetime import datetime, time
1+
from datetime import date, datetime, time
22

33
import pymongo
44
from bson.binary import Binary
@@ -8,43 +8,97 @@
88
from .models import (
99
Appointment,
1010
Billing,
11-
EncryptedNumbers,
11+
CreditCard,
1212
Patient,
1313
PatientPortalUser,
1414
PatientRecord,
1515
)
16-
from .test_base import QueryableEncryptionTests
16+
from .test_base import QueryableEncryptionTestCase
1717

1818

19-
class FieldTests(QueryableEncryptionTests):
20-
def test_appointment(self):
21-
self.assertEqual(Appointment.objects.get(time="8:00").time, time(8, 0))
19+
class FieldTests(QueryableEncryptionTestCase):
20+
def setUp(self):
21+
Patient.objects.create(
22+
patient_id=1,
23+
patient_name="John Doe",
24+
patient_notes="patient notes " * 25,
25+
registration_date=datetime(2023, 10, 1, 12, 0, 0),
26+
is_active=True,
27+
28+
)
29+
PatientRecord.objects.create(
30+
ssn="123-45-6789",
31+
birth_date="1969-01-01",
32+
profile_picture=b"image data",
33+
patient_age=50,
34+
weight=180.0,
35+
)
36+
37+
def test_binaryfield(self):
38+
self.assertEqual(
39+
PatientRecord.objects.get(profile_picture=b"image data").profile_picture, b"image data"
40+
)
41+
42+
def test_booleanfield(self):
43+
self.assertTrue(Patient.objects.get(patient_id=1).is_active)
44+
45+
def test_charfield(self):
46+
CreditCard.objects.create(cc_type="Visa", cc_number="1234567890123456")
47+
self.assertEqual(CreditCard.objects.get(cc_type="Visa").cc_type, "Visa")
48+
self.assertEqual(PatientRecord.objects.get(ssn="123-45-6789").ssn, "123-45-6789")
2249

23-
def test_billing(self):
50+
def test_datefield(self):
2451
self.assertEqual(
25-
Billing.objects.get(cc_number=1234567890123456).cc_number, 1234567890123456
52+
PatientRecord.objects.get(birth_date="1969-1-1").birth_date, date(1969, 1, 1)
2653
)
27-
self.assertEqual(Billing.objects.get(cc_type="Visa").cc_type, "Visa")
54+
55+
def test_datetimefield(self):
56+
self.assertEqual(
57+
Patient.objects.get(
58+
registration_date=datetime(2023, 10, 1, 12, 0, 0)
59+
).registration_date,
60+
datetime(2023, 10, 1, 12, 0, 0),
61+
)
62+
63+
def test_decimalfield(self):
64+
Billing.objects.create(account_balance=100.50)
2865
self.assertTrue(Billing.objects.filter(account_balance__gte=100.0).exists())
2966

30-
def test_patientportaluser(self):
67+
def test_emailfield(self):
3168
self.assertEqual(
32-
PatientPortalUser.objects.get(ip_address="127.0.0.1").ip_address, "127.0.0.1"
69+
Patient.objects.get(email="[email protected]").email, "[email protected]"
3370
)
3471

35-
def test_patientrecord(self):
36-
self.assertEqual(PatientRecord.objects.get(ssn="123-45-6789").ssn, "123-45-6789")
37-
with self.assertRaises(PatientRecord.DoesNotExist):
38-
PatientRecord.objects.get(ssn="000-00-0000")
39-
self.assertTrue(PatientRecord.objects.filter(birth_date__gte="1969-01-01").exists())
72+
def test_floatfield(self):
73+
self.assertTrue(PatientRecord.objects.filter(weight__gte=175.0).exists())
74+
75+
def test_integerfield(self):
76+
CreditCard.objects.create(cc_type="Visa", cc_number="1234567890123456")
77+
self.assertEqual(
78+
CreditCard.objects.get(cc_number=1234567890123456).cc_number, 1234567890123456
79+
)
80+
81+
def test_ipaddressfield(self):
82+
PatientPortalUser.objects.create(ip_address="127.0.0.1", url="https://example.com")
4083
self.assertEqual(
41-
PatientRecord.objects.get(ssn="123-45-6789").profile_picture, b"image data"
84+
PatientPortalUser.objects.get(ip_address="127.0.0.1").ip_address, "127.0.0.1"
4285
)
86+
87+
def test_smallintegerfield(self):
4388
self.assertTrue(PatientRecord.objects.filter(patient_age__gte=40).exists())
4489
self.assertFalse(PatientRecord.objects.filter(patient_age__gte=80).exists())
45-
self.assertTrue(PatientRecord.objects.filter(weight__gte=175.0).exists())
4690

47-
# Test encrypted patient record in unencrypted database.
91+
def test_timefield(self):
92+
Appointment.objects.create(time="8:00")
93+
self.assertEqual(Appointment.objects.get(time="8:00").time, time(8, 0))
94+
95+
def test_encrypted_patient_record_in_encrypted_database(self):
96+
patients = connections["encrypted"].database.encryption__patient.find()
97+
self.assertEqual(len(list(patients)), 1)
98+
records = connections["encrypted"].database.encryption__patientrecord.find()
99+
self.assertTrue("__safeContent__" in records[0])
100+
101+
def test_encrypted_patient_record_in_unencrypted_database(self):
48102
conn_params = connections["encrypted"].get_connection_params()
49103
db_name = settings.DATABASES["encrypted"]["NAME"]
50104
if conn_params.pop("auto_encryption_opts", False):
@@ -56,32 +110,8 @@ def test_patientrecord(self):
56110
ssn = patientrecords[0]["ssn"]
57111
self.assertTrue(isinstance(ssn, Binary))
58112

59-
def test_patient(self):
113+
def test_textfield(self):
60114
self.assertEqual(
61115
Patient.objects.get(patient_notes="patient notes " * 25).patient_notes,
62116
"patient notes " * 25,
63117
)
64-
self.assertEqual(
65-
Patient.objects.get(
66-
registration_date=datetime(2023, 10, 1, 12, 0, 0)
67-
).registration_date,
68-
datetime(2023, 10, 1, 12, 0, 0),
69-
)
70-
self.assertTrue(Patient.objects.get(patient_id=1).is_active)
71-
self.assertEqual(
72-
Patient.objects.get(email="[email protected]").email, "[email protected]"
73-
)
74-
75-
# Test decrypted patient record in encrypted database.
76-
patients = connections["encrypted"].database.encryption__patient.find()
77-
self.assertEqual(len(list(patients)), 1)
78-
records = connections["encrypted"].database.encryption__patientrecord.find()
79-
self.assertTrue("__safeContent__" in records[0])
80-
81-
def test_pos_small_int(self):
82-
obj = EncryptedNumbers.objects.get(pos_smallint=12345)
83-
self.assertEqual(obj.pos_smallint, 12345)
84-
85-
def test_small_int(self):
86-
obj = EncryptedNumbers.objects.get(smallint=-12345)
87-
self.assertEqual(obj.smallint, -12345)

tests/encryption_/test_schema.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from django.db import connections
22

3-
from .test_base import QueryableEncryptionTests
3+
from .models import Patient
4+
from .test_base import QueryableEncryptionTestCase
45

56

6-
class SchemaTests(QueryableEncryptionTests):
7+
class SchemaTests(QueryableEncryptionTestCase):
78
maxDiff = None
89

910
def test_get_encrypted_fields_map(self):
@@ -53,7 +54,7 @@ def test_get_encrypted_fields_map(self):
5354
connection = connections["encrypted"]
5455
with connection.schema_editor() as editor:
5556
client = connection.connection
56-
encrypted_fields_map = editor._get_encrypted_fields_map(self.patient, client)
57+
encrypted_fields_map = editor._get_encrypted_fields_map(Patient, client)
5758
for field in encrypted_fields_map["fields"]:
5859
# Remove data keys from the output; they are expected to differ
5960
field.pop("keyId", None)

0 commit comments

Comments
 (0)