diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index b7e264f8..9f71eb17 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -746,7 +746,7 @@ def execute_sql(self, result_type): elif hasattr(value, "prepare_database_save"): if field.remote_field: value = value.prepare_database_save(field) - elif not hasattr(field, "embedded_model"): + elif not getattr(field, "value_is_model_instance", False): raise TypeError( f"Tried to update field {field} with a model " f"instance, {value!r}. Use a value compatible with " diff --git a/django_mongodb_backend/fields/__init__.py b/django_mongodb_backend/fields/__init__.py index be95fa5e..26678e1c 100644 --- a/django_mongodb_backend/fields/__init__.py +++ b/django_mongodb_backend/fields/__init__.py @@ -5,6 +5,7 @@ from .embedded_model_array import EmbeddedModelArrayField from .json import register_json_field from .objectid import ObjectIdField +from .polymorphic_embedded_model import PolymorphicEmbeddedModelField __all__ = [ "register_fields", @@ -13,6 +14,7 @@ "EmbeddedModelField", "ObjectIdAutoField", "ObjectIdField", + "PolymorphicEmbeddedModelField", ] diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 87b195b1..093df7d9 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -12,6 +12,8 @@ class EmbeddedModelField(models.Field): """Field that stores a model instance.""" + value_is_model_instance = True + def __init__(self, embedded_model, *args, **kwargs): """ `embedded_model` is the model class of the instance to be stored. diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index 77e91f80..bd4c1283 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -13,6 +13,8 @@ class EmbeddedModelArrayField(ArrayField): + value_is_model_instance = True + def __init__(self, embedded_model, **kwargs): if "size" in kwargs: raise ValueError("EmbeddedModelArrayField does not support size.") diff --git a/django_mongodb_backend/fields/polymorphic_embedded_model.py b/django_mongodb_backend/fields/polymorphic_embedded_model.py new file mode 100644 index 00000000..7ef2bb9d --- /dev/null +++ b/django_mongodb_backend/fields/polymorphic_embedded_model.py @@ -0,0 +1,221 @@ +import contextlib +import difflib + +from django.core import checks +from django.core.exceptions import FieldDoesNotExist, ValidationError +from django.db import models +from django.db.models.fields.related import lazy_related_operation +from django.db.models.lookups import Transform + + +class PolymorphicEmbeddedModelField(models.Field): + """Field that stores a model instance.""" + + value_is_model_instance = True + + def __init__(self, embedded_models, *args, **kwargs): + """ + `embedded_models` is a list of possible model classes to be stored. + Like other relational fields, each model may also be passed as a + string. + """ + self.embedded_models = embedded_models + kwargs["editable"] = False + super().__init__(*args, **kwargs) + + def db_type(self, connection): + return "embeddedDocuments" + + def check(self, **kwargs): + from ..models import EmbeddedModel + + errors = super().check(**kwargs) + for model in self.embedded_models: + if not issubclass(model, EmbeddedModel): + return [ + checks.Error( + "Embedded models must be a subclass of " + "django_mongodb_backend.models.EmbeddedModel.", + obj=self, + hint="{model} doesn't subclass EmbeddedModel.", + id="django_mongodb_backend.embedded_model.E002", + ) + ] + for field in model._meta.fields: + if field.remote_field: + errors.append( + checks.Error( + "Embedded models cannot have relational fields " + f"({model().__class__.__name__}.{field.name} " + f"is a {field.__class__.__name__}).", + obj=self, + id="django_mongodb_backend.embedded_model.E001", + ) + ) + return errors + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + if path.startswith("django_mongodb_backend.fields.polymorphic_embedded_model"): + path = path.replace( + "django_mongodb_backend.fields.polymorphic_embedded_model", + "django_mongodb_backend.fields", + ) + kwargs["embedded_models"] = self.embedded_models + del kwargs["editable"] + return name, path, args, kwargs + + def get_internal_type(self): + return "PolymorphicEmbeddedModelField" + + def _set_model(self, model): + """ + Resolve embedded model classes once the field knows the model it + belongs to. If any of the items in __init__()'s embedded_models + argument are strings, resolve each to the actual model class, + similar to relation fields. + """ + self._model = model + if model is not None: + for embedded_model in self.embedded_models: + if isinstance(embedded_model, str): + + def _resolve_lookup(_, *resolved_models): + self.embedded_models = resolved_models + + lazy_related_operation(_resolve_lookup, model, *self.embedded_models) + + model = property(lambda self: self._model, _set_model) + + def from_db_value(self, value, expression, connection): + return self.to_python(value) + + def to_python(self, value): + """ + Pass embedded model fields' values through each field's to_python() and + reinstantiate the embedded instance. + """ + if value is None: + return None + if not isinstance(value, dict): + return value + model_class = self._get_model_from_label(value.pop("_label")) + instance = model_class( + **{ + field.attname: field.to_python(value[field.attname]) + for field in model_class._meta.fields + if field.attname in value + } + ) + instance._state.adding = False + return instance + + def get_db_prep_save(self, embedded_instance, connection): + """ + Apply pre_save() and get_db_prep_save() of embedded instance fields and + create the {field: value} dict to be saved. + """ + if embedded_instance is None: + return None + if not isinstance(embedded_instance, self.embedded_models): + raise TypeError( + f"Expected instance of type {self.embedded_models!r}, not " + f"{type(embedded_instance)!r}." + ) + field_values = {} + add = embedded_instance._state.adding + for field in embedded_instance._meta.fields: + value = field.get_db_prep_save( + field.pre_save(embedded_instance, add), connection=connection + ) + # Exclude unset primary keys (e.g. {'id': None}). + if field.primary_key and value is None: + continue + field_values[field.attname] = value + field_values["_label"] = embedded_instance._meta.label + # This instance will exist in the database soon. + embedded_instance._state.adding = False + return field_values + + def get_transform(self, name): + transform = super().get_transform(name) + if transform: + return transform + field = None + for model in self.embedded_models: + with contextlib.suppress(FieldDoesNotExist): + field = model._meta.get_field(name) + if field is None: + raise FieldDoesNotExist( + f"The models of field '{self.name}' have no field named '{name}'." + ) + return KeyTransformFactory(name, field) + + def validate(self, value, model_instance): + super().validate(value, model_instance) + if not isinstance(value, self.embedded_models): + raise ValidationError( + f"Expected instance of type {self.embedded_models!r}, not {type(value)!r}." + ) + for field in value._meta.fields: + attname = field.attname + field.validate(getattr(value, attname), model_instance) + + def formfield(self, **kwargs): + raise NotImplementedError("PolymorphicEmbeddedModelField does not support forms.") + + def _get_model_from_label(self, label): + return {model._meta.label: model for model in self.embedded_models}[label] + + +class KeyTransform(Transform): + def __init__(self, key_name, ref_field, *args, **kwargs): + super().__init__(*args, **kwargs) + self.key_name = str(key_name) + self.ref_field = ref_field + + def get_lookup(self, name): + return self.ref_field.get_lookup(name) + + def get_transform(self, name): + """ + Validate that `name` is either a field of an embedded model or a + lookup on an embedded model's field. + """ + if transform := self.ref_field.get_transform(name): + return transform + suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups()) + if suggested_lookups: + suggested_lookups = " or ".join(suggested_lookups) + suggestion = f", perhaps you meant {suggested_lookups}?" + else: + suggestion = "." + raise FieldDoesNotExist( + f"Unsupported lookup '{name}' for " + f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'" + f"{suggestion}" + ) + + def as_mql(self, compiler, connection): + previous = self + key_transforms = [] + while isinstance(previous, KeyTransform): + key_transforms.insert(0, previous.key_name) + previous = previous.lhs + mql = previous.as_mql(compiler, connection) + for key in key_transforms: + mql = {"$getField": {"input": mql, "field": key}} + return mql + + @property + def output_field(self): + return self.ref_field + + +class KeyTransformFactory: + def __init__(self, key_name, ref_field): + self.key_name = key_name + self.ref_field = ref_field + + def __call__(self, *args, **kwargs): + return KeyTransform(self.key_name, self.ref_field, *args, **kwargs) diff --git a/django_mongodb_backend/operations.py b/django_mongodb_backend/operations.py index f03d30b3..8c3e1e59 100644 --- a/django_mongodb_backend/operations.py +++ b/django_mongodb_backend/operations.py @@ -122,6 +122,8 @@ def get_db_converters(self, expression): ) elif internal_type == "JSONField": converters.append(self.convert_jsonfield_value) + elif internal_type == "PolymorphicEmbeddedModelField": + converters.append(self.convert_polymorphicembeddedmodelfield_value) elif internal_type == "TimeField": # Trunc(... output_field="TimeField") values must remain datetime # until Trunc.convert_value() so they can be converted from UTC @@ -182,6 +184,19 @@ def convert_jsonfield_value(self, value, expression, connection): """ return json.dumps(value) + def convert_polymorphicembeddedmodelfield_value(self, value, expression, connection): + if value is not None: + model_class = expression.output_field._get_model_from_label(value["_label"]) + # Apply database converters to each field of the embedded model. + for field in model_class._meta.fields: + field_expr = Expression(output_field=field) + converters = connection.ops.get_db_converters( + field_expr + ) + field_expr.get_db_converters(connection) + for converter in converters: + value[field.attname] = converter(value[field.attname], field_expr, connection) + return value + def convert_timefield_value(self, value, expression, connection): if value is not None: value = value.time() diff --git a/docs/source/ref/models/fields.rst b/docs/source/ref/models/fields.rst index 79cafe3d..1d59db64 100644 --- a/docs/source/ref/models/fields.rst +++ b/docs/source/ref/models/fields.rst @@ -313,3 +313,32 @@ These indexes use 0-based indexing. .. class:: ObjectIdField Stores an :class:`~bson.objectid.ObjectId`. + +``PolymorphicEmbeddedModelField`` +--------------------------------- + +.. class:: PolymorphicEmbeddedModelField(embedded_models, **kwargs) + + .. versionadded:: 5.2.0b2 + + Stores a model of type ``embedded_models``. + + .. attribute:: embedded_models + + This is a required argument that specifies a list of model classes + that may be embedded. + + Each model class reference works just like + :attr:`.EmbeddedModelField.embedded_model`. + + See :ref:`the embedded model topic guide + ` for more details and examples. + +.. admonition:: Migrations support is limited + + :djadmin:`makemigrations` does not yet detect changes to embedded models, + nor does it create indexes or constraints for embedded models. + +.. admonition:: Forms are not supported + + ``PolymorphicEmbeddedModelField``\s don't appear in model forms. diff --git a/docs/source/topics/embedded-models.rst b/docs/source/topics/embedded-models.rst index 0daa483b..e2a3c149 100644 --- a/docs/source/topics/embedded-models.rst +++ b/docs/source/topics/embedded-models.rst @@ -181,3 +181,84 @@ For example, if the ``Tag`` model had an ``EmbeddedModelArrayField`` called >>> Post.objects.filter(tags__colors__name="blue") ... ValueError: Cannot perform multiple levels of array traversal in a query. + +.. _polymorphic-embedded-model-field-example: + +``PolymorphicEmbeddedModelField`` +--------------------------------- + +The basics +~~~~~~~~~~ + +Let's consider this example:: + + from django.db import models + + from django_mongodb_backend.fields import PolymorphicEmbeddedModelField + from django_mongodb_backend.models import EmbeddedModel + + + class Person(models.Model): + name = models.CharField(max_length=255) + pet = PolymorphicEmbeddedModelField(["Cat", "Dog"]) + + def __str__(self): + return self.name + + + class Cat(EmbeddedModel): + name = models.CharField(max_length=255) + purrs = models.BooleanField(default=True) + + def __str__(self): + return self.name + + + class Dog(EmbeddedModel): + name = models.CharField(max_length=255) + barks = models.BooleanField(default=True) + + def __str__(self): + return self.name + + +The API is similar to that of Django's relational fields:: + + >>> bob = Person.objects.create(name="Bob", pet=Dog(name="Woofer")) + >>> bob.pet + + >>> bob.pet.name + 'Woofer' + >>> bob = Person.objects.create(name="Fred", pet=Cat(name="Pheobe")) + +Represented in BSON, the person structures looks like this: + +.. code-block:: js + + { + _id: ObjectId('685da4895e42adade0c8db29'), + name: 'Bob', + pet: { name: 'Woofer', barks: true, _label: 'myapp.Dog' } + }, + { + _id: ObjectId('685da4925e42adade0c8db2a'), + name: 'Fred', + pet: { name: 'Pheobe', purrs: true, _label: 'myapp.Cat' } + } + +The ``_label`` field contains the model's +:attr:`~django.db.models.Options.label`. + +Querying ``PolymorphicEmbeddedModelField`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can query into a polymorphic embedded model field using the same double +underscore syntax as relational fields. For example, to retrieve all people +who have a pet named "Lassy":: + + >>> Person.objects.filter(pet__name="Lassy") + +You can also filter on fields that aren't shared among the embedded models. For +example, if you filter on ``barks``, you'll only get back people with dogs:: + + >>> Person.objects.filter(pet__barks=True) diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index 43522565..78362916 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -7,6 +7,7 @@ EmbeddedModelArrayField, EmbeddedModelField, ObjectIdField, + PolymorphicEmbeddedModelField, ) from django_mongodb_backend.models import EmbeddedModel @@ -222,3 +223,31 @@ class Tour(models.Model): def __str__(self): return f"Tour by {self.guide}" + + +# PolymorphicEmbeddedModelField +class Person(models.Model): + name = models.CharField(max_length=100) + pet = PolymorphicEmbeddedModelField(("Dog", "Cat"), blank=True, null=True) + + def __str__(self): + return self.name + + +class Dog(EmbeddedModel): + name = models.CharField(max_length=100) + barks = models.BooleanField(default=True) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + def __str__(self): + return self.name + + +class Cat(EmbeddedModel): + name = models.CharField(max_length=100) + purs = models.BooleanField(default=True) + weight = models.DecimalField(max_digits=4, decimal_places=2, blank=True, null=True) + + def __str__(self): + return self.name diff --git a/tests/model_fields_/test_polymorphic_embedded_model.py b/tests/model_fields_/test_polymorphic_embedded_model.py new file mode 100644 index 00000000..27ba289b --- /dev/null +++ b/tests/model_fields_/test_polymorphic_embedded_model.py @@ -0,0 +1,223 @@ +from datetime import timedelta +from decimal import Decimal +from unittest import expectedFailure + +from django.core.exceptions import FieldDoesNotExist, ValidationError +from django.db import models +from django.test import SimpleTestCase, TestCase +from django.test.utils import isolate_apps + +from django_mongodb_backend.fields import PolymorphicEmbeddedModelField +from django_mongodb_backend.models import EmbeddedModel + +from .models import Cat, Dog, Library, Person +from .utils import truncate_ms + + +class MethodTests(SimpleTestCase): + def test_not_editable(self): + field = PolymorphicEmbeddedModelField(["Data"], null=True) + self.assertIs(field.editable, False) + + def test_deconstruct(self): + field = PolymorphicEmbeddedModelField(["Data"], null=True) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django_mongodb_backend.fields.PolymorphicEmbeddedModelField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"embedded_models": ["Data"], "null": True}) + + def test_get_db_prep_save_invalid(self): + msg = ( + "Expected instance of type (, " + "), " + "not ." + ) + with self.assertRaisesMessage(TypeError, msg): + Person(pet=42).save() + + def test_validate(self): + obj = Person(name="Bob", pet=Dog(name="Woofer", barks=None)) + # This isn't quite right because "barks" is the subfield of data + # that's non-null. + msg = "{'pet': ['This field cannot be null.']}" + with self.assertRaisesMessage(ValidationError, msg): + obj.full_clean() + + def test_validate_wrong_model_type(self): + obj = Person(name="Bob", pet=Library()) + msg = ( + "{'pet': [\"Expected instance of type " + "(, " + "), not " + ".\"]}" + ) + with self.assertRaisesMessage(ValidationError, msg): + obj.full_clean() + + +class ModelTests(TestCase): + def test_save_load(self): + Person.objects.create(name="Jim", pet=Dog(name="Woofer")) + obj = Person.objects.get() + self.assertIsInstance(obj.pet, Dog) + # get_prep_value() is called, transforming string to int. + self.assertEqual(obj.pet.name, "Woofer") + # Primary keys should not be populated... + self.assertEqual(obj.pet.id, None) + # ... unless set explicitly. + obj.pet.id = obj.id + obj.save() + obj = Person.objects.get() + self.assertEqual(obj.pet.id, obj.id) + + def test_save_load_null(self): + Person.objects.create(pet=None) + obj = Person.objects.get() + self.assertIsNone(obj.pet) + + def test_save_load_decimal(self): + obj = Person.objects.create(pet=Cat(name="Phoebe", weight="5.5")) + obj.refresh_from_db() + self.assertEqual(obj.pet.weight, Decimal("5.5")) + + def test_pre_save(self): + """Field.pre_save() is called on embedded model fields.""" + obj = Person.objects.create(name="Bob", pet=Dog(name="Woofer")) + created_at = truncate_ms(obj.pet.created_at) + updated_at = truncate_ms(obj.pet.updated_at) + self.assertIsNotNone(obj.pet.created_at) + # The values may differ by a millisecond since they aren't generated + # simultaneously. + self.assertAlmostEqual(updated_at, created_at, delta=timedelta(microseconds=1000)) + # save() updates auto_now but not auto_now_add. + obj.save() + self.assertEqual(truncate_ms(obj.pet.created_at), created_at) + self.assertGreater(truncate_ms(obj.pet.updated_at), updated_at) + + +class QueryingTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.cat_owners = [ + Person.objects.create( + name=f"Cat Owner {x}", + pet=Cat( + name=f"Cat {x}", + weight=f"{x}.5", + ), + ) + for x in range(6) + ] + cls.dog_owners = [ + Person.objects.create( + name=f"Dog Owner {x}", + pet=Dog( + name=f"Dog {x}", + barks=x % 2 == 0, + ), + ) + for x in range(6) + ] + + def test_exact(self): + self.assertCountEqual(Person.objects.filter(pet__weight="3.5"), [self.cat_owners[3]]) + + # lt/lte don't exclude nonexistent fields. (In these tests, all Dogs are + # also returned since they don't have a weight field.) + @expectedFailure + def test_lt(self): + self.assertCountEqual(Person.objects.filter(pet__weight__lt="3.5"), self.cat_owners[:3]) + + @expectedFailure + def test_lte(self): + self.assertCountEqual(Person.objects.filter(pet__weight__lte="3.5"), self.cat_owners[:4]) + + def test_gt(self): + self.assertCountEqual(Person.objects.filter(pet__weight__gt=3.5), self.cat_owners[4:]) + + def test_gte(self): + self.assertCountEqual(Person.objects.filter(pet__weight__gte=3.5), self.cat_owners[3:]) + + def test_range(self): + self.assertCountEqual( + Person.objects.filter(pet__weight__range=(2, 4)), self.cat_owners[2:4] + ) + + def test_order_by_embedded_field(self): + qs = Person.objects.filter(pet__weight__gt=3).order_by("-pet__weight") + self.assertSequenceEqual(qs, list(reversed(self.cat_owners[3:]))) + + def test_boolean(self): + self.assertCountEqual( + Person.objects.filter(pet__barks=True), + [x for i, x in enumerate(self.dog_owners, 1) if i % 2 == 1], + ) + + +class InvalidLookupTests(SimpleTestCase): + def test_invalid_field(self): + msg = "The models of field 'pet' have no field named 'first_name'." + with self.assertRaisesMessage(FieldDoesNotExist, msg): + Person.objects.filter(pet__first_name="Bob") + + # def test_invalid_field_nested(self): + # msg = "Address has no field named 'floor'" + # with self.assertRaisesMessage(FieldDoesNotExist, msg): + # Book.objects.filter(author__address__floor="NYC") + + def test_invalid_lookup(self): + msg = "Unsupported lookup 'foo' for CharField 'name'." + with self.assertRaisesMessage(FieldDoesNotExist, msg): + Person.objects.filter(pet__name__foo="Bob") + + def test_invalid_lookup_with_suggestions(self): + msg = ( + "Unsupported lookup '{lookup}' for CharField 'name', " + "perhaps you meant {suggested_lookups}?" + ) + with self.assertRaisesMessage( + FieldDoesNotExist, msg.format(lookup="exactly", suggested_lookups="exact or iexact") + ): + Person.objects.filter(pet__name__exactly="Woof") + with self.assertRaisesMessage( + FieldDoesNotExist, msg.format(lookup="gti", suggested_lookups="gt or gte") + ): + Person.objects.filter(pet__name__gti="Woof") + with self.assertRaisesMessage( + FieldDoesNotExist, msg.format(lookup="is_null", suggested_lookups="isnull") + ): + Person.objects.filter(pet__name__is_null="Woof") + + +@isolate_apps("model_fields_") +class CheckTests(SimpleTestCase): + def test_no_relational_fields(self): + class Target(EmbeddedModel): + key = models.ForeignKey("MyModel", models.CASCADE) + + class MyModel(models.Model): + field = PolymorphicEmbeddedModelField([Target]) + + errors = MyModel().check() + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, "django_mongodb_backend.embedded_model.E001") + msg = errors[0].msg + self.assertEqual( + msg, "Embedded models cannot have relational fields (Target.key is a ForeignKey)." + ) + + def test_embedded_model_subclass(self): + class Target(models.Model): + pass + + class MyModel(models.Model): + field = PolymorphicEmbeddedModelField([Target]) + + errors = MyModel().check() + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, "django_mongodb_backend.embedded_model.E002") + msg = errors[0].msg + self.assertEqual( + msg, + "Embedded models must be a subclass of django_mongodb_backend.models.EmbeddedModel.", + )