Skip to content

Commit 13efcd1

Browse files
WaVEVtimgraham
authored andcommitted
Add support for querying EmbeddedModelArrayField
1 parent 17fb974 commit 13efcd1

File tree

6 files changed

+555
-8
lines changed

6 files changed

+555
-8
lines changed

django_mongodb_backend/fields/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ class ArrayLenTransform(Transform):
338338

339339
def as_mql(self, compiler, connection):
340340
lhs_mql = process_lhs(self, compiler, connection)
341-
return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}}
341+
return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}}
342342

343343

344344
@ArrayField.register_lookup

django_mongodb_backend/fields/embedded_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,9 @@ def as_mql(self, compiler, connection):
186186
key_transforms.insert(0, previous.key_name)
187187
previous = previous.lhs
188188
mql = previous.as_mql(compiler, connection)
189-
transforms = ".".join(key_transforms)
190-
return f"{mql}.{transforms}"
189+
for key in key_transforms:
190+
mql = {"$getField": {"input": mql, "field": key}}
191+
return mql
191192

192193
@property
193194
def output_field(self):

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 204 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1-
from django.db.models import Field
1+
import difflib
2+
3+
from django.core.exceptions import FieldDoesNotExist
4+
from django.db.models import Field, lookups
5+
from django.db.models.expressions import Col
26
from django.db.models.fields.related import lazy_related_operation
7+
from django.db.models.lookups import Lookup, Transform
38

49
from .. import forms
10+
from ..query_utils import process_lhs, process_rhs
511
from . import EmbeddedModelField
6-
from .array import ArrayField
12+
from .array import ArrayField, ArrayLenTransform
713

814

915
class EmbeddedModelArrayField(ArrayField):
@@ -56,3 +62,199 @@ def formfield(self, **kwargs):
5662
**kwargs,
5763
},
5864
)
65+
66+
def get_transform(self, name):
67+
transform = super().get_transform(name)
68+
if transform:
69+
return transform
70+
return KeyTransformFactory(name, self)
71+
72+
def _get_lookup(self, lookup_name):
73+
lookup = super()._get_lookup(lookup_name)
74+
if lookup is None or lookup is ArrayLenTransform:
75+
return lookup
76+
77+
class EmbeddedModelArrayFieldLookups(Lookup):
78+
def as_mql(self, compiler, connection):
79+
raise ValueError(
80+
"Lookups aren't supported on EmbeddedModelArrayField. "
81+
"Try querying one of its embedded fields instead."
82+
)
83+
84+
return EmbeddedModelArrayFieldLookups
85+
86+
87+
class _EmbeddedModelArrayOutputField(ArrayField):
88+
"""
89+
Represent the output of an EmbeddedModelArrayField when traversed in a
90+
query path.
91+
92+
This field is not meant to be used in model definitions. It exists solely
93+
to support query output resolution. When an EmbeddedModelArrayField is
94+
accessed in a query, the result should behave like an array of the embedded
95+
model's target type.
96+
97+
While it mimics ArrayField's lookup behavior, the way those lookups are
98+
resolved follows the semantics of EmbeddedModelArrayField rather than
99+
ArrayField.
100+
"""
101+
102+
ALLOWED_LOOKUPS = {
103+
"in",
104+
"exact",
105+
"iexact",
106+
"gt",
107+
"gte",
108+
"lt",
109+
"lte",
110+
}
111+
112+
def get_lookup(self, name):
113+
return super().get_lookup(name) if name in self.ALLOWED_LOOKUPS else None
114+
115+
116+
class EmbeddedModelArrayFieldBuiltinLookup(Lookup):
117+
def process_rhs(self, compiler, connection):
118+
value = self.rhs
119+
if not self.get_db_prep_lookup_value_is_iterable:
120+
value = [value]
121+
# Value must be serialized based on the query target. If querying a
122+
# subfield inside the array (i.e., a nested KeyTransform), use the
123+
# output field of the subfield. Otherwise, use the base field of the
124+
# array itself.
125+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
126+
return None, [
127+
v if hasattr(v, "as_mql") else get_db_prep_value(v, connection, prepared=True)
128+
for v in value
129+
]
130+
131+
def as_mql(self, compiler, connection):
132+
# Querying a subfield within the array elements (via nested
133+
# KeyTransform). Replicate MongoDB's implicit ANY-match by mapping over
134+
# the array and applying $in on the subfield.
135+
lhs_mql = process_lhs(self, compiler, connection)
136+
inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"]
137+
values = process_rhs(self, compiler, connection)
138+
lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators[self.lookup_name](
139+
inner_lhs_mql, values
140+
)
141+
return {"$anyElementTrue": lhs_mql}
142+
143+
144+
@_EmbeddedModelArrayOutputField.register_lookup
145+
class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In):
146+
pass
147+
148+
149+
@_EmbeddedModelArrayOutputField.register_lookup
150+
class EmbeddedModelArrayFieldExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.Exact):
151+
pass
152+
153+
154+
@_EmbeddedModelArrayOutputField.register_lookup
155+
class EmbeddedModelArrayFieldIExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.IExact):
156+
get_db_prep_lookup_value_is_iterable = False
157+
158+
159+
@_EmbeddedModelArrayOutputField.register_lookup
160+
class EmbeddedModelArrayFieldGreaterThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThan):
161+
pass
162+
163+
164+
@_EmbeddedModelArrayOutputField.register_lookup
165+
class EmbeddedModelArrayFieldGreaterThanOrEqual(
166+
EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThanOrEqual
167+
):
168+
pass
169+
170+
171+
@_EmbeddedModelArrayOutputField.register_lookup
172+
class EmbeddedModelArrayFieldLessThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThan):
173+
pass
174+
175+
176+
@_EmbeddedModelArrayOutputField.register_lookup
177+
class EmbeddedModelArrayFieldLessThanOrEqual(
178+
EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThanOrEqual
179+
):
180+
pass
181+
182+
183+
class KeyTransform(Transform):
184+
def __init__(self, key_name, array_field, *args, **kwargs):
185+
super().__init__(*args, **kwargs)
186+
self.array_field = array_field
187+
self.key_name = key_name
188+
# Lookups iterate over the array of embedded models. A virtual column
189+
# of the queried field's type represents each element.
190+
column_target = array_field.base_field.embedded_model._meta.get_field(key_name).clone()
191+
column_name = f"$item.{key_name}"
192+
column_target.db_column = column_name
193+
column_target.set_attributes_from_name(column_name)
194+
self._lhs = Col(None, column_target)
195+
self._sub_transform = None
196+
197+
def __call__(self, this, *args, **kwargs):
198+
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
199+
return self
200+
201+
def get_lookup(self, name):
202+
return self.output_field.get_lookup(name)
203+
204+
def get_transform(self, name):
205+
"""
206+
Validate that `name` is either a field of an embedded model or am
207+
allowed lookup on an embedded model's field.
208+
"""
209+
# Once the sub-lhs is a transform, all the filters are applied over it.
210+
# Otherwise get the transform from the nested embedded model field.
211+
if transform := self._lhs.get_transform(name):
212+
if isinstance(transform, KeyTransformFactory):
213+
raise ValueError("Cannot perform multiple levels of array traversal in a query.")
214+
self._sub_transform = transform
215+
return self
216+
output_field = self._lhs.output_field
217+
# The lookup must be allowed AND a valid lookup for the field.
218+
allowed_lookups = self.output_field.ALLOWED_LOOKUPS.intersection(
219+
set(output_field.get_lookups())
220+
)
221+
suggested_lookups = difflib.get_close_matches(name, allowed_lookups)
222+
if suggested_lookups:
223+
suggested_lookups = " or ".join(suggested_lookups)
224+
suggestion = f", perhaps you meant {suggested_lookups}?"
225+
else:
226+
suggestion = ""
227+
raise FieldDoesNotExist(
228+
f"Unsupported lookup '{name}' for "
229+
f"EmbeddedModelArrayField of '{output_field.__class__.__name__}'"
230+
f"{suggestion}"
231+
)
232+
233+
def as_mql(self, compiler, connection):
234+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
235+
lhs_mql = process_lhs(self, compiler, connection)
236+
return {
237+
"$ifNull": [
238+
{
239+
"$map": {
240+
"input": lhs_mql,
241+
"as": "item",
242+
"in": inner_lhs_mql,
243+
}
244+
},
245+
[],
246+
]
247+
}
248+
249+
@property
250+
def output_field(self):
251+
return _EmbeddedModelArrayOutputField(self._lhs.output_field)
252+
253+
254+
class KeyTransformFactory:
255+
def __init__(self, key_name, base_field):
256+
self.key_name = key_name
257+
self.base_field = base_field
258+
259+
def __call__(self, *args, **kwargs):
260+
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)

docs/source/topics/embedded-models.rst

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,69 @@ Represented in BSON, the post's structure looks like this:
115115
name: 'Hello world!',
116116
tags: [ { name: 'welcome' }, { name: 'test' } ]
117117
}
118+
119+
Querying ``EmbeddedModelArrayField``
120+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
121+
122+
You can query into an embedded model array using the same double underscore
123+
syntax as relational fields. For example, to find posts that have a tag with
124+
name "test"::
125+
126+
>>> Post.objects.filter(tags__name="test")
127+
128+
There are a limited set of lookups you can chain after an embedded field:
129+
130+
* :lookup:`exact`, :lookup:`iexact`
131+
* :lookup:`in`
132+
* :lookup:`gt`, :lookup:`gte`, :lookup:`lt`, :lookup:`lte`
133+
134+
For example, to find posts that have tags with name "test", "TEST", "tEsT",
135+
etc::
136+
137+
>>> Post.objects.filter(tags__name__iexact="test")
138+
139+
.. fieldlookup:: embeddedmodelarrayfield.len
140+
141+
``len`` transform
142+
^^^^^^^^^^^^^^^^^
143+
144+
You can use the ``len`` transform to filter on the length of the array. The
145+
lookups available afterward are those available for
146+
:class:`~django.db.models.IntegerField`. For example, to match posts with one
147+
tag::
148+
149+
>>> Post.objects.filter(tags__len=1)
150+
151+
or at least one tag::
152+
153+
>>> Post.objects.filter(tags__len__gte=1)
154+
155+
Index and slice transforms
156+
^^^^^^^^^^^^^^^^^^^^^^^^^^
157+
158+
Like :class:`~django_mongodb_backend.fields.ArrayField`, you can use
159+
:lookup:`index <mongo-arrayfield.index>` and :lookup:`slice
160+
<mongo-arrayfield.slice>` transforms to filter on particular items in an array.
161+
162+
For example, to find posts where the first tag is named "test"::
163+
164+
>>> Post.objects.filter(tags__0__name="test")
165+
166+
Or to find posts where the one of the first two tags is named "test"::
167+
168+
>>> Post.objects.filter(tags__0_1__name="test")
169+
170+
These indexes use 0-based indexing.
171+
172+
Nested ``EmbeddedModelArrayField``\s
173+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
174+
175+
If your models use nested ``EmbeddedModelArrayField``\s, you can't use double
176+
underscores to query into the the second level.
177+
178+
For example, if the ``Tag`` model had an ``EmbeddedModelArrayField`` called
179+
``colors``:
180+
181+
>>> Post.objects.filter(tags__colors__name="blue")
182+
...
183+
ValueError: Cannot perform multiple levels of array traversal in a query.

tests/model_fields_/models.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,51 @@ class Review(EmbeddedModel):
165165

166166
def __str__(self):
167167
return self.title
168+
169+
170+
# An exhibit in the museum, composed of multiple sections.
171+
class Exhibit(models.Model):
172+
exhibit_name = models.CharField(max_length=255)
173+
sections = EmbeddedModelArrayField("Section", null=True)
174+
main_section = EmbeddedModelField("Section", null=True)
175+
176+
def __str__(self):
177+
return self.exhibit_name
178+
179+
180+
# A section within an exhibit, containing multiple artifacts.
181+
class Section(EmbeddedModel):
182+
section_number = models.IntegerField()
183+
artifacts = EmbeddedModelArrayField("Artifact", null=True)
184+
185+
def __str__(self):
186+
return "Section %d" % self.section_number
187+
188+
189+
# Details about a specific artifact.
190+
class Artifact(EmbeddedModel):
191+
name = models.CharField(max_length=255)
192+
metadata = models.JSONField()
193+
restorations = EmbeddedModelArrayField("Restoration", null=True)
194+
last_restoration = EmbeddedModelField("Restoration", null=True)
195+
196+
def __str__(self):
197+
return self.name
198+
199+
200+
# Details about when an artifact was restored.
201+
class Restoration(EmbeddedModel):
202+
date = models.DateField()
203+
restored_by = models.CharField(max_length=255)
204+
205+
def __str__(self):
206+
return f"Restored by {self.restored_by} on {self.date}"
207+
208+
209+
# ForeignKey to a model with EmbeddedModelArrayField.
210+
class Tour(models.Model):
211+
guide = models.CharField(max_length=100)
212+
exhibit = models.ForeignKey(Exhibit, models.CASCADE)
213+
214+
def __str__(self):
215+
return f"Tour by {self.guide}"

0 commit comments

Comments
 (0)