Skip to content

Commit 9e36afe

Browse files
committed
EmbeddedModelArrayField Querying
1 parent fe24e7a commit 9e36afe

File tree

5 files changed

+534
-8
lines changed

5 files changed

+534
-8
lines changed

django_mongodb_backend/fields/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ class ArrayLenTransform(Transform):
338338

339339
def as_mql(self, compiler, connection):
340340
lhs_mql = process_lhs(self, compiler, connection)
341-
return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}}
341+
return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}}
342342

343343

344344
@ArrayField.register_lookup

django_mongodb_backend/fields/embedded_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,9 @@ def as_mql(self, compiler, connection):
186186
key_transforms.insert(0, previous.key_name)
187187
previous = previous.lhs
188188
mql = previous.as_mql(compiler, connection)
189-
transforms = ".".join(key_transforms)
190-
return f"{mql}.{transforms}"
189+
for key in key_transforms:
190+
mql = {"$getField": {"input": mql, "field": key}}
191+
return mql
191192

192193
@property
193194
def output_field(self):

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 235 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1-
from django.db.models import Field
1+
import difflib
2+
3+
from django.core.exceptions import FieldDoesNotExist
4+
from django.db.models import Field, lookups
5+
from django.db.models.expressions import Col
6+
from django.db.models.lookups import Lookup, Transform
27

38
from .. import forms
9+
from ..query_utils import process_lhs, process_rhs
410
from . import EmbeddedModelField
5-
from .array import ArrayField
11+
from .array import ArrayField, ArrayLenTransform
612

713

814
class EmbeddedModelArrayField(ArrayField):
@@ -44,3 +50,230 @@ def formfield(self, **kwargs):
4450
**kwargs,
4551
},
4652
)
53+
54+
def get_transform(self, name):
55+
transform = super().get_transform(name)
56+
if transform:
57+
return transform
58+
return KeyTransformFactory(name, self)
59+
60+
def _get_lookup(self, lookup_name):
61+
lookup = super()._get_lookup(lookup_name)
62+
if lookup is None or lookup is ArrayLenTransform:
63+
return lookup
64+
65+
class EmbeddedModelArrayFieldLookups(Lookup):
66+
def as_mql(self, compiler, connection):
67+
raise ValueError(
68+
"Cannot apply this lookup directly to EmbeddedModelArrayField. "
69+
"Try querying one of its embedded fields instead."
70+
)
71+
72+
return EmbeddedModelArrayFieldLookups
73+
74+
75+
class _EmbeddedModelArrayOutputField(ArrayField):
76+
"""
77+
Represents the output of an EmbeddedModelArrayField when traversed in a query path.
78+
79+
This field is not meant to be used directly in model definitions. It exists solely to
80+
support query output resolution; when an EmbeddedModelArrayField is accessed in a query,
81+
the result should behave like an array of the embedded model's target type.
82+
83+
While it mimics ArrayField's lookups behavior, the way those lookups are resolved
84+
follows the semantics of EmbeddedModelArrayField rather than native array behavior.
85+
"""
86+
87+
ALLOWED_LOOKUPS = {
88+
"in",
89+
"exact",
90+
"iexact",
91+
"gt",
92+
"gte",
93+
"lt",
94+
"lte",
95+
"all",
96+
"contained_by",
97+
}
98+
99+
def get_lookup(self, name):
100+
return super().get_lookup(name) if name in self.ALLOWED_LOOKUPS else None
101+
102+
103+
class EmbeddedModelArrayFieldBuiltinLookup(Lookup):
104+
def process_rhs(self, compiler, connection):
105+
value = self.rhs
106+
if not self.get_db_prep_lookup_value_is_iterable:
107+
value = [value]
108+
# Value must be serialized based on the query target.
109+
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
110+
# field of the subfield. Otherwise, use the base field of the array itself.
111+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
112+
return None, [
113+
v if hasattr(v, "as_mql") else get_db_prep_value(v, connection, prepared=True)
114+
for v in value
115+
]
116+
117+
def as_mql(self, compiler, connection):
118+
# Querying a subfield within the array elements (via nested KeyTransform).
119+
# Replicates MongoDB's implicit ANY-match by mapping over the array and applying
120+
# `$in` on the subfield.
121+
lhs_mql = process_lhs(self, compiler, connection)
122+
inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"]
123+
values = process_rhs(self, compiler, connection)
124+
lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators[self.lookup_name](
125+
inner_lhs_mql, values
126+
)
127+
return {"$anyElementTrue": lhs_mql}
128+
129+
130+
@_EmbeddedModelArrayOutputField.register_lookup
131+
class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In):
132+
pass
133+
134+
135+
@_EmbeddedModelArrayOutputField.register_lookup
136+
class EmbeddedModelArrayFieldExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.Exact):
137+
pass
138+
139+
140+
@_EmbeddedModelArrayOutputField.register_lookup
141+
class EmbeddedModelArrayFieldIExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.IExact):
142+
get_db_prep_lookup_value_is_iterable = False
143+
144+
145+
@_EmbeddedModelArrayOutputField.register_lookup
146+
class EmbeddedModelArrayFieldGreaterThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThan):
147+
pass
148+
149+
150+
@_EmbeddedModelArrayOutputField.register_lookup
151+
class EmbeddedModelArrayFieldGreaterThanOrEqual(
152+
EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThanOrEqual
153+
):
154+
pass
155+
156+
157+
@_EmbeddedModelArrayOutputField.register_lookup
158+
class EmbeddedModelArrayFieldLessThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThan):
159+
pass
160+
161+
162+
@_EmbeddedModelArrayOutputField.register_lookup
163+
class EmbeddedModelArrayFieldLessThanOrEqual(
164+
EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThanOrEqual
165+
):
166+
pass
167+
168+
169+
@_EmbeddedModelArrayOutputField.register_lookup
170+
class EmbeddedModelArrayFieldAll(EmbeddedModelArrayFieldBuiltinLookup, Lookup):
171+
lookup_name = "all"
172+
get_db_prep_lookup_value_is_iterable = False
173+
174+
def as_mql(self, compiler, connection):
175+
lhs_mql = process_lhs(self, compiler, connection)
176+
values = process_rhs(self, compiler, connection)
177+
return {
178+
"$and": [
179+
{"$ne": [lhs_mql, None]},
180+
{"$ne": [values, None]},
181+
{"$setIsSubset": [values, lhs_mql]},
182+
]
183+
}
184+
185+
186+
@_EmbeddedModelArrayOutputField.register_lookup
187+
class ArrayContainedBy(EmbeddedModelArrayFieldBuiltinLookup, Lookup):
188+
lookup_name = "contained_by"
189+
get_db_prep_lookup_value_is_iterable = False
190+
191+
def as_mql(self, compiler, connection):
192+
lhs_mql = process_lhs(self, compiler, connection)
193+
value = process_rhs(self, compiler, connection)
194+
return {
195+
"$and": [
196+
{"$ne": [lhs_mql, None]},
197+
{"$ne": [value, None]},
198+
{"$setIsSubset": [lhs_mql, value]},
199+
]
200+
}
201+
202+
203+
class KeyTransform(Transform):
204+
def __init__(self, key_name, array_field, *args, **kwargs):
205+
super().__init__(*args, **kwargs)
206+
self.array_field = array_field
207+
self.key_name = key_name
208+
# The iteration items begins from the base_field, a virtual column with
209+
# base field output type is created.
210+
column_target = array_field.embedded_model._meta.get_field(key_name).clone()
211+
column_name = f"$item.{key_name}"
212+
column_target.db_column = column_name
213+
column_target.set_attributes_from_name(column_name)
214+
self._lhs = Col(None, column_target)
215+
self._sub_transform = None
216+
217+
def __call__(self, this, *args, **kwargs):
218+
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
219+
return self
220+
221+
def get_lookup(self, name):
222+
return self.output_field.get_lookup(name)
223+
224+
def get_transform(self, name):
225+
"""
226+
Validate that `name` is either a field of an embedded model or a
227+
lookup on an embedded model's field.
228+
"""
229+
# Once the sub lhs is a transform, all the filter are applied over it.
230+
# Otherwise get transform from EMF.
231+
if transform := self._lhs.get_transform(name):
232+
if isinstance(transform, KeyTransformFactory):
233+
raise ValueError("Cannot perform multiple levels of array traversal in a query.")
234+
self._sub_transform = transform
235+
return self
236+
output_field = self._lhs.output_field
237+
allowed_lookups = self.output_field.ALLOWED_LOOKUPS.intersection(
238+
set(output_field.get_lookups())
239+
)
240+
suggested_lookups = difflib.get_close_matches(name, allowed_lookups)
241+
if suggested_lookups:
242+
suggested_lookups = " or ".join(suggested_lookups)
243+
suggestion = f", perhaps you meant {suggested_lookups}?"
244+
else:
245+
suggestion = ""
246+
raise FieldDoesNotExist(
247+
f"Unsupported lookup '{name}' for "
248+
f"EmbeddedModelArrayField of '{output_field.__class__.__name__}'"
249+
f"{suggestion}"
250+
)
251+
252+
def as_mql(self, compiler, connection):
253+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
254+
lhs_mql = process_lhs(self, compiler, connection)
255+
return {
256+
"$ifNull": [
257+
{
258+
"$map": {
259+
"input": lhs_mql,
260+
"as": "item",
261+
"in": inner_lhs_mql,
262+
}
263+
},
264+
[],
265+
]
266+
}
267+
268+
@property
269+
def output_field(self):
270+
return _EmbeddedModelArrayOutputField(self._lhs.output_field)
271+
272+
273+
class KeyTransformFactory:
274+
def __init__(self, key_name, base_field):
275+
self.key_name = key_name
276+
self.base_field = base_field
277+
278+
def __call__(self, *args, **kwargs):
279+
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)

tests/model_fields_/models.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,40 @@ class Movie(models.Model):
165165

166166
def __str__(self):
167167
return self.title
168+
169+
170+
class RestorationRecord(EmbeddedModel):
171+
date = models.DateField()
172+
restored_by = models.CharField(max_length=255)
173+
174+
175+
# Details about a specific artifact.
176+
class ArtifactDetail(EmbeddedModel):
177+
name = models.CharField(max_length=255)
178+
metadata = models.JSONField()
179+
restorations = EmbeddedModelArrayField(RestorationRecord, null=True)
180+
last_restoration = EmbeddedModelField(RestorationRecord, null=True)
181+
182+
183+
# A section within an exhibit, containing multiple artifacts.
184+
class ExhibitSection(EmbeddedModel):
185+
section_number = models.IntegerField()
186+
artifacts = EmbeddedModelArrayField(ArtifactDetail, null=True)
187+
188+
189+
# An exhibit in the museum, composed of multiple sections.
190+
class MuseumExhibit(models.Model):
191+
exhibit_name = models.CharField(max_length=255)
192+
sections = EmbeddedModelArrayField(ExhibitSection, null=True)
193+
main_section = EmbeddedModelField(ExhibitSection, null=True)
194+
195+
def __str__(self):
196+
return self.exhibit_name
197+
198+
199+
class Tour(models.Model):
200+
guide = models.CharField(max_length=100)
201+
exhibit = models.ForeignKey(MuseumExhibit, on_delete=models.CASCADE)
202+
203+
def __str__(self):
204+
return f"Tour by {self.guide}"

0 commit comments

Comments
 (0)