diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index 4f951514..8a9f7e0a 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -338,7 +338,7 @@ class ArrayLenTransform(Transform): def as_mql(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) - return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}} + return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}} @ArrayField.register_lookup diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 57bbd3f5..590fd5f8 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -186,8 +186,9 @@ def as_mql(self, compiler, connection): key_transforms.insert(0, previous.key_name) previous = previous.lhs mql = previous.as_mql(compiler, connection) - transforms = ".".join(key_transforms) - return f"{mql}.{transforms}" + for key in key_transforms: + mql = {"$getField": {"input": mql, "field": key}} + return mql @property def output_field(self): diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index 8b14a498..f10e3f86 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -1,9 +1,15 @@ -from django.db.models import Field +import difflib + +from django.core.exceptions import FieldDoesNotExist +from django.db.models import Field, lookups +from django.db.models.expressions import Col from django.db.models.fields.related import lazy_related_operation +from django.db.models.lookups import Lookup, Transform from .. import forms +from ..query_utils import process_lhs, process_rhs from . import EmbeddedModelField -from .array import ArrayField +from .array import ArrayField, ArrayLenTransform class EmbeddedModelArrayField(ArrayField): @@ -56,3 +62,199 @@ def formfield(self, **kwargs): **kwargs, }, ) + + def get_transform(self, name): + transform = super().get_transform(name) + if transform: + return transform + return KeyTransformFactory(name, self) + + def _get_lookup(self, lookup_name): + lookup = super()._get_lookup(lookup_name) + if lookup is None or lookup is ArrayLenTransform: + return lookup + + class EmbeddedModelArrayFieldLookups(Lookup): + def as_mql(self, compiler, connection): + raise ValueError( + "Lookups aren't supported on EmbeddedModelArrayField. " + "Try querying one of its embedded fields instead." + ) + + return EmbeddedModelArrayFieldLookups + + +class _EmbeddedModelArrayOutputField(ArrayField): + """ + Represent the output of an EmbeddedModelArrayField when traversed in a + query path. + + This field is not meant to be used in model definitions. It exists solely + to support query output resolution. When an EmbeddedModelArrayField is + accessed in a query, the result should behave like an array of the embedded + model's target type. + + While it mimics ArrayField's lookup behavior, the way those lookups are + resolved follows the semantics of EmbeddedModelArrayField rather than + ArrayField. + """ + + ALLOWED_LOOKUPS = { + "in", + "exact", + "iexact", + "gt", + "gte", + "lt", + "lte", + } + + def get_lookup(self, name): + return super().get_lookup(name) if name in self.ALLOWED_LOOKUPS else None + + +class EmbeddedModelArrayFieldBuiltinLookup(Lookup): + def process_rhs(self, compiler, connection): + value = self.rhs + if not self.get_db_prep_lookup_value_is_iterable: + value = [value] + # Value must be serialized based on the query target. If querying a + # subfield inside the array (i.e., a nested KeyTransform), use the + # output field of the subfield. Otherwise, use the base field of the + # array itself. + get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value + return None, [ + v if hasattr(v, "as_mql") else get_db_prep_value(v, connection, prepared=True) + for v in value + ] + + def as_mql(self, compiler, connection): + # Querying a subfield within the array elements (via nested + # KeyTransform). Replicate MongoDB's implicit ANY-match by mapping over + # the array and applying $in on the subfield. + lhs_mql = process_lhs(self, compiler, connection) + inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"] + values = process_rhs(self, compiler, connection) + lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators[self.lookup_name]( + inner_lhs_mql, values + ) + return {"$anyElementTrue": lhs_mql} + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In): + pass + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.Exact): + pass + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldIExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.IExact): + get_db_prep_lookup_value_is_iterable = False + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldGreaterThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThan): + pass + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldGreaterThanOrEqual( + EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThanOrEqual +): + pass + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldLessThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThan): + pass + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldLessThanOrEqual( + EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThanOrEqual +): + pass + + +class KeyTransform(Transform): + def __init__(self, key_name, array_field, *args, **kwargs): + super().__init__(*args, **kwargs) + self.array_field = array_field + self.key_name = key_name + # Lookups iterate over the array of embedded models. A virtual column + # of the queried field's type represents each element. + column_target = array_field.base_field.embedded_model._meta.get_field(key_name).clone() + column_name = f"$item.{key_name}" + column_target.db_column = column_name + column_target.set_attributes_from_name(column_name) + self._lhs = Col(None, column_target) + self._sub_transform = None + + def __call__(self, this, *args, **kwargs): + self._lhs = self._sub_transform(self._lhs, *args, **kwargs) + return self + + def get_lookup(self, name): + return self.output_field.get_lookup(name) + + def get_transform(self, name): + """ + Validate that `name` is either a field of an embedded model or am + allowed lookup on an embedded model's field. + """ + # Once the sub-lhs is a transform, all the filters are applied over it. + # Otherwise get the transform from the nested embedded model field. + if transform := self._lhs.get_transform(name): + if isinstance(transform, KeyTransformFactory): + raise ValueError("Cannot perform multiple levels of array traversal in a query.") + self._sub_transform = transform + return self + output_field = self._lhs.output_field + # The lookup must be allowed AND a valid lookup for the field. + allowed_lookups = self.output_field.ALLOWED_LOOKUPS.intersection( + set(output_field.get_lookups()) + ) + suggested_lookups = difflib.get_close_matches(name, allowed_lookups) + if suggested_lookups: + suggested_lookups = " or ".join(suggested_lookups) + suggestion = f", perhaps you meant {suggested_lookups}?" + else: + suggestion = "" + raise FieldDoesNotExist( + f"Unsupported lookup '{name}' for " + f"EmbeddedModelArrayField of '{output_field.__class__.__name__}'" + f"{suggestion}" + ) + + def as_mql(self, compiler, connection): + inner_lhs_mql = self._lhs.as_mql(compiler, connection) + lhs_mql = process_lhs(self, compiler, connection) + return { + "$ifNull": [ + { + "$map": { + "input": lhs_mql, + "as": "item", + "in": inner_lhs_mql, + } + }, + [], + ] + } + + @property + def output_field(self): + return _EmbeddedModelArrayOutputField(self._lhs.output_field) + + +class KeyTransformFactory: + def __init__(self, key_name, base_field): + self.key_name = key_name + self.base_field = base_field + + def __call__(self, *args, **kwargs): + return KeyTransform(self.key_name, self.base_field, *args, **kwargs) diff --git a/docs/source/ref/models/fields.rst b/docs/source/ref/models/fields.rst index a4de529a..79cafe3d 100644 --- a/docs/source/ref/models/fields.rst +++ b/docs/source/ref/models/fields.rst @@ -91,7 +91,7 @@ We will use the following example model:: def __str__(self): return self.name -.. fieldlookup:: arrayfield.contains +.. fieldlookup:: mongo-arrayfield.contains ``contains`` ^^^^^^^^^^^^ @@ -134,7 +134,7 @@ passed. It uses the ``$setIntersection`` operator. For example: >>> Post.objects.filter(tags__contained_by=["thoughts", "django", "tutorial"]) , , ]> -.. fieldlookup:: arrayfield.overlap +.. fieldlookup:: mongo-arrayfield.overlap ``overlap`` ~~~~~~~~~~~ @@ -154,7 +154,7 @@ uses the ``$setIntersection`` operator. For example: >>> Post.objects.filter(tags__overlap=["thoughts", "tutorial"]) , , ]> -.. fieldlookup:: arrayfield.len +.. fieldlookup:: mongo-arrayfield.len ``len`` ^^^^^^^ @@ -170,7 +170,7 @@ available for :class:`~django.db.models.IntegerField`. For example: >>> Post.objects.filter(tags__len=1) ]> -.. fieldlookup:: arrayfield.index +.. fieldlookup:: mongo-arrayfield.index Index transforms ^^^^^^^^^^^^^^^^ @@ -196,7 +196,7 @@ array. The lookups available after the transform are those from the These indexes use 0-based indexing. -.. fieldlookup:: arrayfield.slice +.. fieldlookup:: mongo-arrayfield.slice Slice transforms ^^^^^^^^^^^^^^^^ diff --git a/docs/source/topics/embedded-models.rst b/docs/source/topics/embedded-models.rst index 2e314567..0daa483b 100644 --- a/docs/source/topics/embedded-models.rst +++ b/docs/source/topics/embedded-models.rst @@ -115,3 +115,69 @@ Represented in BSON, the post's structure looks like this: name: 'Hello world!', tags: [ { name: 'welcome' }, { name: 'test' } ] } + +Querying ``EmbeddedModelArrayField`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can query into an embedded model array using the same double underscore +syntax as relational fields. For example, to find posts that have a tag with +name "test":: + + >>> Post.objects.filter(tags__name="test") + +There are a limited set of lookups you can chain after an embedded field: + +* :lookup:`exact`, :lookup:`iexact` +* :lookup:`in` +* :lookup:`gt`, :lookup:`gte`, :lookup:`lt`, :lookup:`lte` + +For example, to find posts that have tags with name "test", "TEST", "tEsT", +etc:: + +>>> Post.objects.filter(tags__name__iexact="test") + +.. fieldlookup:: embeddedmodelarrayfield.len + +``len`` transform +^^^^^^^^^^^^^^^^^ + +You can use the ``len`` transform to filter on the length of the array. The +lookups available afterward are those available for +:class:`~django.db.models.IntegerField`. For example, to match posts with one +tag:: + + >>> Post.objects.filter(tags__len=1) + +or at least one tag:: + + >>> Post.objects.filter(tags__len__gte=1) + +Index and slice transforms +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Like :class:`~django_mongodb_backend.fields.ArrayField`, you can use +:lookup:`index ` and :lookup:`slice +` transforms to filter on particular items in an array. + +For example, to find posts where the first tag is named "test":: + +>>> Post.objects.filter(tags__0__name="test") + +Or to find posts where the one of the first two tags is named "test":: + +>>> Post.objects.filter(tags__0_1__name="test") + +These indexes use 0-based indexing. + +Nested ``EmbeddedModelArrayField``\s +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If your models use nested ``EmbeddedModelArrayField``\s, you can't use double +underscores to query into the the second level. + +For example, if the ``Tag`` model had an ``EmbeddedModelArrayField`` called +``colors``: + + >>> Post.objects.filter(tags__colors__name="blue") + ... + ValueError: Cannot perform multiple levels of array traversal in a query. diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index 3b4fd4e1..3d3a1584 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -165,3 +165,51 @@ class Review(EmbeddedModel): def __str__(self): return self.title + + +# An exhibit in the museum, composed of multiple sections. +class Exhibit(models.Model): + exhibit_name = models.CharField(max_length=255) + sections = EmbeddedModelArrayField("Section", null=True) + main_section = EmbeddedModelField("Section", null=True) + + def __str__(self): + return self.exhibit_name + + +# A section within an exhibit, containing multiple artifacts. +class Section(EmbeddedModel): + section_number = models.IntegerField() + artifacts = EmbeddedModelArrayField("Artifact", null=True) + + def __str__(self): + return "Section %d" % self.section_number + + +# Details about a specific artifact. +class Artifact(EmbeddedModel): + name = models.CharField(max_length=255) + metadata = models.JSONField() + restorations = EmbeddedModelArrayField("Restoration", null=True) + last_restoration = EmbeddedModelField("Restoration", null=True) + + def __str__(self): + return self.name + + +# Details about when an artifact was restored. +class Restoration(EmbeddedModel): + date = models.DateField() + restored_by = models.CharField(max_length=255) + + def __str__(self): + return f"Restored by {self.restored_by} on {self.date}" + + +# ForeignKey to a model with EmbeddedModelArrayField. +class Tour(models.Model): + guide = models.CharField(max_length=100) + exhibit = models.ForeignKey(Exhibit, models.CASCADE) + + def __str__(self): + return f"Tour by {self.guide}" diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py index 892d2e18..caab244b 100644 --- a/tests/model_fields_/test_embedded_model_array.py +++ b/tests/model_fields_/test_embedded_model_array.py @@ -1,11 +1,14 @@ -from django.db import models +from datetime import date + +from django.core.exceptions import FieldDoesNotExist +from django.db import connection, models from django.test import SimpleTestCase, TestCase -from django.test.utils import isolate_apps +from django.test.utils import CaptureQueriesContext, isolate_apps from django_mongodb_backend.fields import EmbeddedModelArrayField from django_mongodb_backend.models import EmbeddedModel -from .models import Movie, Review +from .models import Artifact, Exhibit, Movie, Restoration, Review, Section, Tour class MethodTests(SimpleTestCase): @@ -55,6 +58,233 @@ def test_save_load_null(self): self.assertIsNone(movie.reviews) +class QueryingTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.egypt = Exhibit.objects.create( + exhibit_name="Ancient Egypt", + sections=[ + Section( + section_number=1, + artifacts=[ + Artifact( + name="Ptolemaic Crown", + metadata={ + "origin": "Egypt", + }, + ) + ], + ) + ], + ) + cls.wonders = Exhibit.objects.create( + exhibit_name="Wonders of the Ancient World", + sections=[ + Section( + section_number=1, + artifacts=[ + Artifact( + name="Statue of Zeus", + metadata={"location": "Olympia", "height_m": 12}, + ), + Artifact( + name="Hanging Gardens", + ), + ], + ), + Section( + section_number=2, + artifacts=[ + Artifact( + name="Lighthouse of Alexandria", + metadata={"height_m": 100, "built": "3rd century BC"}, + ) + ], + ), + ], + ) + cls.new_descoveries = Exhibit.objects.create( + exhibit_name="New Discoveries", + sections=[ + Section( + section_number=2, + artifacts=[ + Artifact( + name="Lighthouse of Alexandria", + metadata={"height_m": 100, "built": "3rd century BC"}, + ) + ], + ) + ], + ) + cls.lost_empires = Exhibit.objects.create( + exhibit_name="Lost Empires", + main_section=Section( + section_number=3, + artifacts=[ + Artifact( + name="Bronze Statue", + metadata={"origin": "Pergamon"}, + restorations=[ + Restoration( + date=date(1998, 4, 15), + restored_by="Zacarias", + ), + Restoration( + date=date(2010, 7, 22), + restored_by="Vicente", + ), + ], + last_restoration=Restoration( + date=date(2010, 7, 22), + restored_by="Monzon", + ), + ) + ], + ), + ) + cls.egypt_tour = Tour.objects.create(guide="Amira", exhibit=cls.egypt) + cls.wonders_tour = Tour.objects.create(guide="Carlos", exhibit=cls.wonders) + cls.lost_tour = Tour.objects.create(guide="Yelena", exhibit=cls.lost_empires) + + def test_exact(self): + self.assertCountEqual( + Exhibit.objects.filter(sections__section_number=1), [self.egypt, self.wonders] + ) + + def test_array_index(self): + self.assertCountEqual( + Exhibit.objects.filter(sections__0__section_number=1), + [self.egypt, self.wonders], + ) + + def test_nested_array_index(self): + self.assertCountEqual( + Exhibit.objects.filter( + main_section__artifacts__restorations__0__restored_by="Zacarias" + ), + [self.lost_empires], + ) + + def test_array_slice(self): + self.assertSequenceEqual( + Exhibit.objects.filter(sections__0_1__section_number=2), [self.new_descoveries] + ) + + def test_filter_unsupported_lookups_in_json(self): + """Unsupported lookups can be used as keys in a JSONField.""" + for lookup in ["contains", "range"]: + kwargs = {f"main_section__artifacts__metadata__origin__{lookup}": ["Pergamon", "Egypt"]} + with CaptureQueriesContext(connection) as captured_queries: + self.assertCountEqual(Exhibit.objects.filter(**kwargs), []) + self.assertIn(f"'field': '{lookup}'", captured_queries[0]["sql"]) + + def test_len(self): + self.assertCountEqual(Exhibit.objects.filter(sections__len=10), []) + self.assertCountEqual( + Exhibit.objects.filter(sections__len=1), + [self.egypt, self.new_descoveries], + ) + # Nested EMF + self.assertCountEqual( + Exhibit.objects.filter(main_section__artifacts__len=1), [self.lost_empires] + ) + self.assertCountEqual(Exhibit.objects.filter(main_section__artifacts__len=2), []) + # Nested Indexed Array + self.assertCountEqual(Exhibit.objects.filter(sections__0__artifacts__len=2), [self.wonders]) + self.assertCountEqual(Exhibit.objects.filter(sections__0__artifacts__len=0), []) + self.assertCountEqual(Exhibit.objects.filter(sections__1__artifacts__len=1), [self.wonders]) + + def test_in(self): + self.assertCountEqual(Exhibit.objects.filter(sections__section_number__in=[10]), []) + self.assertCountEqual( + Exhibit.objects.filter(sections__section_number__in=[1]), + [self.egypt, self.wonders], + ) + self.assertCountEqual( + Exhibit.objects.filter(sections__section_number__in=[2]), + [self.new_descoveries, self.wonders], + ) + self.assertCountEqual(Exhibit.objects.filter(sections__section_number__in=[3]), []) + + def test_iexact(self): + self.assertCountEqual( + Exhibit.objects.filter(sections__artifacts__0__name__iexact="lightHOuse of aLexandriA"), + [self.new_descoveries, self.wonders], + ) + + def test_gt(self): + self.assertCountEqual( + Exhibit.objects.filter(sections__section_number__gt=1), + [self.new_descoveries, self.wonders], + ) + + def test_gte(self): + self.assertCountEqual( + Exhibit.objects.filter(sections__section_number__gte=1), + [self.egypt, self.new_descoveries, self.wonders], + ) + + def test_lt(self): + self.assertCountEqual( + Exhibit.objects.filter(sections__section_number__lt=2), [self.egypt, self.wonders] + ) + + def test_lte(self): + self.assertCountEqual( + Exhibit.objects.filter(sections__section_number__lte=2), + [self.egypt, self.wonders, self.new_descoveries], + ) + + def test_querying_array_not_allowed(self): + msg = ( + "Lookups aren't supported on EmbeddedModelArrayField. " + "Try querying one of its embedded fields instead." + ) + with self.assertRaisesMessage(ValueError, msg): + Exhibit.objects.filter(sections=10).first() + + with self.assertRaisesMessage(ValueError, msg): + Exhibit.objects.filter(sections__0_1=10).first() + + def test_invalid_field(self): + msg = "Section has no field named 'section'" + with self.assertRaisesMessage(FieldDoesNotExist, msg): + Exhibit.objects.filter(sections__section__in=[10]).first() + + def test_invalid_lookup(self): + msg = "Unsupported lookup 'return' for EmbeddedModelArrayField of 'IntegerField'" + with self.assertRaisesMessage(FieldDoesNotExist, msg): + Exhibit.objects.filter(sections__section_number__return=3) + + def test_invalid_operation(self): + msg = "Unsupported lookup 'rage' for EmbeddedModelArrayField of 'IntegerField'" + with self.assertRaisesMessage(FieldDoesNotExist, msg): + Exhibit.objects.filter(sections__section_number__rage=[10]) + + def test_missing_lookup_suggestions(self): + msg = ( + "Unsupported lookup 'ltee' for EmbeddedModelArrayField of 'IntegerField', " + "perhaps you meant lte or lt?" + ) + with self.assertRaisesMessage(FieldDoesNotExist, msg): + Exhibit.objects.filter(sections__section_number__ltee=3) + + def test_nested_lookup(self): + msg = "Cannot perform multiple levels of array traversal in a query." + with self.assertRaisesMessage(ValueError, msg): + Exhibit.objects.filter(sections__artifacts__name="") + + def test_foreign_field_exact(self): + """Querying from a foreign key to an EmbeddedModelArrayField.""" + qs = Tour.objects.filter(exhibit__sections__section_number=1) + self.assertCountEqual(qs, [self.egypt_tour, self.wonders_tour]) + + def test_foreign_field_with_slice(self): + qs = Tour.objects.filter(exhibit__sections__0_2__section_number__in=[1, 2]) + self.assertCountEqual(qs, [self.wonders_tour, self.egypt_tour]) + + @isolate_apps("model_fields_") class CheckTests(SimpleTestCase): def test_no_relational_fields(self):