Skip to content

Commit 7c45d6a

Browse files
timgrahamWaVEV
authored andcommitted
EmbeddedModelArrayField Querying
1 parent 825ffca commit 7c45d6a

File tree

6 files changed

+378
-15
lines changed

6 files changed

+378
-15
lines changed

django_mongodb_backend/fields/array.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,7 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
326326
def as_mql(self, compiler, connection):
327327
lhs_mql = process_lhs(self, compiler, connection)
328328
value = process_rhs(self, compiler, connection)
329-
return {
330-
"$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}]
331-
}
329+
return {"$and": [{"$isArray": lhs_mql}, {"$size": {"$setIntersection": [value, lhs_mql]}}]}
332330

333331

334332
@ArrayField.register_lookup
@@ -338,7 +336,7 @@ class ArrayLenTransform(Transform):
338336

339337
def as_mql(self, compiler, connection):
340338
lhs_mql = process_lhs(self, compiler, connection)
341-
return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}}
339+
return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}}
342340

343341

344342
@ArrayField.register_lookup

django_mongodb_backend/fields/embedded_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,9 @@ def as_mql(self, compiler, connection):
186186
key_transforms.insert(0, previous.key_name)
187187
previous = previous.lhs
188188
mql = previous.as_mql(compiler, connection)
189-
transforms = ".".join(key_transforms)
190-
return f"{mql}.{transforms}"
189+
for key in key_transforms:
190+
mql = {"$getField": {"input": mql, "field": key}}
191+
return mql
191192

192193
@property
193194
def output_field(self):

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1-
from django.db.models import Field
1+
import difflib
2+
3+
from django.core.exceptions import FieldDoesNotExist
4+
from django.db.models import Field, lookups
5+
from django.db.models.expressions import Col
6+
from django.db.models.lookups import Lookup, Transform
27

38
from .. import forms
9+
from ..query_utils import process_lhs, process_rhs
410
from . import EmbeddedModelField
511
from .array import ArrayField
612

713

814
class EmbeddedModelArrayField(ArrayField):
15+
ALLOWED_LOOKUPS = {"exact", "len", "overlap"}
16+
917
def __init__(self, embedded_model, **kwargs):
1018
if "size" in kwargs:
1119
raise ValueError("EmbeddedModelArrayField does not support size.")
@@ -44,3 +52,142 @@ def formfield(self, **kwargs):
4452
**kwargs,
4553
},
4654
)
55+
56+
def get_transform(self, name):
57+
transform = super().get_transform(name)
58+
if transform:
59+
return transform
60+
return KeyTransformFactory(name, self)
61+
62+
def get_lookup(self, name):
63+
return super().get_lookup(name) if name in self.ALLOWED_LOOKUPS else None
64+
65+
66+
class EMFArrayRHSMixin:
67+
def process_rhs(self, compiler, connection):
68+
values = self.rhs
69+
# Value must be serealized based on the query target.
70+
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
71+
# field of the subfield. Otherwise, use the base field of the array itself.
72+
if isinstance(self.lhs, KeyTransform):
73+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
74+
else:
75+
get_db_prep_value = self.lhs.output_field.base_field.get_db_prep_value
76+
return None, [get_db_prep_value(values, connection, prepared=True)]
77+
78+
79+
@EmbeddedModelArrayField.register_lookup
80+
class EMFArrayExact(EMFArrayRHSMixin, lookups.Exact):
81+
def as_mql(self, compiler, connection):
82+
if not isinstance(self.lhs, KeyTransform):
83+
raise ValueError("error")
84+
lhs_mql, inner_lhs_mql = process_lhs(self, compiler, connection)
85+
value = process_rhs(self, compiler, connection)
86+
return {
87+
"$anyElementTrue": {
88+
"$ifNull": [
89+
{
90+
"$map": {
91+
"input": lhs_mql,
92+
"as": "item",
93+
"in": {"$eq": [inner_lhs_mql, value]},
94+
}
95+
},
96+
[],
97+
]
98+
}
99+
}
100+
101+
102+
@EmbeddedModelArrayField.register_lookup
103+
class ArrayOverlap(EMFArrayRHSMixin, Lookup):
104+
lookup_name = "overlap"
105+
106+
def as_mql(self, compiler, connection):
107+
# Querying a subfield within the array elements (via nested KeyTransform).
108+
# Replicates MongoDB's implicit ANY-match by mapping over the array and applying
109+
# `$in` on the subfield.
110+
if not isinstance(self.lhs, KeyTransform):
111+
raise ValueError()
112+
lhs_mql = process_lhs(self, compiler, connection)
113+
values = process_rhs(self, compiler, connection)
114+
lhs_mql, inner_lhs_mql = lhs_mql
115+
return {
116+
"$anyElementTrue": {
117+
"$ifNull": [
118+
{
119+
"$map": {
120+
"input": lhs_mql,
121+
"as": "item",
122+
"in": {"$in": [inner_lhs_mql, values]},
123+
}
124+
},
125+
[],
126+
]
127+
}
128+
}
129+
130+
131+
class KeyTransform(Transform):
132+
def __init__(self, key_name, array_field, *args, **kwargs):
133+
super().__init__(*args, **kwargs)
134+
self.array_field = array_field
135+
self.key_name = key_name
136+
# The iteration items begins from the base_field, a virtual column with
137+
# base field output type is created.
138+
column_target = array_field.embedded_model._meta.get_field(key_name).clone()
139+
column_name = f"$item.{key_name}"
140+
column_target.db_column = column_name
141+
column_target.set_attributes_from_name(column_name)
142+
self._lhs = Col(None, column_target)
143+
self._sub_transform = None
144+
145+
def __call__(self, this, *args, **kwargs):
146+
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
147+
return self
148+
149+
def get_lookup(self, name):
150+
return self.output_field.get_lookup(name)
151+
152+
def _get_missing_field_or_lookup_exception(self, lhs, name):
153+
suggested_lookups = difflib.get_close_matches(name, lhs.get_lookups())
154+
if suggested_lookups:
155+
suggested_lookups = " or ".join(suggested_lookups)
156+
suggestion = f", perhaps you meant {suggested_lookups}?"
157+
else:
158+
suggestion = ""
159+
raise FieldDoesNotExist(
160+
f"Unsupported lookup '{name}' for "
161+
f"EmbeddedModelArrayField of '{lhs.__class__.__name__}'"
162+
f"{suggestion}"
163+
)
164+
165+
def get_transform(self, name):
166+
"""
167+
Validate that `name` is either a field of an embedded model or a
168+
lookup on an embedded model's field.
169+
"""
170+
# Once the sub lhs is a transform, all the filter are applied over it.
171+
# Otherwise get transform from EMF.
172+
if transform := self._lhs.get_transform(name):
173+
self._sub_transform = transform
174+
return self
175+
raise self._get_missing_field_or_lookup_exception(self._lhs.output_field, name)
176+
177+
def as_mql(self, compiler, connection):
178+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
179+
lhs_mql = process_lhs(self, compiler, connection)
180+
return lhs_mql, inner_lhs_mql
181+
182+
@property
183+
def output_field(self):
184+
return self.array_field
185+
186+
187+
class KeyTransformFactory:
188+
def __init__(self, key_name, base_field):
189+
self.key_name = key_name
190+
self.base_field = base_field
191+
192+
def __call__(self, *args, **kwargs):
193+
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)

tests/model_fields_/models.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,35 @@ class Movie(models.Model):
165165

166166
def __str__(self):
167167
return self.title
168+
169+
170+
class RestorationRecord(EmbeddedModel):
171+
date = models.DateField()
172+
restored_by = models.CharField(max_length=255)
173+
174+
175+
class ArtifactDetail(EmbeddedModel):
176+
"""Details about a specific artifact."""
177+
178+
name = models.CharField(max_length=255)
179+
metadata = models.JSONField()
180+
restorations = EmbeddedModelArrayField(RestorationRecord, null=True)
181+
last_restoration = EmbeddedModelField(RestorationRecord, null=True)
182+
183+
184+
class ExhibitSection(EmbeddedModel):
185+
"""A section within an exhibit, containing multiple artifacts."""
186+
187+
section_number = models.IntegerField()
188+
artifacts = EmbeddedModelArrayField(ArtifactDetail, null=True)
189+
190+
191+
class MuseumExhibit(models.Model):
192+
"""An exhibit in the museum, composed of multiple sections."""
193+
194+
exhibit_name = models.CharField(max_length=255)
195+
sections = EmbeddedModelArrayField(ExhibitSection, null=True)
196+
main_section = EmbeddedModelField(ExhibitSection, null=True)
197+
198+
def __str__(self):
199+
return self.exhibit_name

tests/model_fields_/test_embedded_model.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
Data,
2525
Holder,
2626
Library,
27+
Movie,
2728
NestedData,
29+
Review,
2830
)
2931
from .utils import truncate_ms
3032

@@ -96,6 +98,29 @@ def test_pre_save(self):
9698
self.assertGreater(obj.data.auto_now, auto_now_two)
9799

98100

101+
class EmbeddedArrayTests(TestCase):
102+
def test_save_load(self):
103+
reviews = [
104+
Review(title="The best", rating=10),
105+
Review(title="Mediocre", rating=5),
106+
Review(title="Horrible", rating=1),
107+
]
108+
Movie.objects.create(title="Lion King", reviews=reviews)
109+
movie = Movie.objects.get(title="Lion King")
110+
self.assertEqual(movie.reviews[0].title, "The best")
111+
self.assertEqual(movie.reviews[0].rating, 10)
112+
self.assertEqual(movie.reviews[1].title, "Mediocre")
113+
self.assertEqual(movie.reviews[1].rating, 5)
114+
self.assertEqual(movie.reviews[2].title, "Horrible")
115+
self.assertEqual(movie.reviews[2].rating, 1)
116+
self.assertEqual(len(movie.reviews), 3)
117+
118+
def test_save_load_null(self):
119+
movie = Movie.objects.create(title="Lion King")
120+
movie = Movie.objects.get(title="Lion King")
121+
self.assertIsNone(movie.reviews)
122+
123+
99124
class QueryingTests(TestCase):
100125
@classmethod
101126
def setUpTestData(cls):

0 commit comments

Comments
 (0)