Skip to content

Commit 4ebd62f

Browse files
timgrahamWaVEV
authored andcommitted
EmbeddedModelArrayField Querying
1 parent 825ffca commit 4ebd62f

File tree

6 files changed

+388
-15
lines changed

6 files changed

+388
-15
lines changed

django_mongodb_backend/fields/array.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,7 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
326326
def as_mql(self, compiler, connection):
327327
lhs_mql = process_lhs(self, compiler, connection)
328328
value = process_rhs(self, compiler, connection)
329-
return {
330-
"$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}]
331-
}
329+
return {"$and": [{"$isArray": lhs_mql}, {"$size": {"$setIntersection": [value, lhs_mql]}}]}
332330

333331

334332
@ArrayField.register_lookup
@@ -338,7 +336,7 @@ class ArrayLenTransform(Transform):
338336

339337
def as_mql(self, compiler, connection):
340338
lhs_mql = process_lhs(self, compiler, connection)
341-
return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}}
339+
return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}}
342340

343341

344342
@ArrayField.register_lookup

django_mongodb_backend/fields/embedded_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,9 @@ def as_mql(self, compiler, connection):
186186
key_transforms.insert(0, previous.key_name)
187187
previous = previous.lhs
188188
mql = previous.as_mql(compiler, connection)
189-
transforms = ".".join(key_transforms)
190-
return f"{mql}.{transforms}"
189+
for key in key_transforms:
190+
mql = {"$getField": {"input": mql, "field": key}}
191+
return mql
191192

192193
@property
193194
def output_field(self):

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1-
from django.db.models import Field
1+
import difflib
2+
3+
from django.core.exceptions import FieldDoesNotExist
4+
from django.db.models import Field, lookups
5+
from django.db.models.expressions import Col
6+
from django.db.models.lookups import Lookup, Transform
27

38
from .. import forms
9+
from ..query_utils import process_lhs, process_rhs
410
from . import EmbeddedModelField
511
from .array import ArrayField
612

713

814
class EmbeddedModelArrayField(ArrayField):
15+
ALLOWED_LOOKUPS = {"exact", "len", "overlap"}
16+
917
def __init__(self, embedded_model, **kwargs):
1018
if "size" in kwargs:
1119
raise ValueError("EmbeddedModelArrayField does not support size.")
@@ -44,3 +52,144 @@ def formfield(self, **kwargs):
4452
**kwargs,
4553
},
4654
)
55+
56+
def get_transform(self, name):
57+
transform = super().get_transform(name)
58+
if transform:
59+
return transform
60+
return KeyTransformFactory(name, self)
61+
62+
def get_lookup(self, name):
63+
return super().get_lookup(name) if name in self.ALLOWED_LOOKUPS else None
64+
65+
66+
class EMFArrayRHSMixin:
67+
def check_lhs(self):
68+
if not isinstance(self.lhs, KeyTransform):
69+
raise ValueError(
70+
"Cannot apply this lookup directly to EmbeddedModelArrayField. "
71+
"Try querying one of its embedded fields instead."
72+
)
73+
74+
def process_rhs(self, compiler, connection):
75+
values = self.rhs
76+
# Value must be serealized based on the query target.
77+
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
78+
# field of the subfield. Otherwise, use the base field of the array itself.
79+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
80+
return None, [get_db_prep_value(values, connection, prepared=True)]
81+
82+
83+
@EmbeddedModelArrayField.register_lookup
84+
class EMFArrayExact(EMFArrayRHSMixin, lookups.Exact):
85+
def as_mql(self, compiler, connection):
86+
self.check_lhs()
87+
lhs_mql, inner_lhs_mql = process_lhs(self, compiler, connection)
88+
value = process_rhs(self, compiler, connection)
89+
return {
90+
"$anyElementTrue": {
91+
"$ifNull": [
92+
{
93+
"$map": {
94+
"input": lhs_mql,
95+
"as": "item",
96+
"in": {"$eq": [inner_lhs_mql, value]},
97+
}
98+
},
99+
[],
100+
]
101+
}
102+
}
103+
104+
105+
@EmbeddedModelArrayField.register_lookup
106+
class ArrayOverlap(EMFArrayRHSMixin, Lookup):
107+
lookup_name = "overlap"
108+
109+
def as_mql(self, compiler, connection):
110+
self.check_lhs()
111+
# Querying a subfield within the array elements (via nested KeyTransform).
112+
# Replicates MongoDB's implicit ANY-match by mapping over the array and applying
113+
# `$in` on the subfield.
114+
lhs_mql = process_lhs(self, compiler, connection)
115+
values = process_rhs(self, compiler, connection)
116+
lhs_mql, inner_lhs_mql = lhs_mql
117+
return {
118+
"$anyElementTrue": {
119+
"$ifNull": [
120+
{
121+
"$map": {
122+
"input": lhs_mql,
123+
"as": "item",
124+
"in": {"$in": [inner_lhs_mql, values]},
125+
}
126+
},
127+
[],
128+
]
129+
}
130+
}
131+
132+
133+
class KeyTransform(Transform):
134+
def __init__(self, key_name, array_field, *args, **kwargs):
135+
super().__init__(*args, **kwargs)
136+
self.array_field = array_field
137+
self.key_name = key_name
138+
# The iteration items begins from the base_field, a virtual column with
139+
# base field output type is created.
140+
column_target = array_field.embedded_model._meta.get_field(key_name).clone()
141+
column_name = f"$item.{key_name}"
142+
column_target.db_column = column_name
143+
column_target.set_attributes_from_name(column_name)
144+
self._lhs = Col(None, column_target)
145+
self._sub_transform = None
146+
147+
def __call__(self, this, *args, **kwargs):
148+
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
149+
return self
150+
151+
def get_lookup(self, name):
152+
return self.output_field.get_lookup(name)
153+
154+
def _get_missing_field_or_lookup_exception(self, lhs, name):
155+
suggested_lookups = difflib.get_close_matches(name, lhs.get_lookups())
156+
if suggested_lookups:
157+
suggested_lookups = " or ".join(suggested_lookups)
158+
suggestion = f", perhaps you meant {suggested_lookups}?"
159+
else:
160+
suggestion = ""
161+
raise FieldDoesNotExist(
162+
f"Unsupported lookup '{name}' for "
163+
f"EmbeddedModelArrayField of '{lhs.__class__.__name__}'"
164+
f"{suggestion}"
165+
)
166+
167+
def get_transform(self, name):
168+
"""
169+
Validate that `name` is either a field of an embedded model or a
170+
lookup on an embedded model's field.
171+
"""
172+
# Once the sub lhs is a transform, all the filter are applied over it.
173+
# Otherwise get transform from EMF.
174+
if transform := self._lhs.get_transform(name):
175+
self._sub_transform = transform
176+
return self
177+
raise self._get_missing_field_or_lookup_exception(self._lhs.output_field, name)
178+
179+
def as_mql(self, compiler, connection):
180+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
181+
lhs_mql = process_lhs(self, compiler, connection)
182+
return lhs_mql, inner_lhs_mql
183+
184+
@property
185+
def output_field(self):
186+
return self.array_field
187+
188+
189+
class KeyTransformFactory:
190+
def __init__(self, key_name, base_field):
191+
self.key_name = key_name
192+
self.base_field = base_field
193+
194+
def __call__(self, *args, **kwargs):
195+
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)

tests/model_fields_/models.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,35 @@ class Movie(models.Model):
165165

166166
def __str__(self):
167167
return self.title
168+
169+
170+
class RestorationRecord(EmbeddedModel):
171+
date = models.DateField()
172+
restored_by = models.CharField(max_length=255)
173+
174+
175+
class ArtifactDetail(EmbeddedModel):
176+
"""Details about a specific artifact."""
177+
178+
name = models.CharField(max_length=255)
179+
metadata = models.JSONField()
180+
restorations = EmbeddedModelArrayField(RestorationRecord, null=True)
181+
last_restoration = EmbeddedModelField(RestorationRecord, null=True)
182+
183+
184+
class ExhibitSection(EmbeddedModel):
185+
"""A section within an exhibit, containing multiple artifacts."""
186+
187+
section_number = models.IntegerField()
188+
artifacts = EmbeddedModelArrayField(ArtifactDetail, null=True)
189+
190+
191+
class MuseumExhibit(models.Model):
192+
"""An exhibit in the museum, composed of multiple sections."""
193+
194+
exhibit_name = models.CharField(max_length=255)
195+
sections = EmbeddedModelArrayField(ExhibitSection, null=True)
196+
main_section = EmbeddedModelField(ExhibitSection, null=True)
197+
198+
def __str__(self):
199+
return self.exhibit_name

tests/model_fields_/test_embedded_model.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
Data,
2525
Holder,
2626
Library,
27+
Movie,
2728
NestedData,
29+
Review,
2830
)
2931
from .utils import truncate_ms
3032

@@ -96,6 +98,29 @@ def test_pre_save(self):
9698
self.assertGreater(obj.data.auto_now, auto_now_two)
9799

98100

101+
class EmbeddedArrayTests(TestCase):
102+
def test_save_load(self):
103+
reviews = [
104+
Review(title="The best", rating=10),
105+
Review(title="Mediocre", rating=5),
106+
Review(title="Horrible", rating=1),
107+
]
108+
Movie.objects.create(title="Lion King", reviews=reviews)
109+
movie = Movie.objects.get(title="Lion King")
110+
self.assertEqual(movie.reviews[0].title, "The best")
111+
self.assertEqual(movie.reviews[0].rating, 10)
112+
self.assertEqual(movie.reviews[1].title, "Mediocre")
113+
self.assertEqual(movie.reviews[1].rating, 5)
114+
self.assertEqual(movie.reviews[2].title, "Horrible")
115+
self.assertEqual(movie.reviews[2].rating, 1)
116+
self.assertEqual(len(movie.reviews), 3)
117+
118+
def test_save_load_null(self):
119+
movie = Movie.objects.create(title="Lion King")
120+
movie = Movie.objects.get(title="Lion King")
121+
self.assertIsNone(movie.reviews)
122+
123+
99124
class QueryingTests(TestCase):
100125
@classmethod
101126
def setUpTestData(cls):

0 commit comments

Comments
 (0)