Skip to content

Commit fb35cad

Browse files
committed
Fix emf flow and add subquery unit test
1 parent c99a830 commit fb35cad

File tree

2 files changed

+33
-40
lines changed

2 files changed

+33
-40
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from django.core.exceptions import FieldDoesNotExist
44
from django.db import models
55
from django.db.models.expressions import Col
6-
from django.db.models.lookups import Transform
6+
from django.db.models.lookups import Lookup, Transform
77

88
from ..forms import EmbeddedModelArrayFormField
99
from ..query_utils import process_lhs, process_rhs
1010
from . import EmbeddedModelField
1111
from .array import ArrayField
12-
from .embedded_model import EMFExact
12+
from .embedded_model import EMFExact, EMFMixin
1313

1414

1515
class EmbeddedModelArrayField(ArrayField):
@@ -52,17 +52,8 @@ def get_transform(self, name):
5252
return KeyTransformFactory(name, self)
5353

5454

55-
class ProcessRHSMixin:
56-
def process_rhs(self, compiler, connection):
57-
if isinstance(self.lhs, KeyTransform):
58-
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
59-
else:
60-
get_db_prep_value = self.lhs.output_field.get_db_prep_value
61-
return None, [get_db_prep_value(v, connection, prepared=True) for v in self.rhs]
62-
63-
6455
@EmbeddedModelArrayField.register_lookup
65-
class EMFArrayExact(EMFExact, ProcessRHSMixin):
56+
class EMFArrayExact(EMFExact):
6657
def as_mql(self, compiler, connection):
6758
lhs_mql = process_lhs(self, compiler, connection)
6859
value = process_rhs(self, compiler, connection)
@@ -105,12 +96,29 @@ def as_mql(self, compiler, connection):
10596

10697

10798
@EmbeddedModelArrayField.register_lookup
108-
class ArrayOverlap(EMFExact, ProcessRHSMixin):
99+
class ArrayOverlap(EMFMixin, Lookup):
109100
lookup_name = "overlap"
101+
get_db_prep_lookup_value_is_iterable = True
102+
103+
def process_rhs(self, compiler, connection):
104+
values = self.rhs
105+
if self.get_db_prep_lookup_value_is_iterable:
106+
values = [values]
107+
# Compute how to serialize each value based on the query target.
108+
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
109+
# field of the subfield. Otherwise, use the base field of the array itself.
110+
if isinstance(self.lhs, KeyTransform):
111+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
112+
else:
113+
get_db_prep_value = self.lhs.output_field.base_field.get_db_prep_value
114+
return None, [get_db_prep_value(v, connection, prepared=True) for v in values]
110115

111116
def as_mql(self, compiler, connection):
112117
lhs_mql = process_lhs(self, compiler, connection)
113118
values = process_rhs(self, compiler, connection)
119+
# Querying a subfield within the array elements (via nested KeyTransform).
120+
# Replicates MongoDB's implicit ANY-match by mapping over the array and applying
121+
# `$in` on the subfield.
114122
if isinstance(self.lhs, KeyTransform):
115123
lhs_mql, inner_lhs_mql = lhs_mql
116124
return {
@@ -129,11 +137,12 @@ def as_mql(self, compiler, connection):
129137
}
130138
conditions = []
131139
inner_lhs_mql = "$$item"
140+
# Querying full embedded documents in the array.
141+
# Builds `$or` conditions and maps them over the array to match any full document.
132142
for value in values:
133-
if isinstance(value, models.Model):
134-
value, emf_data = self.model_to_dict(value)
135-
# Get conditions for any nested EmbeddedModelFields.
136-
conditions.append({"$and": self.get_conditions({inner_lhs_mql: (value, emf_data)})})
143+
value, emf_data = self.model_to_dict(value)
144+
# Get conditions for any nested EmbeddedModelFields.
145+
conditions.append({"$and": self.get_conditions({inner_lhs_mql: (value, emf_data)})})
137146
return {
138147
"$anyElementTrue": {
139148
"$ifNull": [

tests/model_fields_/test_embedded_model.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -300,36 +300,20 @@ def test_overlap_emf(self):
300300
[self.clouds],
301301
)
302302

303-
"""
304-
def test_overlap_charfield_including_expression(self):
305-
obj_1 = CharArrayModel.objects.create(field=["TEXT", "lower text"])
306-
obj_2 = CharArrayModel.objects.create(field=["lower text", "TEXT"])
307-
CharArrayModel.objects.create(field=["lower text", "text"])
308-
self.assertSequenceEqual(
309-
CharArrayModel.objects.filter(
310-
field__overlap=[
311-
Upper(Value("text")),
312-
"other",
313-
]
314-
),
315-
[obj_1, obj_2],
316-
)
317-
318303
def test_overlap_values(self):
319-
qs = NullableIntegerArrayModel.objects.filter(order__lt=3)
304+
qs = Movie.objects.filter(title__in=["Clouds", "Frozen"])
320305
self.assertCountEqual(
321-
NullableIntegerArrayModel.objects.filter(
322-
field__overlap=qs.values_list("field"),
306+
Movie.objects.filter(
307+
reviews__overlap=qs.values_list("reviews"),
323308
),
324-
self.objs[:3],
309+
[self.clouds, self.frozen],
325310
)
326311
self.assertCountEqual(
327-
NullableIntegerArrayModel.objects.filter(
328-
field__overlap=qs.values("field"),
312+
Movie.objects.filter(
313+
reviews__overlap=qs.values("reviews"),
329314
),
330-
self.objs[:3],
315+
[self.clouds, self.frozen],
331316
)
332-
"""
333317

334318

335319
class QueryingTests(TestCase):

0 commit comments

Comments
 (0)