diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index f10e3f86..77e91f80 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -143,7 +143,50 @@ def as_mql(self, compiler, connection): @_EmbeddedModelArrayOutputField.register_lookup class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In): - pass + def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): + # This pipeline is adapted from that of ArrayField, because the + # structure of EmbeddedModelArrayField on the RHS behaves similar to + # ArrayField. + return [ + { + "$facet": { + "gathered_data": [ + {"$project": {"tmp_name": expr.as_mql(compiler, connection)}}, + # To concatenate all the values from the RHS subquery, + # use an $unwind followed by a $group. + { + "$unwind": "$tmp_name", + }, + # The $group stage collects values into an array using + # $addToSet. The use of {_id: null} results in a + # single grouped array. However, because arrays from + # multiple documents are aggregated, the result is a + # list of lists. + { + "$group": { + "_id": None, + "tmp_name": {"$addToSet": "$tmp_name"}, + } + }, + ] + } + }, + { + "$project": { + field_name: { + "$ifNull": [ + { + "$getField": { + "input": {"$arrayElemAt": ["$gathered_data", 0]}, + "field": "tmp_name", + } + }, + [], + ] + } + } + }, + ] @_EmbeddedModelArrayOutputField.register_lookup diff --git a/docs/source/releases/5.2.x.rst b/docs/source/releases/5.2.x.rst index de4b6efc..cf7b2f1b 100644 --- a/docs/source/releases/5.2.x.rst +++ b/docs/source/releases/5.2.x.rst @@ -2,6 +2,16 @@ Django MongoDB Backend 5.2.x ============================ +5.2.0 beta 2 +============ + +*Unreleased* + +New features +------------ + +- Added subquery support for :class:`~.fields.EmbeddedModelArrayField`. + 5.2.0 beta 1 ============ diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index 3d3a1584..43522565 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -169,21 +169,21 @@ def __str__(self): # An exhibit in the museum, composed of multiple sections. class Exhibit(models.Model): - exhibit_name = models.CharField(max_length=255) + name = models.CharField(max_length=255) sections = EmbeddedModelArrayField("Section", null=True) main_section = EmbeddedModelField("Section", null=True) def __str__(self): - return self.exhibit_name + return self.name # A section within an exhibit, containing multiple artifacts. class Section(EmbeddedModel): - section_number = models.IntegerField() + number = models.IntegerField() artifacts = EmbeddedModelArrayField("Artifact", null=True) def __str__(self): - return "Section %d" % self.section_number + return "Section %d" % self.number # Details about a specific artifact. @@ -206,6 +206,15 @@ def __str__(self): return f"Restored by {self.restored_by} on {self.date}" +# An audit of a section in the museum. +class Audit(models.Model): + section_number = models.IntegerField() + reviewed = models.BooleanField() + + def __str__(self): + return f"Section {self.section_number} audit" + + # ForeignKey to a model with EmbeddedModelArrayField. class Tour(models.Model): guide = models.CharField(max_length=100) diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py index caab244b..81dba503 100644 --- a/tests/model_fields_/test_embedded_model_array.py +++ b/tests/model_fields_/test_embedded_model_array.py @@ -1,14 +1,17 @@ +import unittest from datetime import date +from operator import attrgetter from django.core.exceptions import FieldDoesNotExist from django.db import connection, models +from django.db.models.expressions import Value from django.test import SimpleTestCase, TestCase from django.test.utils import CaptureQueriesContext, isolate_apps -from django_mongodb_backend.fields import EmbeddedModelArrayField +from django_mongodb_backend.fields import ArrayField, EmbeddedModelArrayField from django_mongodb_backend.models import EmbeddedModel -from .models import Artifact, Exhibit, Movie, Restoration, Review, Section, Tour +from .models import Artifact, Audit, Exhibit, Movie, Restoration, Review, Section, Tour class MethodTests(SimpleTestCase): @@ -62,10 +65,10 @@ class QueryingTests(TestCase): @classmethod def setUpTestData(cls): cls.egypt = Exhibit.objects.create( - exhibit_name="Ancient Egypt", + name="Ancient Egypt", sections=[ Section( - section_number=1, + number=1, artifacts=[ Artifact( name="Ptolemaic Crown", @@ -78,10 +81,10 @@ def setUpTestData(cls): ], ) cls.wonders = Exhibit.objects.create( - exhibit_name="Wonders of the Ancient World", + name="Wonders of the Ancient World", sections=[ Section( - section_number=1, + number=1, artifacts=[ Artifact( name="Statue of Zeus", @@ -93,7 +96,7 @@ def setUpTestData(cls): ], ), Section( - section_number=2, + number=2, artifacts=[ Artifact( name="Lighthouse of Alexandria", @@ -104,10 +107,10 @@ def setUpTestData(cls): ], ) cls.new_descoveries = Exhibit.objects.create( - exhibit_name="New Discoveries", + name="New Discoveries", sections=[ Section( - section_number=2, + number=2, artifacts=[ Artifact( name="Lighthouse of Alexandria", @@ -116,11 +119,12 @@ def setUpTestData(cls): ], ) ], + main_section=Section(number=2), ) cls.lost_empires = Exhibit.objects.create( - exhibit_name="Lost Empires", + name="Lost Empires", main_section=Section( - section_number=3, + number=3, artifacts=[ Artifact( name="Bronze Statue", @@ -146,15 +150,18 @@ def setUpTestData(cls): cls.egypt_tour = Tour.objects.create(guide="Amira", exhibit=cls.egypt) cls.wonders_tour = Tour.objects.create(guide="Carlos", exhibit=cls.wonders) cls.lost_tour = Tour.objects.create(guide="Yelena", exhibit=cls.lost_empires) + cls.audit_1 = Audit.objects.create(section_number=1, reviewed=True) + cls.audit_2 = Audit.objects.create(section_number=2, reviewed=True) + cls.audit_3 = Audit.objects.create(section_number=5, reviewed=False) def test_exact(self): self.assertCountEqual( - Exhibit.objects.filter(sections__section_number=1), [self.egypt, self.wonders] + Exhibit.objects.filter(sections__number=1), [self.egypt, self.wonders] ) def test_array_index(self): self.assertCountEqual( - Exhibit.objects.filter(sections__0__section_number=1), + Exhibit.objects.filter(sections__0__number=1), [self.egypt, self.wonders], ) @@ -168,7 +175,7 @@ def test_nested_array_index(self): def test_array_slice(self): self.assertSequenceEqual( - Exhibit.objects.filter(sections__0_1__section_number=2), [self.new_descoveries] + Exhibit.objects.filter(sections__0_1__number=2), [self.new_descoveries] ) def test_filter_unsupported_lookups_in_json(self): @@ -196,16 +203,16 @@ def test_len(self): self.assertCountEqual(Exhibit.objects.filter(sections__1__artifacts__len=1), [self.wonders]) def test_in(self): - self.assertCountEqual(Exhibit.objects.filter(sections__section_number__in=[10]), []) + self.assertCountEqual(Exhibit.objects.filter(sections__number__in=[10]), []) self.assertCountEqual( - Exhibit.objects.filter(sections__section_number__in=[1]), + Exhibit.objects.filter(sections__number__in=[1]), [self.egypt, self.wonders], ) self.assertCountEqual( - Exhibit.objects.filter(sections__section_number__in=[2]), + Exhibit.objects.filter(sections__number__in=[2]), [self.new_descoveries, self.wonders], ) - self.assertCountEqual(Exhibit.objects.filter(sections__section_number__in=[3]), []) + self.assertCountEqual(Exhibit.objects.filter(sections__number__in=[3]), []) def test_iexact(self): self.assertCountEqual( @@ -215,24 +222,24 @@ def test_iexact(self): def test_gt(self): self.assertCountEqual( - Exhibit.objects.filter(sections__section_number__gt=1), + Exhibit.objects.filter(sections__number__gt=1), [self.new_descoveries, self.wonders], ) def test_gte(self): self.assertCountEqual( - Exhibit.objects.filter(sections__section_number__gte=1), + Exhibit.objects.filter(sections__number__gte=1), [self.egypt, self.new_descoveries, self.wonders], ) def test_lt(self): self.assertCountEqual( - Exhibit.objects.filter(sections__section_number__lt=2), [self.egypt, self.wonders] + Exhibit.objects.filter(sections__number__lt=2), [self.egypt, self.wonders] ) def test_lte(self): self.assertCountEqual( - Exhibit.objects.filter(sections__section_number__lte=2), + Exhibit.objects.filter(sections__number__lte=2), [self.egypt, self.wonders, self.new_descoveries], ) @@ -255,12 +262,12 @@ def test_invalid_field(self): def test_invalid_lookup(self): msg = "Unsupported lookup 'return' for EmbeddedModelArrayField of 'IntegerField'" with self.assertRaisesMessage(FieldDoesNotExist, msg): - Exhibit.objects.filter(sections__section_number__return=3) + Exhibit.objects.filter(sections__number__return=3) def test_invalid_operation(self): msg = "Unsupported lookup 'rage' for EmbeddedModelArrayField of 'IntegerField'" with self.assertRaisesMessage(FieldDoesNotExist, msg): - Exhibit.objects.filter(sections__section_number__rage=[10]) + Exhibit.objects.filter(sections__number__rage=[10]) def test_missing_lookup_suggestions(self): msg = ( @@ -268,7 +275,7 @@ def test_missing_lookup_suggestions(self): "perhaps you meant lte or lt?" ) with self.assertRaisesMessage(FieldDoesNotExist, msg): - Exhibit.objects.filter(sections__section_number__ltee=3) + Exhibit.objects.filter(sections__number__ltee=3) def test_nested_lookup(self): msg = "Cannot perform multiple levels of array traversal in a query." @@ -277,13 +284,78 @@ def test_nested_lookup(self): def test_foreign_field_exact(self): """Querying from a foreign key to an EmbeddedModelArrayField.""" - qs = Tour.objects.filter(exhibit__sections__section_number=1) + qs = Tour.objects.filter(exhibit__sections__number=1) self.assertCountEqual(qs, [self.egypt_tour, self.wonders_tour]) def test_foreign_field_with_slice(self): - qs = Tour.objects.filter(exhibit__sections__0_2__section_number__in=[1, 2]) + qs = Tour.objects.filter(exhibit__sections__0_2__number__in=[1, 2]) self.assertCountEqual(qs, [self.wonders_tour, self.egypt_tour]) + def test_subquery_numeric_lookups(self): + subquery = Audit.objects.filter( + section_number__in=models.OuterRef("sections__number") + ).values("section_number")[:1] + tests = [ + ("exact", [self.egypt, self.new_descoveries, self.wonders]), + ("lt", []), + ("lte", [self.egypt, self.new_descoveries, self.wonders]), + ("gt", [self.wonders]), + ("gte", [self.egypt, self.new_descoveries, self.wonders]), + ] + for lookup, expected in tests: + with self.subTest(lookup=lookup): + kwargs = {f"sections__number__{lookup}": subquery} + self.assertCountEqual(Exhibit.objects.filter(**kwargs), expected) + + def test_subquery_in_lookup(self): + subquery = Audit.objects.filter(reviewed=True).values_list("section_number", flat=True) + result = Exhibit.objects.filter(sections__number__in=subquery) + self.assertCountEqual(result, [self.wonders, self.new_descoveries, self.egypt]) + + def test_array_as_rhs(self): + result = Exhibit.objects.filter(main_section__number__in=models.F("sections__number")) + self.assertCountEqual(result, [self.new_descoveries]) + + def test_array_annotation_lookup(self): + result = Exhibit.objects.annotate(section_numbers=models.F("main_section__number")).filter( + section_numbers__in=models.F("sections__number") + ) + self.assertCountEqual(result, [self.new_descoveries]) + + def test_array_as_rhs_for_arrayfield_lookups(self): + tests = [ + ("exact", [self.wonders]), + ("lt", [self.new_descoveries]), + ("lte", [self.wonders, self.new_descoveries]), + ("gt", [self.egypt, self.lost_empires]), + ("gte", [self.egypt, self.wonders, self.lost_empires]), + ("overlap", [self.egypt, self.wonders, self.new_descoveries]), + ("contained_by", [self.wonders]), + ("contains", [self.egypt, self.wonders, self.new_descoveries, self.lost_empires]), + ] + for lookup, expected in tests: + with self.subTest(lookup=lookup): + kwargs = {f"section_numbers__{lookup}": models.F("sections__number")} + result = Exhibit.objects.annotate( + section_numbers=Value( + [1, 2], output_field=ArrayField(base_field=models.IntegerField()) + ) + ).filter(**kwargs) + self.assertCountEqual(result, expected) + + @unittest.expectedFailure + def test_array_annotation_index(self): + # Slicing and indexing over an annotated EmbeddedModelArrayField would + # require a refactor of annotation handling. + result = Exhibit.objects.annotate(section_numbers=models.F("sections__number")).filter( + section_numbers__0=1 + ) + self.assertCountEqual(result, [self.new_descoveries, self.egypt]) + + def test_array_annotation(self): + qs = Exhibit.objects.annotate(section_numbers=models.F("sections__number")).order_by("name") + self.assertQuerySetEqual(qs, [[1], [], [2], [1, 2]], attrgetter("section_numbers")) + @isolate_apps("model_fields_") class CheckTests(SimpleTestCase):