Skip to content

Commit 783aa8a

Browse files
WaVEVtimgraham
authored andcommitted
Add subquery support for EmbeddedModelArrayField
1 parent 00e8810 commit 783aa8a

File tree

4 files changed

+137
-3
lines changed

4 files changed

+137
-3
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,50 @@ def as_mql(self, compiler, connection):
143143

144144
@_EmbeddedModelArrayOutputField.register_lookup
145145
class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In):
146-
pass
146+
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr):
147+
# This pipeline is adapted from that of ArrayField, because the
148+
# structure of EmbeddedModelArrayField on the RHS behaves similar to
149+
# ArrayField.
150+
return [
151+
{
152+
"$facet": {
153+
"gathered_data": [
154+
{"$project": {"tmp_name": expr.as_mql(compiler, connection)}},
155+
# To concatenate all the values from the RHS subquery,
156+
# use an $unwind followed by a $group.
157+
{
158+
"$unwind": "$tmp_name",
159+
},
160+
# The $group stage collects values into an array using
161+
# $addToSet. The use of {_id: null} results in a
162+
# single grouped array. However, because arrays from
163+
# multiple documents are aggregated, the result is a
164+
# list of lists.
165+
{
166+
"$group": {
167+
"_id": None,
168+
"tmp_name": {"$addToSet": "$tmp_name"},
169+
}
170+
},
171+
]
172+
}
173+
},
174+
{
175+
"$project": {
176+
field_name: {
177+
"$ifNull": [
178+
{
179+
"$getField": {
180+
"input": {"$arrayElemAt": ["$gathered_data", 0]},
181+
"field": "tmp_name",
182+
}
183+
},
184+
[],
185+
]
186+
}
187+
}
188+
},
189+
]
147190

148191

149192
@_EmbeddedModelArrayOutputField.register_lookup

docs/source/releases/5.2.x.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22
Django MongoDB Backend 5.2.x
33
============================
44

5+
5.2.0 beta 2
6+
============
7+
8+
*Unreleased*
9+
10+
New features
11+
------------
12+
13+
- Added subquery support for :class:`~.fields.EmbeddedModelArrayField`.
14+
515
5.2.0 beta 1
616
============
717

tests/model_fields_/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,15 @@ def __str__(self):
206206
return f"Restored by {self.restored_by} on {self.date}"
207207

208208

209+
# An audit of a section in the museum.
210+
class Audit(models.Model):
211+
section_number = models.IntegerField()
212+
reviewed = models.BooleanField()
213+
214+
def __str__(self):
215+
return f"Section {self.section_number} audit"
216+
217+
209218
# ForeignKey to a model with EmbeddedModelArrayField.
210219
class Tour(models.Model):
211220
guide = models.CharField(max_length=100)

tests/model_fields_/test_embedded_model_array.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
import unittest
12
from datetime import date
3+
from operator import attrgetter
24

35
from django.core.exceptions import FieldDoesNotExist
46
from django.db import connection, models
7+
from django.db.models.expressions import Value
58
from django.test import SimpleTestCase, TestCase
69
from django.test.utils import CaptureQueriesContext, isolate_apps
710

8-
from django_mongodb_backend.fields import EmbeddedModelArrayField
11+
from django_mongodb_backend.fields import ArrayField, EmbeddedModelArrayField
912
from django_mongodb_backend.models import EmbeddedModel
1013

11-
from .models import Artifact, Exhibit, Movie, Restoration, Review, Section, Tour
14+
from .models import Artifact, Audit, Exhibit, Movie, Restoration, Review, Section, Tour
1215

1316

1417
class MethodTests(SimpleTestCase):
@@ -116,6 +119,7 @@ def setUpTestData(cls):
116119
],
117120
)
118121
],
122+
main_section=Section(number=2),
119123
)
120124
cls.lost_empires = Exhibit.objects.create(
121125
name="Lost Empires",
@@ -146,6 +150,9 @@ def setUpTestData(cls):
146150
cls.egypt_tour = Tour.objects.create(guide="Amira", exhibit=cls.egypt)
147151
cls.wonders_tour = Tour.objects.create(guide="Carlos", exhibit=cls.wonders)
148152
cls.lost_tour = Tour.objects.create(guide="Yelena", exhibit=cls.lost_empires)
153+
cls.audit_1 = Audit.objects.create(section_number=1, reviewed=True)
154+
cls.audit_2 = Audit.objects.create(section_number=2, reviewed=True)
155+
cls.audit_3 = Audit.objects.create(section_number=5, reviewed=False)
149156

150157
def test_exact(self):
151158
self.assertCountEqual(
@@ -284,6 +291,71 @@ def test_foreign_field_with_slice(self):
284291
qs = Tour.objects.filter(exhibit__sections__0_2__number__in=[1, 2])
285292
self.assertCountEqual(qs, [self.wonders_tour, self.egypt_tour])
286293

294+
def test_subquery_numeric_lookups(self):
295+
subquery = Audit.objects.filter(
296+
section_number__in=models.OuterRef("sections__number")
297+
).values("section_number")[:1]
298+
tests = [
299+
("exact", [self.egypt, self.new_descoveries, self.wonders]),
300+
("lt", []),
301+
("lte", [self.egypt, self.new_descoveries, self.wonders]),
302+
("gt", [self.wonders]),
303+
("gte", [self.egypt, self.new_descoveries, self.wonders]),
304+
]
305+
for lookup, expected in tests:
306+
with self.subTest(lookup=lookup):
307+
kwargs = {f"sections__number__{lookup}": subquery}
308+
self.assertCountEqual(Exhibit.objects.filter(**kwargs), expected)
309+
310+
def test_subquery_in_lookup(self):
311+
subquery = Audit.objects.filter(reviewed=True).values_list("section_number", flat=True)
312+
result = Exhibit.objects.filter(sections__number__in=subquery)
313+
self.assertCountEqual(result, [self.wonders, self.new_descoveries, self.egypt])
314+
315+
def test_array_as_rhs(self):
316+
result = Exhibit.objects.filter(main_section__number__in=models.F("sections__number"))
317+
self.assertCountEqual(result, [self.new_descoveries])
318+
319+
def test_array_annotation_lookup(self):
320+
result = Exhibit.objects.annotate(section_numbers=models.F("main_section__number")).filter(
321+
section_numbers__in=models.F("sections__number")
322+
)
323+
self.assertCountEqual(result, [self.new_descoveries])
324+
325+
def test_array_as_rhs_for_arrayfield_lookups(self):
326+
tests = [
327+
("exact", [self.wonders]),
328+
("lt", [self.new_descoveries]),
329+
("lte", [self.wonders, self.new_descoveries]),
330+
("gt", [self.egypt, self.lost_empires]),
331+
("gte", [self.egypt, self.wonders, self.lost_empires]),
332+
("overlap", [self.egypt, self.wonders, self.new_descoveries]),
333+
("contained_by", [self.wonders]),
334+
("contains", [self.egypt, self.wonders, self.new_descoveries, self.lost_empires]),
335+
]
336+
for lookup, expected in tests:
337+
with self.subTest(lookup=lookup):
338+
kwargs = {f"section_numbers__{lookup}": models.F("sections__number")}
339+
result = Exhibit.objects.annotate(
340+
section_numbers=Value(
341+
[1, 2], output_field=ArrayField(base_field=models.IntegerField())
342+
)
343+
).filter(**kwargs)
344+
self.assertCountEqual(result, expected)
345+
346+
@unittest.expectedFailure
347+
def test_array_annotation_index(self):
348+
# Slicing and indexing over an annotated EmbeddedModelArrayField would
349+
# require a refactor of annotation handling.
350+
result = Exhibit.objects.annotate(section_numbers=models.F("sections__number")).filter(
351+
section_numbers__0=1
352+
)
353+
self.assertCountEqual(result, [self.new_descoveries, self.egypt])
354+
355+
def test_array_annotation(self):
356+
qs = Exhibit.objects.annotate(section_numbers=models.F("sections__number")).order_by("name")
357+
self.assertQuerySetEqual(qs, [[1], [], [2], [1, 2]], attrgetter("section_numbers"))
358+
287359

288360
@isolate_apps("model_fields_")
289361
class CheckTests(SimpleTestCase):

0 commit comments

Comments
 (0)