Skip to content

Commit e978e60

Browse files
committed
Add new lookups and rework overlap as in.
1 parent e139bd5 commit e978e60

File tree

3 files changed

+186
-80
lines changed

3 files changed

+186
-80
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 108 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,19 @@
1212

1313

1414
class EmbeddedModelArrayField(ArrayField):
15-
ALLOWED_LOOKUPS = {"exact", "len", "overlap"}
15+
ALLOWED_LOOKUPS = {
16+
"in",
17+
"exact",
18+
"iexact",
19+
"regex",
20+
"iregex",
21+
"gt",
22+
"gte",
23+
"lt",
24+
"lte",
25+
"all",
26+
"contained_by",
27+
}
1628

1729
def __init__(self, embedded_model, **kwargs):
1830
if "size" in kwargs:
@@ -63,7 +75,7 @@ def get_lookup(self, name):
6375
return super().get_lookup(name) if name in self.ALLOWED_LOOKUPS else None
6476

6577

66-
class EMFArrayRHSMixin:
78+
class EMFArrayBuildinLookup(Lookup):
6779
def check_lhs(self):
6880
if not isinstance(self.lhs, KeyTransform):
6981
raise ValueError(
@@ -72,28 +84,35 @@ def check_lhs(self):
7284
)
7385

7486
def process_rhs(self, compiler, connection):
75-
values = self.rhs
87+
value = self.rhs
88+
if not self.get_db_prep_lookup_value_is_iterable:
89+
value = [value]
7690
# Value must be serealized based on the query target.
77-
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
91+
# If querying a subfield inside tche array (i.e., a nested KeyTransform), use the output
7892
# field of the subfield. Otherwise, use the base field of the array itself.
7993
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
80-
return None, [get_db_prep_value(values, connection, prepared=True)]
94+
return None, [
95+
v if hasattr(v, "as_mql") else get_db_prep_value(v, connection, prepared=True)
96+
for v in value
97+
]
8198

82-
83-
@EmbeddedModelArrayField.register_lookup
84-
class EMFArrayExact(EMFArrayRHSMixin, lookups.Exact):
8599
def as_mql(self, compiler, connection):
86100
self.check_lhs()
101+
# Querying a subfield within the array elements (via nested KeyTransform).
102+
# Replicates MongoDB's implicit ANY-match by mapping over the array and applying
103+
# `$in` on the subfield.
87104
lhs_mql, inner_lhs_mql = process_lhs(self, compiler, connection)
88-
value = process_rhs(self, compiler, connection)
105+
values = process_rhs(self, compiler, connection)
89106
return {
90107
"$anyElementTrue": {
91108
"$ifNull": [
92109
{
93110
"$map": {
94111
"input": lhs_mql,
95112
"as": "item",
96-
"in": {"$eq": [inner_lhs_mql, value]},
113+
"in": connection.mongo_operators[self.lookup_name](
114+
inner_lhs_mql, values
115+
),
97116
}
98117
},
99118
[],
@@ -103,31 +122,90 @@ def as_mql(self, compiler, connection):
103122

104123

105124
@EmbeddedModelArrayField.register_lookup
106-
class ArrayOverlap(EMFArrayRHSMixin, Lookup):
107-
lookup_name = "overlap"
125+
class EMFArrayIn(EMFArrayBuildinLookup, lookups.In):
126+
pass
127+
128+
129+
@EmbeddedModelArrayField.register_lookup
130+
class EMFArrayExact(EMFArrayBuildinLookup, lookups.Exact):
131+
pass
132+
133+
134+
@EmbeddedModelArrayField.register_lookup
135+
class EMFArrayIExact(EMFArrayBuildinLookup, lookups.IExact):
136+
get_db_prep_lookup_value_is_iterable = False
137+
138+
139+
@EmbeddedModelArrayField.register_lookup
140+
class EMFArrayGreaterThan(EMFArrayBuildinLookup, lookups.GreaterThan):
141+
pass
142+
143+
144+
@EmbeddedModelArrayField.register_lookup
145+
class EMFArrayGreaterThanOrEqual(EMFArrayBuildinLookup, lookups.GreaterThanOrEqual):
146+
pass
147+
148+
149+
@EmbeddedModelArrayField.register_lookup
150+
class EMFArrayLessThan(EMFArrayBuildinLookup, lookups.LessThan):
151+
pass
152+
153+
154+
@EmbeddedModelArrayField.register_lookup
155+
class EMFArrayLessThanOrEqual(EMFArrayBuildinLookup, lookups.LessThanOrEqual):
156+
pass
157+
158+
159+
@EmbeddedModelArrayField.register_lookup
160+
class EMFArrayAll(EMFArrayBuildinLookup, Lookup):
161+
lookup_name = "all"
162+
get_db_prep_lookup_value_is_iterable = False
108163

109164
def as_mql(self, compiler, connection):
110165
self.check_lhs()
111-
# Querying a subfield within the array elements (via nested KeyTransform).
112-
# Replicates MongoDB's implicit ANY-match by mapping over the array and applying
113-
# `$in` on the subfield.
114-
lhs_mql = process_lhs(self, compiler, connection)
166+
lhs_mql, inner_lhs_mql = process_lhs(self, compiler, connection)
115167
values = process_rhs(self, compiler, connection)
116-
lhs_mql, inner_lhs_mql = lhs_mql
117168
return {
118-
"$anyElementTrue": {
119-
"$ifNull": [
120-
{
121-
"$map": {
122-
"input": lhs_mql,
123-
"as": "item",
124-
"in": {"$in": [inner_lhs_mql, values]},
125-
}
126-
},
127-
[],
128-
]
169+
"$setIsSubset": [
170+
values,
171+
{
172+
"$ifNull": [
173+
{
174+
"$map": {
175+
"input": lhs_mql,
176+
"as": "item",
177+
"in": inner_lhs_mql,
178+
}
179+
},
180+
[],
181+
]
182+
},
183+
]
184+
}
185+
186+
187+
@EmbeddedModelArrayField.register_lookup
188+
class ArrayContainedBy(EMFArrayBuildinLookup, Lookup):
189+
lookup_name = "contained_by"
190+
get_db_prep_lookup_value_is_iterable = False
191+
192+
def as_mql(self, compiler, connection):
193+
lhs_mql, inner_lhs_mql = process_lhs(self, compiler, connection)
194+
lhs_mql = {
195+
"$map": {
196+
"input": lhs_mql,
197+
"as": "item",
198+
"in": inner_lhs_mql,
129199
}
130200
}
201+
value = process_rhs(self, compiler, connection)
202+
return {
203+
"$and": [
204+
{"$ne": [lhs_mql, None]},
205+
{"$ne": [value, None]},
206+
{"$setIsSubset": [lhs_mql, value]},
207+
]
208+
}
131209

132210

133211
class KeyTransform(Transform):
@@ -159,6 +237,8 @@ def get_transform(self, name):
159237
# Once the sub lhs is a transform, all the filter are applied over it.
160238
# Otherwise get transform from EMF.
161239
if transform := self._lhs.get_transform(name):
240+
if isinstance(transform, KeyTransformFactory):
241+
raise ValueError("Cannot perform multiple levels of array traversal in a query.")
162242
self._sub_transform = transform
163243
return self
164244
output_field = self._lhs.output_field

tests/model_fields_/models.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,25 +172,22 @@ class RestorationRecord(EmbeddedModel):
172172
restored_by = models.CharField(max_length=255)
173173

174174

175+
# Details about a specific artifact.
175176
class ArtifactDetail(EmbeddedModel):
176-
"""Details about a specific artifact."""
177-
178177
name = models.CharField(max_length=255)
179178
metadata = models.JSONField()
180179
restorations = EmbeddedModelArrayField(RestorationRecord, null=True)
181180
last_restoration = EmbeddedModelField(RestorationRecord, null=True)
182181

183182

183+
# A section within an exhibit, containing multiple artifacts.
184184
class ExhibitSection(EmbeddedModel):
185-
"""A section within an exhibit, containing multiple artifacts."""
186-
187185
section_number = models.IntegerField()
188186
artifacts = EmbeddedModelArrayField(ArtifactDetail, null=True)
189187

190188

189+
# An exhibit in the museum, composed of multiple sections.
191190
class MuseumExhibit(models.Model):
192-
"""An exhibit in the museum, composed of multiple sections."""
193-
194191
exhibit_name = models.CharField(max_length=255)
195192
sections = EmbeddedModelArrayField(ExhibitSection, null=True)
196193
main_section = EmbeddedModelField(ExhibitSection, null=True)

0 commit comments

Comments
 (0)