Skip to content

Commit c99a830

Browse files
committed
Adding support for overlap
1 parent 22ee9ec commit c99a830

File tree

3 files changed

+115
-11
lines changed

3 files changed

+115
-11
lines changed

django_mongodb_backend/fields/array.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,7 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
326326
def as_mql(self, compiler, connection):
327327
lhs_mql = process_lhs(self, compiler, connection)
328328
value = process_rhs(self, compiler, connection)
329-
return {
330-
"$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}]
331-
}
329+
return {"$and": [{"$isArray": lhs_mql}, {"$size": {"$setIntersection": [value, lhs_mql]}}]}
332330

333331

334332
@ArrayField.register_lookup

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,20 @@ def get_transform(self, name):
4949
transform = super().get_transform(name)
5050
if transform:
5151
return transform
52-
return KeyTransformFactory(name, self.base_field)
52+
return KeyTransformFactory(name, self)
53+
54+
55+
class ProcessRHSMixin:
56+
def process_rhs(self, compiler, connection):
57+
if isinstance(self.lhs, KeyTransform):
58+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
59+
else:
60+
get_db_prep_value = self.lhs.output_field.get_db_prep_value
61+
return None, [get_db_prep_value(v, connection, prepared=True) for v in self.rhs]
5362

5463

5564
@EmbeddedModelArrayField.register_lookup
56-
class EMFArrayExact(EMFExact):
65+
class EMFArrayExact(EMFExact, ProcessRHSMixin):
5766
def as_mql(self, compiler, connection):
5867
lhs_mql = process_lhs(self, compiler, connection)
5968
value = process_rhs(self, compiler, connection)
@@ -95,15 +104,61 @@ def as_mql(self, compiler, connection):
95104
}
96105

97106

107+
@EmbeddedModelArrayField.register_lookup
108+
class ArrayOverlap(EMFExact, ProcessRHSMixin):
109+
lookup_name = "overlap"
110+
111+
def as_mql(self, compiler, connection):
112+
lhs_mql = process_lhs(self, compiler, connection)
113+
values = process_rhs(self, compiler, connection)
114+
if isinstance(self.lhs, KeyTransform):
115+
lhs_mql, inner_lhs_mql = lhs_mql
116+
return {
117+
"$anyElementTrue": {
118+
"$ifNull": [
119+
{
120+
"$map": {
121+
"input": lhs_mql,
122+
"as": "item",
123+
"in": {"$in": [inner_lhs_mql, values]},
124+
}
125+
},
126+
[],
127+
]
128+
}
129+
}
130+
conditions = []
131+
inner_lhs_mql = "$$item"
132+
for value in values:
133+
if isinstance(value, models.Model):
134+
value, emf_data = self.model_to_dict(value)
135+
# Get conditions for any nested EmbeddedModelFields.
136+
conditions.append({"$and": self.get_conditions({inner_lhs_mql: (value, emf_data)})})
137+
return {
138+
"$anyElementTrue": {
139+
"$ifNull": [
140+
{
141+
"$map": {
142+
"input": lhs_mql,
143+
"as": "item",
144+
"in": {"$or": conditions},
145+
}
146+
},
147+
[],
148+
]
149+
}
150+
}
151+
152+
98153
class KeyTransform(Transform):
99154
# it should be different class than EMF keytransform even most of the methods are equal.
100-
def __init__(self, key_name, base_field, *args, **kwargs):
155+
def __init__(self, key_name, array_field, *args, **kwargs):
101156
super().__init__(*args, **kwargs)
102-
self.base_field = base_field
157+
self.array_field = array_field
103158
self.key_name = key_name
104159
# The iteration items begins from the base_field, a virtual column with
105160
# base field output type is created.
106-
column_target = base_field.clone()
161+
column_target = array_field.base_field.embedded_model._meta.get_field(key_name).clone()
107162
column_name = f"$item.{key_name}"
108163
column_target.db_column = column_name
109164
column_target.set_attributes_from_name(column_name)
@@ -126,7 +181,7 @@ def _get_missing_field_or_lookup_exception(self, lhs, name):
126181
suggestion = "."
127182
raise FieldDoesNotExist(
128183
f"Unsupported lookup '{name}' for "
129-
f"{self.base_field.__class__.__name__} '{self.base_field.name}'"
184+
f"{self.array_field.base_field.__class__.__name__} '{self.array_field.base_field.name}'"
130185
f"{suggestion}"
131186
)
132187

@@ -139,7 +194,9 @@ def get_transform(self, name):
139194
transform = (
140195
self._lhs.get_transform(name)
141196
if isinstance(self._lhs, Transform)
142-
else self.base_field.embedded_model._meta.get_field(self.key_name).get_transform(name)
197+
else self.array_field.base_field.embedded_model._meta.get_field(
198+
self.key_name
199+
).get_transform(name)
143200
)
144201
if transform:
145202
self._sub_transform = transform
@@ -155,7 +212,7 @@ def as_mql(self, compiler, connection):
155212

156213
@property
157214
def output_field(self):
158-
return EmbeddedModelArrayField(self.base_field)
215+
return self.array_field
159216

160217

161218
class KeyTransformFactory:

tests/model_fields_/test_embedded_model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,55 @@ def test_len(self):
282282
MuseumExhibit.objects.filter(sections__1__artifacts__len=1), [self.wonders]
283283
)
284284

285+
def test_overlap_simplefield(self):
286+
self.assertSequenceEqual(
287+
MuseumExhibit.objects.filter(sections__section_number__overlap=[10]), []
288+
)
289+
self.assertSequenceEqual(
290+
MuseumExhibit.objects.filter(sections__section_number__overlap=[1]),
291+
[self.egypt, self.wonders, self.new_descoveries],
292+
)
293+
self.assertSequenceEqual(
294+
MuseumExhibit.objects.filter(sections__section_number__overlap=[2]), [self.wonders]
295+
)
296+
297+
def test_overlap_emf(self):
298+
self.assertSequenceEqual(
299+
Movie.objects.filter(reviews__overlap=[Review(title="The best", rating=10)]),
300+
[self.clouds],
301+
)
302+
303+
"""
304+
def test_overlap_charfield_including_expression(self):
305+
obj_1 = CharArrayModel.objects.create(field=["TEXT", "lower text"])
306+
obj_2 = CharArrayModel.objects.create(field=["lower text", "TEXT"])
307+
CharArrayModel.objects.create(field=["lower text", "text"])
308+
self.assertSequenceEqual(
309+
CharArrayModel.objects.filter(
310+
field__overlap=[
311+
Upper(Value("text")),
312+
"other",
313+
]
314+
),
315+
[obj_1, obj_2],
316+
)
317+
318+
def test_overlap_values(self):
319+
qs = NullableIntegerArrayModel.objects.filter(order__lt=3)
320+
self.assertCountEqual(
321+
NullableIntegerArrayModel.objects.filter(
322+
field__overlap=qs.values_list("field"),
323+
),
324+
self.objs[:3],
325+
)
326+
self.assertCountEqual(
327+
NullableIntegerArrayModel.objects.filter(
328+
field__overlap=qs.values("field"),
329+
),
330+
self.objs[:3],
331+
)
332+
"""
333+
285334

286335
class QueryingTests(TestCase):
287336
@classmethod

0 commit comments

Comments
 (0)