Skip to content

Add subquery support for EmbeddedModelArrayField #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion django_mongodb_backend/fields/embedded_model_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,50 @@ def as_mql(self, compiler, connection):

@_EmbeddedModelArrayOutputField.register_lookup
class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In):
pass
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr):
# This pipeline is adapted from that of ArrayField, because the
# structure of EmbeddedModelArrayField on the RHS behaves similar to
# ArrayField.
return [
{
"$facet": {
"gathered_data": [
{"$project": {"tmp_name": expr.as_mql(compiler, connection)}},
# To concatenate all the values from the RHS subquery,
# use an $unwind followed by a $group.
{
"$unwind": "$tmp_name",
},
# The $group stage collects values into an array using
# $addToSet. The use of {_id: null} results in a
# single grouped array. However, because arrays from
# multiple documents are aggregated, the result is a
# list of lists.
{
"$group": {
"_id": None,
"tmp_name": {"$addToSet": "$tmp_name"},
}
},
]
}
},
{
"$project": {
field_name: {
"$ifNull": [
{
"$getField": {
"input": {"$arrayElemAt": ["$gathered_data", 0]},
"field": "tmp_name",
}
},
[],
]
}
}
},
]


@_EmbeddedModelArrayOutputField.register_lookup
Expand Down
10 changes: 10 additions & 0 deletions docs/source/releases/5.2.x.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
Django MongoDB Backend 5.2.x
============================

5.2.0 beta 2
============

*Unreleased*

New features
------------

- Added subquery support for :class:`~.fields.EmbeddedModelArrayField`.

5.2.0 beta 1
============

Expand Down
17 changes: 13 additions & 4 deletions tests/model_fields_/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,21 +169,21 @@ def __str__(self):

# An exhibit in the museum, composed of multiple sections.
class Exhibit(models.Model):
exhibit_name = models.CharField(max_length=255)
name = models.CharField(max_length=255)
sections = EmbeddedModelArrayField("Section", null=True)
main_section = EmbeddedModelField("Section", null=True)

def __str__(self):
return self.exhibit_name
return self.name


# A section within an exhibit, containing multiple artifacts.
class Section(EmbeddedModel):
section_number = models.IntegerField()
number = models.IntegerField()
artifacts = EmbeddedModelArrayField("Artifact", null=True)

def __str__(self):
return "Section %d" % self.section_number
return "Section %d" % self.number


# Details about a specific artifact.
Expand All @@ -206,6 +206,15 @@ def __str__(self):
return f"Restored by {self.restored_by} on {self.date}"


# An audit of a section in the museum.
class Audit(models.Model):
section_number = models.IntegerField()
reviewed = models.BooleanField()

def __str__(self):
return f"Section {self.section_number} audit"


# ForeignKey to a model with EmbeddedModelArrayField.
class Tour(models.Model):
guide = models.CharField(max_length=100)
Expand Down
126 changes: 99 additions & 27 deletions tests/model_fields_/test_embedded_model_array.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import unittest
from datetime import date
from operator import attrgetter

from django.core.exceptions import FieldDoesNotExist
from django.db import connection, models
from django.db.models.expressions import Value
from django.test import SimpleTestCase, TestCase
from django.test.utils import CaptureQueriesContext, isolate_apps

from django_mongodb_backend.fields import EmbeddedModelArrayField
from django_mongodb_backend.fields import ArrayField, EmbeddedModelArrayField
from django_mongodb_backend.models import EmbeddedModel

from .models import Artifact, Exhibit, Movie, Restoration, Review, Section, Tour
from .models import Artifact, Audit, Exhibit, Movie, Restoration, Review, Section, Tour


class MethodTests(SimpleTestCase):
Expand Down Expand Up @@ -62,10 +65,10 @@ class QueryingTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.egypt = Exhibit.objects.create(
exhibit_name="Ancient Egypt",
name="Ancient Egypt",
sections=[
Section(
section_number=1,
number=1,
artifacts=[
Artifact(
name="Ptolemaic Crown",
Expand All @@ -78,10 +81,10 @@ def setUpTestData(cls):
],
)
cls.wonders = Exhibit.objects.create(
exhibit_name="Wonders of the Ancient World",
name="Wonders of the Ancient World",
sections=[
Section(
section_number=1,
number=1,
artifacts=[
Artifact(
name="Statue of Zeus",
Expand All @@ -93,7 +96,7 @@ def setUpTestData(cls):
],
),
Section(
section_number=2,
number=2,
artifacts=[
Artifact(
name="Lighthouse of Alexandria",
Expand All @@ -104,10 +107,10 @@ def setUpTestData(cls):
],
)
cls.new_descoveries = Exhibit.objects.create(
exhibit_name="New Discoveries",
name="New Discoveries",
sections=[
Section(
section_number=2,
number=2,
artifacts=[
Artifact(
name="Lighthouse of Alexandria",
Expand All @@ -116,11 +119,12 @@ def setUpTestData(cls):
],
)
],
main_section=Section(number=2),
)
cls.lost_empires = Exhibit.objects.create(
exhibit_name="Lost Empires",
name="Lost Empires",
main_section=Section(
section_number=3,
number=3,
artifacts=[
Artifact(
name="Bronze Statue",
Expand All @@ -146,15 +150,18 @@ def setUpTestData(cls):
cls.egypt_tour = Tour.objects.create(guide="Amira", exhibit=cls.egypt)
cls.wonders_tour = Tour.objects.create(guide="Carlos", exhibit=cls.wonders)
cls.lost_tour = Tour.objects.create(guide="Yelena", exhibit=cls.lost_empires)
cls.audit_1 = Audit.objects.create(section_number=1, reviewed=True)
cls.audit_2 = Audit.objects.create(section_number=2, reviewed=True)
cls.audit_3 = Audit.objects.create(section_number=5, reviewed=False)

def test_exact(self):
self.assertCountEqual(
Exhibit.objects.filter(sections__section_number=1), [self.egypt, self.wonders]
Exhibit.objects.filter(sections__number=1), [self.egypt, self.wonders]
)

def test_array_index(self):
self.assertCountEqual(
Exhibit.objects.filter(sections__0__section_number=1),
Exhibit.objects.filter(sections__0__number=1),
[self.egypt, self.wonders],
)

Expand All @@ -168,7 +175,7 @@ def test_nested_array_index(self):

def test_array_slice(self):
self.assertSequenceEqual(
Exhibit.objects.filter(sections__0_1__section_number=2), [self.new_descoveries]
Exhibit.objects.filter(sections__0_1__number=2), [self.new_descoveries]
)

def test_filter_unsupported_lookups_in_json(self):
Expand Down Expand Up @@ -196,16 +203,16 @@ def test_len(self):
self.assertCountEqual(Exhibit.objects.filter(sections__1__artifacts__len=1), [self.wonders])

def test_in(self):
self.assertCountEqual(Exhibit.objects.filter(sections__section_number__in=[10]), [])
self.assertCountEqual(Exhibit.objects.filter(sections__number__in=[10]), [])
self.assertCountEqual(
Exhibit.objects.filter(sections__section_number__in=[1]),
Exhibit.objects.filter(sections__number__in=[1]),
[self.egypt, self.wonders],
)
self.assertCountEqual(
Exhibit.objects.filter(sections__section_number__in=[2]),
Exhibit.objects.filter(sections__number__in=[2]),
[self.new_descoveries, self.wonders],
)
self.assertCountEqual(Exhibit.objects.filter(sections__section_number__in=[3]), [])
self.assertCountEqual(Exhibit.objects.filter(sections__number__in=[3]), [])

def test_iexact(self):
self.assertCountEqual(
Expand All @@ -215,24 +222,24 @@ def test_iexact(self):

def test_gt(self):
self.assertCountEqual(
Exhibit.objects.filter(sections__section_number__gt=1),
Exhibit.objects.filter(sections__number__gt=1),
[self.new_descoveries, self.wonders],
)

def test_gte(self):
self.assertCountEqual(
Exhibit.objects.filter(sections__section_number__gte=1),
Exhibit.objects.filter(sections__number__gte=1),
[self.egypt, self.new_descoveries, self.wonders],
)

def test_lt(self):
self.assertCountEqual(
Exhibit.objects.filter(sections__section_number__lt=2), [self.egypt, self.wonders]
Exhibit.objects.filter(sections__number__lt=2), [self.egypt, self.wonders]
)

def test_lte(self):
self.assertCountEqual(
Exhibit.objects.filter(sections__section_number__lte=2),
Exhibit.objects.filter(sections__number__lte=2),
[self.egypt, self.wonders, self.new_descoveries],
)

Expand All @@ -255,20 +262,20 @@ def test_invalid_field(self):
def test_invalid_lookup(self):
msg = "Unsupported lookup 'return' for EmbeddedModelArrayField of 'IntegerField'"
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Exhibit.objects.filter(sections__section_number__return=3)
Exhibit.objects.filter(sections__number__return=3)

def test_invalid_operation(self):
msg = "Unsupported lookup 'rage' for EmbeddedModelArrayField of 'IntegerField'"
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Exhibit.objects.filter(sections__section_number__rage=[10])
Exhibit.objects.filter(sections__number__rage=[10])

def test_missing_lookup_suggestions(self):
msg = (
"Unsupported lookup 'ltee' for EmbeddedModelArrayField of 'IntegerField', "
"perhaps you meant lte or lt?"
)
with self.assertRaisesMessage(FieldDoesNotExist, msg):
Exhibit.objects.filter(sections__section_number__ltee=3)
Exhibit.objects.filter(sections__number__ltee=3)

def test_nested_lookup(self):
msg = "Cannot perform multiple levels of array traversal in a query."
Expand All @@ -277,13 +284,78 @@ def test_nested_lookup(self):

def test_foreign_field_exact(self):
"""Querying from a foreign key to an EmbeddedModelArrayField."""
qs = Tour.objects.filter(exhibit__sections__section_number=1)
qs = Tour.objects.filter(exhibit__sections__number=1)
self.assertCountEqual(qs, [self.egypt_tour, self.wonders_tour])

def test_foreign_field_with_slice(self):
qs = Tour.objects.filter(exhibit__sections__0_2__section_number__in=[1, 2])
qs = Tour.objects.filter(exhibit__sections__0_2__number__in=[1, 2])
self.assertCountEqual(qs, [self.wonders_tour, self.egypt_tour])

def test_subquery_numeric_lookups(self):
subquery = Audit.objects.filter(
section_number__in=models.OuterRef("sections__number")
).values("section_number")[:1]
tests = [
("exact", [self.egypt, self.new_descoveries, self.wonders]),
("lt", []),
("lte", [self.egypt, self.new_descoveries, self.wonders]),
("gt", [self.wonders]),
("gte", [self.egypt, self.new_descoveries, self.wonders]),
]
for lookup, expected in tests:
with self.subTest(lookup=lookup):
kwargs = {f"sections__number__{lookup}": subquery}
self.assertCountEqual(Exhibit.objects.filter(**kwargs), expected)

def test_subquery_in_lookup(self):
subquery = Audit.objects.filter(reviewed=True).values_list("section_number", flat=True)
result = Exhibit.objects.filter(sections__number__in=subquery)
self.assertCountEqual(result, [self.wonders, self.new_descoveries, self.egypt])

def test_array_as_rhs(self):
result = Exhibit.objects.filter(main_section__number__in=models.F("sections__number"))
self.assertCountEqual(result, [self.new_descoveries])

def test_array_annotation_lookup(self):
result = Exhibit.objects.annotate(section_numbers=models.F("main_section__number")).filter(
section_numbers__in=models.F("sections__number")
)
self.assertCountEqual(result, [self.new_descoveries])

def test_array_as_rhs_for_arrayfield_lookups(self):
tests = [
("exact", [self.wonders]),
("lt", [self.new_descoveries]),
("lte", [self.wonders, self.new_descoveries]),
("gt", [self.egypt, self.lost_empires]),
("gte", [self.egypt, self.wonders, self.lost_empires]),
("overlap", [self.egypt, self.wonders, self.new_descoveries]),
("contained_by", [self.wonders]),
("contains", [self.egypt, self.wonders, self.new_descoveries, self.lost_empires]),
]
for lookup, expected in tests:
with self.subTest(lookup=lookup):
kwargs = {f"section_numbers__{lookup}": models.F("sections__number")}
result = Exhibit.objects.annotate(
section_numbers=Value(
[1, 2], output_field=ArrayField(base_field=models.IntegerField())
)
).filter(**kwargs)
self.assertCountEqual(result, expected)

@unittest.expectedFailure
def test_array_annotation_index(self):
# Slicing and indexing over an annotated EmbeddedModelArrayField would
# require a refactor of annotation handling.
result = Exhibit.objects.annotate(section_numbers=models.F("sections__number")).filter(
section_numbers__0=1
)
self.assertCountEqual(result, [self.new_descoveries, self.egypt])

def test_array_annotation(self):
qs = Exhibit.objects.annotate(section_numbers=models.F("sections__number")).order_by("name")
self.assertQuerySetEqual(qs, [[1], [], [2], [1, 2]], attrgetter("section_numbers"))


@isolate_apps("model_fields_")
class CheckTests(SimpleTestCase):
Expand Down