Skip to content

Commit 08a6e5b

Browse files
committed
Handle output type as a separate field
1 parent 2de42ec commit 08a6e5b

File tree

2 files changed

+87
-80
lines changed

2 files changed

+87
-80
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 75 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,10 @@
88
from .. import forms
99
from ..query_utils import process_lhs, process_rhs
1010
from . import EmbeddedModelField
11-
from .array import ArrayField
11+
from .array import ArrayField, ArrayLenTransform
1212

1313

1414
class EmbeddedModelArrayField(ArrayField):
15-
ALLOWED_LOOKUPS = {
16-
"in",
17-
"exact",
18-
"iexact",
19-
"gt",
20-
"gte",
21-
"lt",
22-
"lte",
23-
"all",
24-
"contained_by",
25-
}
26-
2715
def __init__(self, embedded_model, **kwargs):
2816
if "size" in kwargs:
2917
raise ValueError("EmbeddedModelArrayField does not support size.")
@@ -69,18 +57,50 @@ def get_transform(self, name):
6957
return transform
7058
return KeyTransformFactory(name, self)
7159

60+
def _get_lookup(self, lookup_name):
61+
lookup = super()._get_lookup(lookup_name)
62+
if lookup is None or lookup is ArrayLenTransform:
63+
return lookup
64+
65+
class EmbeddedModelArrayFieldLookups(Lookup):
66+
def as_mql(self, compiler, connection):
67+
raise ValueError(
68+
"Cannot apply this lookup directly to EmbeddedModelArrayField. "
69+
"Try querying one of its embedded fields instead."
70+
)
71+
72+
return EmbeddedModelArrayFieldLookups
73+
74+
75+
class _EmbeddedModelArrayOutputField(ArrayField):
76+
"""
77+
Represents the output of an EmbeddedModelArrayField when traversed in a query path.
78+
79+
This field is not meant to be used directly in model definitions. It exists solely to
80+
support query output resolution; when an EmbeddedModelArrayField is accessed in a query,
81+
the result should behave like an array of the embedded model's target type.
82+
83+
While it mimics ArrayField's lookups behavior, the way those lookups are resolved
84+
follows the semantics of EmbeddedModelArrayField rather than native array behavior.
85+
"""
86+
87+
ALLOWED_LOOKUPS = {
88+
"in",
89+
"exact",
90+
"iexact",
91+
"gt",
92+
"gte",
93+
"lt",
94+
"lte",
95+
"all",
96+
"contained_by",
97+
}
98+
7299
def get_lookup(self, name):
73100
return super().get_lookup(name) if name in self.ALLOWED_LOOKUPS else None
74101

75102

76103
class EmbeddedModelArrayFieldBuiltinLookup(Lookup):
77-
def check_lhs(self):
78-
if not isinstance(self.lhs, KeyTransform):
79-
raise ValueError(
80-
"Cannot apply this lookup directly to EmbeddedModelArrayField. "
81-
"Try querying one of its embedded fields instead."
82-
)
83-
84104
def process_rhs(self, compiler, connection):
85105
value = self.rhs
86106
if not self.get_db_prep_lookup_value_is_iterable:
@@ -95,111 +115,81 @@ def process_rhs(self, compiler, connection):
95115
]
96116

97117
def as_mql(self, compiler, connection):
98-
self.check_lhs()
99118
# Querying a subfield within the array elements (via nested KeyTransform).
100119
# Replicates MongoDB's implicit ANY-match by mapping over the array and applying
101120
# `$in` on the subfield.
102-
lhs_mql, inner_lhs_mql = process_lhs(self, compiler, connection)
121+
lhs_mql = process_lhs(self, compiler, connection)
122+
inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"]
103123
values = process_rhs(self, compiler, connection)
104-
return {
105-
"$anyElementTrue": {
106-
"$ifNull": [
107-
{
108-
"$map": {
109-
"input": lhs_mql,
110-
"as": "item",
111-
"in": connection.mongo_operators[self.lookup_name](
112-
inner_lhs_mql, values
113-
),
114-
}
115-
},
116-
[],
117-
]
118-
}
119-
}
124+
lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators[self.lookup_name](
125+
inner_lhs_mql, values
126+
)
127+
return {"$anyElementTrue": lhs_mql}
120128

121129

122-
@EmbeddedModelArrayField.register_lookup
130+
@_EmbeddedModelArrayOutputField.register_lookup
123131
class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In):
124132
pass
125133

126134

127-
@EmbeddedModelArrayField.register_lookup
135+
@_EmbeddedModelArrayOutputField.register_lookup
128136
class EmbeddedModelArrayFieldExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.Exact):
129137
pass
130138

131139

132-
@EmbeddedModelArrayField.register_lookup
140+
@_EmbeddedModelArrayOutputField.register_lookup
133141
class EmbeddedModelArrayFieldIExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.IExact):
134142
get_db_prep_lookup_value_is_iterable = False
135143

136144

137-
@EmbeddedModelArrayField.register_lookup
145+
@_EmbeddedModelArrayOutputField.register_lookup
138146
class EmbeddedModelArrayFieldGreaterThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThan):
139147
pass
140148

141149

142-
@EmbeddedModelArrayField.register_lookup
150+
@_EmbeddedModelArrayOutputField.register_lookup
143151
class EmbeddedModelArrayFieldGreaterThanOrEqual(
144152
EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThanOrEqual
145153
):
146154
pass
147155

148156

149-
@EmbeddedModelArrayField.register_lookup
157+
@_EmbeddedModelArrayOutputField.register_lookup
150158
class EmbeddedModelArrayFieldLessThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThan):
151159
pass
152160

153161

154-
@EmbeddedModelArrayField.register_lookup
162+
@_EmbeddedModelArrayOutputField.register_lookup
155163
class EmbeddedModelArrayFieldLessThanOrEqual(
156164
EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThanOrEqual
157165
):
158166
pass
159167

160168

161-
@EmbeddedModelArrayField.register_lookup
169+
@_EmbeddedModelArrayOutputField.register_lookup
162170
class EmbeddedModelArrayFieldAll(EmbeddedModelArrayFieldBuiltinLookup, Lookup):
163171
lookup_name = "all"
164172
get_db_prep_lookup_value_is_iterable = False
165173

166174
def as_mql(self, compiler, connection):
167-
self.check_lhs()
168-
lhs_mql, inner_lhs_mql = process_lhs(self, compiler, connection)
175+
lhs_mql = process_lhs(self, compiler, connection)
169176
values = process_rhs(self, compiler, connection)
170177
return {
171-
"$setIsSubset": [
172-
values,
173-
{
174-
"$ifNull": [
175-
{
176-
"$map": {
177-
"input": lhs_mql,
178-
"as": "item",
179-
"in": inner_lhs_mql,
180-
}
181-
},
182-
[],
183-
]
184-
},
178+
"$and": [
179+
{"$ne": [lhs_mql, None]},
180+
{"$ne": [values, None]},
181+
{"$setIsSubset": [values, lhs_mql]},
185182
]
186183
}
187184

188185

189-
@EmbeddedModelArrayField.register_lookup
186+
@_EmbeddedModelArrayOutputField.register_lookup
190187
class ArrayContainedBy(EmbeddedModelArrayFieldBuiltinLookup, Lookup):
191188
lookup_name = "contained_by"
192189
get_db_prep_lookup_value_is_iterable = False
193190

194191
def as_mql(self, compiler, connection):
195-
lhs_mql, inner_lhs_mql = process_lhs(self, compiler, connection)
196-
lhs_mql = {
197-
"$map": {
198-
"input": lhs_mql,
199-
"as": "item",
200-
"in": inner_lhs_mql,
201-
}
202-
}
192+
lhs_mql = process_lhs(self, compiler, connection)
203193
value = process_rhs(self, compiler, connection)
204194
return {
205195
"$and": [
@@ -244,7 +234,7 @@ def get_transform(self, name):
244234
self._sub_transform = transform
245235
return self
246236
output_field = self._lhs.output_field
247-
allowed_lookups = self.array_field.ALLOWED_LOOKUPS.intersection(
237+
allowed_lookups = self.output_field.ALLOWED_LOOKUPS.intersection(
248238
set(output_field.get_lookups())
249239
)
250240
suggested_lookups = difflib.get_close_matches(name, allowed_lookups)
@@ -262,11 +252,22 @@ def get_transform(self, name):
262252
def as_mql(self, compiler, connection):
263253
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
264254
lhs_mql = process_lhs(self, compiler, connection)
265-
return lhs_mql, inner_lhs_mql
255+
return {
256+
"$ifNull": [
257+
{
258+
"$map": {
259+
"input": lhs_mql,
260+
"as": "item",
261+
"in": inner_lhs_mql,
262+
}
263+
},
264+
[],
265+
]
266+
}
266267

267268
@property
268269
def output_field(self):
269-
return self.array_field
270+
return _EmbeddedModelArrayOutputField(self._lhs.output_field)
270271

271272

272273
class KeyTransformFactory:

tests/model_fields_/test_embedded_model_array.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_all_filter(self):
186186
def test_contained_by(self):
187187
self.assertCountEqual(
188188
MuseumExhibit.objects.filter(sections__section_number__contained_by=[1, 2, 3]),
189-
[self.egypt, self.new_descoveries, self.wonders],
189+
[self.egypt, self.new_descoveries, self.wonders, self.lost_empires],
190190
)
191191

192192
def test_len_filter(self):
@@ -258,12 +258,15 @@ def test_query_array_not_allowed(self):
258258
"Try querying one of its embedded fields instead."
259259
)
260260
with self.assertRaisesMessage(ValueError, msg):
261-
self.assertCountEqual(MuseumExhibit.objects.filter(sections=10), [])
261+
MuseumExhibit.objects.filter(sections=10).first()
262+
263+
with self.assertRaisesMessage(ValueError, msg):
264+
MuseumExhibit.objects.filter(sections__0_1=10).first()
262265

263266
def test_missing_field(self):
264267
msg = "ExhibitSection has no field named 'section'"
265268
with self.assertRaisesMessage(FieldDoesNotExist, msg):
266-
self.assertCountEqual(MuseumExhibit.objects.filter(sections__section__in=[10]), [])
269+
MuseumExhibit.objects.filter(sections__section__in=[10]).first()
267270

268271
def test_missing_lookup(self):
269272
msg = "Unsupported lookup 'return' for EmbeddedModelArrayField of 'IntegerField'"
@@ -273,9 +276,7 @@ def test_missing_lookup(self):
273276
def test_missing_operation(self):
274277
msg = "Unsupported lookup 'rage' for EmbeddedModelArrayField of 'IntegerField'"
275278
with self.assertRaisesMessage(FieldDoesNotExist, msg):
276-
self.assertCountEqual(
277-
MuseumExhibit.objects.filter(sections__section_number__rage=[10]), []
278-
)
279+
MuseumExhibit.objects.filter(sections__section_number__rage=[10])
279280

280281
def test_missing_lookup_suggestions(self):
281282
msg = (
@@ -290,6 +291,11 @@ def test_double_emfarray_transform(self):
290291
with self.assertRaisesMessage(ValueError, msg):
291292
MuseumExhibit.objects.filter(sections__artifacts__name="")
292293

294+
def test_slice(self):
295+
self.assertSequenceEqual(
296+
MuseumExhibit.objects.filter(sections__0_1__section_number=2), [self.new_descoveries]
297+
)
298+
293299

294300
@isolate_apps("model_fields_")
295301
class CheckTests(SimpleTestCase):

0 commit comments

Comments
 (0)