diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 83df2f01..7a330132 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,7 +61,10 @@ repos: rev: v1.1.1 hooks: - id: doc8 - args: ["--ignore=D001"] # ignore line length + # D000 Invalid class attribute value for "class" directive when using + # * (keyword-only parameters separator). + # D001 line length + args: ["--ignore=D000,D001"] stages: [manual] - repo: https://github.com/sirosen/check-jsonschema diff --git a/django_mongodb_backend/fields/__init__.py b/django_mongodb_backend/fields/__init__.py index 569c19be..be95fa5e 100644 --- a/django_mongodb_backend/fields/__init__.py +++ b/django_mongodb_backend/fields/__init__.py @@ -2,12 +2,14 @@ from .auto import ObjectIdAutoField from .duration import register_duration_field from .embedded_model import EmbeddedModelField +from .embedded_model_array import EmbeddedModelArrayField from .json import register_json_field from .objectid import ObjectIdField __all__ = [ "register_fields", "ArrayField", + "EmbeddedModelArrayField", "EmbeddedModelField", "ObjectIdAutoField", "ObjectIdField", diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index 4f951514..8a9f7e0a 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -338,7 +338,7 @@ class ArrayLenTransform(Transform): def as_mql(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) - return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}} + return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}} @ArrayField.register_lookup diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 57bbd3f5..590fd5f8 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -186,8 +186,9 @@ def as_mql(self, compiler, connection): key_transforms.insert(0, previous.key_name) previous = previous.lhs mql = previous.as_mql(compiler, connection) - transforms = ".".join(key_transforms) - return f"{mql}.{transforms}" + for key in key_transforms: + mql = {"$getField": {"input": mql, "field": key}} + return mql @property def output_field(self): diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py new file mode 100644 index 00000000..97ed82ae --- /dev/null +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -0,0 +1,279 @@ +import difflib + +from django.core.exceptions import FieldDoesNotExist +from django.db.models import Field, lookups +from django.db.models.expressions import Col +from django.db.models.lookups import Lookup, Transform + +from .. import forms +from ..query_utils import process_lhs, process_rhs +from . import EmbeddedModelField +from .array import ArrayField, ArrayLenTransform + + +class EmbeddedModelArrayField(ArrayField): + def __init__(self, embedded_model, **kwargs): + if "size" in kwargs: + raise ValueError("EmbeddedModelArrayField does not support size.") + super().__init__(EmbeddedModelField(embedded_model), **kwargs) + self.embedded_model = embedded_model + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + if path == "django_mongodb_backend.fields.embedded_model_array.EmbeddedModelArrayField": + path = "django_mongodb_backend.fields.EmbeddedModelArrayField" + kwargs["embedded_model"] = self.embedded_model + del kwargs["base_field"] + return name, path, args, kwargs + + def get_db_prep_value(self, value, connection, prepared=False): + if isinstance(value, list | tuple): + # Must call get_db_prep_save() rather than get_db_prep_value() + # to transform model instances to dicts. + return [self.base_field.get_db_prep_save(i, connection) for i in value] + if value is not None: + raise TypeError( + f"Expected list of {self.embedded_model!r} instances, not {type(value)!r}." + ) + return value + + def formfield(self, **kwargs): + # Skip ArrayField.formfield() which has some differences, including + # unneeded "base_field", and "max_length" instead of "max_num". + return Field.formfield( + self, + **{ + "form_class": forms.EmbeddedModelArrayField, + "model": self.embedded_model, + "max_num": self.max_size, + "prefix": self.name, + **kwargs, + }, + ) + + def get_transform(self, name): + transform = super().get_transform(name) + if transform: + return transform + return KeyTransformFactory(name, self) + + def _get_lookup(self, lookup_name): + lookup = super()._get_lookup(lookup_name) + if lookup is None or lookup is ArrayLenTransform: + return lookup + + class EmbeddedModelArrayFieldLookups(Lookup): + def as_mql(self, compiler, connection): + raise ValueError( + "Cannot apply this lookup directly to EmbeddedModelArrayField. " + "Try querying one of its embedded fields instead." + ) + + return EmbeddedModelArrayFieldLookups + + +class _EmbeddedModelArrayOutputField(ArrayField): + """ + Represents the output of an EmbeddedModelArrayField when traversed in a query path. + + This field is not meant to be used directly in model definitions. It exists solely to + support query output resolution; when an EmbeddedModelArrayField is accessed in a query, + the result should behave like an array of the embedded model's target type. + + While it mimics ArrayField's lookups behavior, the way those lookups are resolved + follows the semantics of EmbeddedModelArrayField rather than native array behavior. + """ + + ALLOWED_LOOKUPS = { + "in", + "exact", + "iexact", + "gt", + "gte", + "lt", + "lte", + "all", + "contained_by", + } + + def get_lookup(self, name): + return super().get_lookup(name) if name in self.ALLOWED_LOOKUPS else None + + +class EmbeddedModelArrayFieldBuiltinLookup(Lookup): + def process_rhs(self, compiler, connection): + value = self.rhs + if not self.get_db_prep_lookup_value_is_iterable: + value = [value] + # Value must be serialized based on the query target. + # If querying a subfield inside the array (i.e., a nested KeyTransform), use the output + # field of the subfield. Otherwise, use the base field of the array itself. + get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value + return None, [ + v if hasattr(v, "as_mql") else get_db_prep_value(v, connection, prepared=True) + for v in value + ] + + def as_mql(self, compiler, connection): + # Querying a subfield within the array elements (via nested KeyTransform). + # Replicates MongoDB's implicit ANY-match by mapping over the array and applying + # `$in` on the subfield. + lhs_mql = process_lhs(self, compiler, connection) + inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"] + values = process_rhs(self, compiler, connection) + lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators[self.lookup_name]( + inner_lhs_mql, values + ) + return {"$anyElementTrue": lhs_mql} + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In): + pass + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.Exact): + pass + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldIExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.IExact): + get_db_prep_lookup_value_is_iterable = False + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldGreaterThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThan): + pass + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldGreaterThanOrEqual( + EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThanOrEqual +): + pass + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldLessThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThan): + pass + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldLessThanOrEqual( + EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThanOrEqual +): + pass + + +@_EmbeddedModelArrayOutputField.register_lookup +class EmbeddedModelArrayFieldAll(EmbeddedModelArrayFieldBuiltinLookup, Lookup): + lookup_name = "all" + get_db_prep_lookup_value_is_iterable = False + + def as_mql(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + values = process_rhs(self, compiler, connection) + return { + "$and": [ + {"$ne": [lhs_mql, None]}, + {"$ne": [values, None]}, + {"$setIsSubset": [values, lhs_mql]}, + ] + } + + +@_EmbeddedModelArrayOutputField.register_lookup +class ArrayContainedBy(EmbeddedModelArrayFieldBuiltinLookup, Lookup): + lookup_name = "contained_by" + get_db_prep_lookup_value_is_iterable = False + + def as_mql(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + value = process_rhs(self, compiler, connection) + return { + "$and": [ + {"$ne": [lhs_mql, None]}, + {"$ne": [value, None]}, + {"$setIsSubset": [lhs_mql, value]}, + ] + } + + +class KeyTransform(Transform): + def __init__(self, key_name, array_field, *args, **kwargs): + super().__init__(*args, **kwargs) + self.array_field = array_field + self.key_name = key_name + # The iteration items begins from the base_field, a virtual column with + # base field output type is created. + column_target = array_field.embedded_model._meta.get_field(key_name).clone() + column_name = f"$item.{key_name}" + column_target.db_column = column_name + column_target.set_attributes_from_name(column_name) + self._lhs = Col(None, column_target) + self._sub_transform = None + + def __call__(self, this, *args, **kwargs): + self._lhs = self._sub_transform(self._lhs, *args, **kwargs) + return self + + def get_lookup(self, name): + return self.output_field.get_lookup(name) + + def get_transform(self, name): + """ + Validate that `name` is either a field of an embedded model or a + lookup on an embedded model's field. + """ + # Once the sub lhs is a transform, all the filter are applied over it. + # Otherwise get transform from EMF. + if transform := self._lhs.get_transform(name): + if isinstance(transform, KeyTransformFactory): + raise ValueError("Cannot perform multiple levels of array traversal in a query.") + self._sub_transform = transform + return self + output_field = self._lhs.output_field + allowed_lookups = self.output_field.ALLOWED_LOOKUPS.intersection( + set(output_field.get_lookups()) + ) + suggested_lookups = difflib.get_close_matches(name, allowed_lookups) + if suggested_lookups: + suggested_lookups = " or ".join(suggested_lookups) + suggestion = f", perhaps you meant {suggested_lookups}?" + else: + suggestion = "" + raise FieldDoesNotExist( + f"Unsupported lookup '{name}' for " + f"EmbeddedModelArrayField of '{output_field.__class__.__name__}'" + f"{suggestion}" + ) + + def as_mql(self, compiler, connection): + inner_lhs_mql = self._lhs.as_mql(compiler, connection) + lhs_mql = process_lhs(self, compiler, connection) + return { + "$ifNull": [ + { + "$map": { + "input": lhs_mql, + "as": "item", + "in": inner_lhs_mql, + } + }, + [], + ] + } + + @property + def output_field(self): + return _EmbeddedModelArrayOutputField(self._lhs.output_field) + + +class KeyTransformFactory: + def __init__(self, key_name, base_field): + self.key_name = key_name + self.base_field = base_field + + def __call__(self, *args, **kwargs): + return KeyTransform(self.key_name, self.base_field, *args, **kwargs) diff --git a/django_mongodb_backend/forms/__init__.py b/django_mongodb_backend/forms/__init__.py index 2adc8fbe..eb15e762 100644 --- a/django_mongodb_backend/forms/__init__.py +++ b/django_mongodb_backend/forms/__init__.py @@ -1,4 +1,5 @@ from .fields import ( + EmbeddedModelArrayField, EmbeddedModelField, ObjectIdField, SimpleArrayField, @@ -7,6 +8,7 @@ ) __all__ = [ + "EmbeddedModelArrayField", "EmbeddedModelField", "SimpleArrayField", "SplitArrayField", diff --git a/django_mongodb_backend/forms/fields/__init__.py b/django_mongodb_backend/forms/fields/__init__.py index 03cc2372..214c0f80 100644 --- a/django_mongodb_backend/forms/fields/__init__.py +++ b/django_mongodb_backend/forms/fields/__init__.py @@ -1,8 +1,10 @@ from .array import SimpleArrayField, SplitArrayField, SplitArrayWidget from .embedded_model import EmbeddedModelField +from .embedded_model_array import EmbeddedModelArrayField from .objectid import ObjectIdField __all__ = [ + "EmbeddedModelArrayField", "EmbeddedModelField", "SimpleArrayField", "SplitArrayField", diff --git a/django_mongodb_backend/forms/fields/embedded_model.py b/django_mongodb_backend/forms/fields/embedded_model.py index bbfa9c02..185be44b 100644 --- a/django_mongodb_backend/forms/fields/embedded_model.py +++ b/django_mongodb_backend/forms/fields/embedded_model.py @@ -4,36 +4,6 @@ from django.utils.translation import gettext_lazy as _ -class EmbeddedModelWidget(forms.MultiWidget): - def __init__(self, field_names, *args, **kwargs): - self.field_names = field_names - super().__init__(*args, **kwargs) - # The default widget names are "_0", "_1", etc. Use the field names - # instead since that's how they'll be rendered by the model form. - self.widgets_names = ["-" + name for name in field_names] - - def decompress(self, value): - if value is None: - return [] - # Get the data from `value` (a model) for each field. - return [getattr(value, name) for name in self.field_names] - - -class EmbeddedModelBoundField(forms.BoundField): - def __init__(self, form, field, name, prefix_override=None): - super().__init__(form, field, name) - # prefix_override overrides the prefix in self.field.form_kwargs so - # that nested embedded model form elements have the correct name. - self.prefix_override = prefix_override - - def __str__(self): - """Render the model form as the representation for this field.""" - form = self.field.model_form_cls(instance=self.value(), **self.field.form_kwargs) - if self.prefix_override: - form.prefix = self.prefix_override - return mark_safe(f"{form.as_div()}") # noqa: S308 - - class EmbeddedModelField(forms.MultiValueField): default_error_messages = { "invalid": _("Enter a list of values."), @@ -79,3 +49,33 @@ def prepare_value(self, value): # (rather than a list) for initializing the form in # EmbeddedModelBoundField.__str__(). return self.compress(value) if isinstance(value, list) else value + + +class EmbeddedModelBoundField(forms.BoundField): + def __init__(self, form, field, name, prefix_override=None): + super().__init__(form, field, name) + # prefix_override overrides the prefix in self.field.form_kwargs so + # that nested embedded model form elements have the correct name. + self.prefix_override = prefix_override + + def __str__(self): + """Render the model form as the representation for this field.""" + form = self.field.model_form_cls(instance=self.value(), **self.field.form_kwargs) + if self.prefix_override: + form.prefix = self.prefix_override + return mark_safe(f"{form.as_div()}") # noqa: S308 + + +class EmbeddedModelWidget(forms.MultiWidget): + def __init__(self, field_names, *args, **kwargs): + self.field_names = field_names + super().__init__(*args, **kwargs) + # The default widget names are "_0", "_1", etc. Use the field names + # instead since that's how they'll be rendered by the model form. + self.widgets_names = ["-" + name for name in field_names] + + def decompress(self, value): + if value is None: + return [] + # Get the data from `value` (a model) for each field. + return [getattr(value, name) for name in self.field_names] diff --git a/django_mongodb_backend/forms/fields/embedded_model_array.py b/django_mongodb_backend/forms/fields/embedded_model_array.py new file mode 100644 index 00000000..aebbd7f1 --- /dev/null +++ b/django_mongodb_backend/forms/fields/embedded_model_array.py @@ -0,0 +1,80 @@ +from django import forms +from django.core.exceptions import ValidationError +from django.forms import formset_factory, model_to_dict +from django.forms.models import modelform_factory +from django.utils.html import format_html, format_html_join + + +class EmbeddedModelArrayField(forms.Field): + def __init__(self, model, *, prefix, max_num=None, extra_forms=3, **kwargs): + self.model = model + self.prefix = prefix + self.formset = formset_factory( + form=modelform_factory(model, fields="__all__"), + can_delete=True, + max_num=max_num, + extra=extra_forms, + validate_max=True, + ) + kwargs["widget"] = EmbeddedModelArrayWidget() + super().__init__(**kwargs) + + def clean(self, value): + if not value: + return [] + formset = self.formset(value, prefix=self.prefix_override or self.prefix) + if not formset.is_valid(): + raise ValidationError(formset.errors + formset.non_form_errors()) + cleaned_data = [] + for data in formset.cleaned_data: + # The "delete" checkbox isn't part of model data and must be + # removed. The fallback to True skips empty forms. + if data.pop("DELETE", True): + continue + cleaned_data.append(self.model(**data)) + return cleaned_data + + def has_changed(self, initial, data): + formset = self.formset(data, initial=models_to_dicts(initial), prefix=self.prefix) + return formset.has_changed() + + def get_bound_field(self, form, field_name): + # Nested embedded model form fields need a double prefix. + # HACK: Setting self.prefix_override makes it available in clean() + # which doesn't have access to the form. + self.prefix_override = f"{form.prefix}-{self.prefix}" if form.prefix else None + return EmbeddedModelArrayBoundField(form, self, field_name, self.prefix_override) + + +class EmbeddedModelArrayBoundField(forms.BoundField): + def __init__(self, form, field, name, prefix_override): + super().__init__(form, field, name) + self.formset = field.formset( + self.data if form.is_bound else None, + initial=models_to_dicts(self.initial), + prefix=prefix_override if prefix_override else self.html_name, + ) + + def __str__(self): + body = format_html_join( + "\n", "{}", ((form.as_table(),) for form in self.formset) + ) + return format_html("\n{}\n
\n{}", body, self.formset.management_form) + + +class EmbeddedModelArrayWidget(forms.Widget): + """ + Extract the data for EmbeddedModelArrayFormField's formset. + This widget is never rendered. + """ + + def value_from_datadict(self, data, files, name): + return {field: value for field, value in data.items() if field.startswith(f"{name}-")} + + +def models_to_dicts(models): + """ + Convert initial data (which is a list of model instances or None) to a + list of dictionary data suitable for a formset. + """ + return [model_to_dict(model) for model in models or []] diff --git a/django_mongodb_backend/operations.py b/django_mongodb_backend/operations.py index df2b824a..f03d30b3 100644 --- a/django_mongodb_backend/operations.py +++ b/django_mongodb_backend/operations.py @@ -111,6 +111,15 @@ def get_db_converters(self, expression): converters.append(self.convert_decimalfield_value) elif internal_type == "EmbeddedModelField": converters.append(self.convert_embeddedmodelfield_value) + elif internal_type == "EmbeddedModelArrayField": + converters.extend( + [ + self._get_arrayfield_converter(converter) + for converter in self.get_db_converters( + Expression(output_field=expression.output_field.base_field) + ) + ] + ) elif internal_type == "JSONField": converters.append(self.convert_jsonfield_value) elif internal_type == "TimeField": diff --git a/docs/make.bat b/docs/make.bat old mode 100644 new mode 100755 diff --git a/docs/source/ref/forms.rst b/docs/source/ref/forms.rst index 74ef26f9..72c646a7 100644 --- a/docs/source/ref/forms.rst +++ b/docs/source/ref/forms.rst @@ -23,6 +23,36 @@ Some MongoDB-specific fields are available in ``django_mongodb_backend.forms``. in this field's subform will have so that the names don't collide with fields in the main form. +``EmbeddedModelArrayField`` +--------------------------- + +.. class:: EmbeddedModelArrayField(model, *, prefix, max_num=None, extra_forms=3, **kwargs) + + .. versionadded:: 5.2.0b1 + + A field which maps to a list of model instances. The field will render as a + :class:`ModelFormSet `. + + .. attribute:: model + + This is a required argument that specifies the model class. + + .. attribute:: prefix + + This is a required argument that specifies the prefix that all fields + in this field's formset will have so that the names don't collide with + fields in the main form. + + .. attribute:: max_num + + This is an optional argument which specifies the maximum number of + model instances that can be created. + + .. attribute:: extra_forms + + This argument specifies the number of blank forms that will be + rendered by the formset. + ``ObjectIdField`` ----------------- diff --git a/docs/source/ref/models/fields.rst b/docs/source/ref/models/fields.rst index be9ff6dd..f0f67cad 100644 --- a/docs/source/ref/models/fields.rst +++ b/docs/source/ref/models/fields.rst @@ -35,8 +35,8 @@ Some MongoDB-specific fields are available in ``django_mongodb_backend.fields``. :class:`~django.db.models.OneToOneField` and :class:`~django.db.models.ManyToManyField`) and file fields ( :class:`~django.db.models.FileField` and - :class:`~django.db.models.ImageField`). :class:`EmbeddedModelField` is - also not (yet) supported. + :class:`~django.db.models.ImageField`). For + :class:`EmbeddedModelField`, use :class:`EmbeddedModelArrayField`. It is possible to nest array fields - you can specify an instance of ``ArrayField`` as the ``base_field``. For example:: @@ -256,7 +256,8 @@ These indexes use 0-based indexing. class Book(models.Model): author = EmbeddedModelField(Author) - See :doc:`/topics/embedded-models` for more details and examples. + See :ref:`the embedded model topic guide ` + for more details and examples. .. admonition:: Migrations support is limited @@ -268,6 +269,237 @@ These indexes use 0-based indexing. created these models and then added an indexed field to ``Address``, the index created in the nested ``Book`` embed is not created. +``EmbeddedModelArrayField`` +--------------------------- + +.. class:: EmbeddedModelArrayField(embedded_model, max_size=None, **kwargs) + + .. versionadded:: 5.2.0b1 + + Similar to :class:`EmbeddedModelField`, but stores a **list** of models of + type ``embedded_model`` rather than a single instance. + + .. attribute:: embedded_model + + This is a required argument that works just like + :attr:`EmbeddedModelField.embedded_model`. + + .. attribute:: max_size + + This is an optional argument. + + If passed, the list will have a maximum size as specified, validated + by forms and model validation, but not enforced by the database. + + See :ref:`the embedded model topic guide + ` for more details and examples. + +.. admonition:: Migrations support is limited + + As described above for :class:`EmbeddedModelField`, + :djadmin:`makemigrations` does not yet detect changes to embedded models. + +Querying ``EmbeddedModelArrayField`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +There are a number of custom lookups and a transform for +:class:`EmbeddedModelArrayField`, similar to those available +for :class:`ArrayField`. +We will use the following example model:: + + from django.db import models + from django_mongodb_backend.fields import EmbeddedModelArrayField + + + class Tag(EmbeddedModel): + label = models.CharField(max_length=100) + + class Post(models.Model): + name = models.CharField(max_length=200) + tags = EmbeddedModelArrayField(Tag) + + def __str__(self): + return self.name + +Embedded field lookup +^^^^^^^^^^^^^^^^^^^^^ + +Embedded field lookup for :class:`EmbeddedModelArrayField` allow querying +fields of the embedded model. This is done by composing the two involved paths: +the path to the ``EmbeddedModelArrayField`` and the path within the nested +embedded model. +This composition enables generating the appropriate query for the lookups. + +.. fieldlookup:: embeddedmodelarrayfield.in + +``in`` +^^^^^^ + +Returns objects where any of the embedded documents in the field match any of +the values passed. For example: + +.. code-block:: pycon + + >>> Post.objects.create( + ... name="First post", tags=[Tag(label="thoughts"), Tag(label="django")] + ... ) + >>> Post.objects.create(name="Second post", tags=[Tag(label="thoughts")]) + >>> Post.objects.create( + ... name="Third post", tags=[Tag(label="tutorial"), Tag(label="django")] + ... ) + + >>> Post.objects.filter(tags__label__in=["thoughts"]) + , ]> + + >>> Post.objects.filter(tags__label__in=["tutorial", "thoughts"]) + , , ]> + +.. fieldlookup:: embeddedmodelarrayfield.len + +``len`` +^^^^^^^ + +Returns the length of the embedded model array. The lookups available afterward +are those available for :class:`~django.db.models.IntegerField`. For example: + +.. code-block:: pycon + + >>> Post.objects.create( + ... name="First post", tags=[Tag(label="thoughts"), Tag(label="django")] + ... ) + >>> Post.objects.create(name="Second post", tags=[Tag(label="thoughts")]) + + >>> Post.objects.filter(tags__len=1) + ]> + +.. fieldlookup:: embeddedmodelarrayfield.exact + +``exact`` +^^^^^^^^^ + +Returns objects where **any** embedded model in the array exactly matches the +given value. This acts like an existence filter on matching embedded documents. + +.. code-block:: pycon + + >>> Post.objects.create( + ... name="First post", tags=[Tag(label="thoughts"), Tag(label="django")] + ... ) + >>> Post.objects.create(name="Second post", tags=[Tag(label="tutorial")]) + + >>> Post.objects.filter(tags__label__exact="tutorial") + ]> + +.. fieldlookup:: embeddedmodelarrayfield.iexact + +``iexact`` +^^^^^^^^^^ + +Returns objects where **any** embedded model in the array has a field that +matches the given value **case-insensitively**. This works like ``exact`` but +ignores letter casing. + +.. code-block:: pycon + + + >>> Post.objects.create( + ... name="First post", tags=[Tag(label="Thoughts"), Tag(label="Django")] + ... ) + >>> Post.objects.create(name="Second post", tags=[Tag(label="tutorial")]) + + >>> Post.objects.filter(tags__label__iexact="django") + ]> + + >>> Post.objects.filter(tags__label__iexact="TUTORIAL") + ]> + +.. fieldlookup:: embeddedmodelarrayfield.gt +.. fieldlookup:: embeddedmodelarrayfield.gte +.. fieldlookup:: embeddedmodelarrayfield.lt +.. fieldlookup:: embeddedmodelarrayfield.lte + +``Greater Than, Greater Than or Equal, Less Than, Less Than or Equal`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +These lookups return objects where **any** embedded document contains a value +that satisfies the corresponding comparison. These are typically used on +numeric or comparable fields within the embedded model. + +Examples: + +.. code-block:: pycon + + Post.objects.create( + name="First post", tags=[Tag(label="django", rating=5), Tag(label="rest", rating=3)] + ) + Post.objects.create( + name="Second post", tags=[Tag(label="python", rating=2)] + ) + + Post.objects.filter(tags__rating__gt=3) + ]> + + Post.objects.filter(tags__rating__gte=3) + , ]> + + Post.objects.filter(tags__rating__lt=3) + + + Post.objects.filter(tags__rating__lte=3) + , ]> + +.. fieldlookup:: embeddedmodelarrayfield.all + +``all`` +^^^^^^^ + +Returns objects where **all** values provided on the right-hand side are +present. It requires that *every* value be matched by some document in +the array. + +Example: + +.. code-block:: pycon + + Post.objects.create( + name="First post", tags=[Tag(label="django"), Tag(label="rest")] + ) + Post.objects.create( + name="Second post", tags=[Tag(label="django")] + ) + + Post.objects.filter(tags__label__all=["django", "rest"]) + ]> + + Post.objects.filter(tags__label__all=["django"]) + , ]> + +.. fieldlookup:: embeddedmodelarrayfield.contained_by + +``contained_by`` +^^^^^^^^^^^^^^^^ + +Returns objects where the embedded model array is **contained by** the list of +values on the right-hand side. In other words, every value in the embedded +array must be present in the given list. + +Example: + +.. code-block:: pycon + + Post.objects.create( + name="First post", tags=[Tag(label="django"), Tag(label="rest")] + ) + Post.objects.create( + name="Second post", tags=[Tag(label="django")] + ) + + Post.objects.filter(tags__label__contained_by=["django", "rest", "api"]) + , ]> + + Post.objects.filter(tags__label__contained_by=["django"]) + ]> + ``ObjectIdAutoField`` --------------------- diff --git a/docs/source/ref/models/models.rst b/docs/source/ref/models/models.rst index a4491228..32b5fc85 100644 --- a/docs/source/ref/models/models.rst +++ b/docs/source/ref/models/models.rst @@ -14,3 +14,6 @@ One MongoDB-specific model is available in ``django_mongodb_backend.models``. any of the normal ``QuerySet`` methods (``all()``, ``filter()``, ``delete()``, etc.) You also cannot call ``Model.save()`` and ``delete()`` on them. + + Embedded model instances won't have a value for their primary key unless + one is explicitly set. diff --git a/docs/source/releases/5.2.x.rst b/docs/source/releases/5.2.x.rst index db3bf67e..638e8e6a 100644 --- a/docs/source/releases/5.2.x.rst +++ b/docs/source/releases/5.2.x.rst @@ -7,6 +7,12 @@ Django MongoDB Backend 5.2.x *Unreleased* +New features +------------ + +- Added :class:`~.fields.EmbeddedModelArrayField` for storing a list of model + instances. + Bug fixes --------- diff --git a/docs/source/topics/embedded-models.rst b/docs/source/topics/embedded-models.rst index 94abecfd..2e314567 100644 --- a/docs/source/topics/embedded-models.rst +++ b/docs/source/topics/embedded-models.rst @@ -1,56 +1,117 @@ Embedded models =============== -Use :class:`~django_mongodb_backend.fields.EmbeddedModelField` to structure +Use :class:`~django_mongodb_backend.fields.EmbeddedModelField` and +:class:`~django_mongodb_backend.fields.EmbeddedModelArrayField` to structure your data using `embedded documents `_. +.. _embedded-model-field-example: + +``EmbeddedModelField`` +---------------------- + The basics ----------- +~~~~~~~~~~ Let's consider this example:: - from django_mongodb_backend.fields import EmbeddedModelField - from django_mongodb_backend.models import EmbeddedModel + from django.db import models + + from django_mongodb_backend.fields import EmbeddedModelField + from django_mongodb_backend.models import EmbeddedModel + - class Customer(models.Model): - name = models.CharField(...) - address = EmbeddedModelField("Address") - ... + class Customer(models.Model): + name = models.CharField(max_length=255) + address = EmbeddedModelField("Address") - class Address(EmbeddedModel): - ... - city = models.CharField(...) + def __str__(self): + return self.name + + + class Address(EmbeddedModel): + city = models.CharField(max_length=255) + + def __str__(self): + return self.city The API is similar to that of Django's relational fields:: - >>> Customer.objects.create(name="Bob", address=Address(city="New York", ...), ...) - >>> bob = Customer.objects.get(...) - >>> bob.address - - >>> bob.address.city - 'New York' + >>> bob = Customer.objects.create(name="Bob", address=Address(city="New York")) + >>> bob.address + + >>> bob.address.city + 'New York' -Represented in BSON, Bob's structure looks like this: +Represented in BSON, the customer structure looks like this: .. code-block:: js - { - "_id": ObjectId(...), - "name": "Bob", - "address": { - ... - "city": "New York" - }, - ... - } + { + _id: ObjectId('683df821ec4bbe0692d43388'), + name: 'Bob', + address: { city: 'New York' } + } Querying ``EmbeddedModelField`` -------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ You can query into an embedded model using the same double underscore syntax as relational fields. For example, to retrieve all customers who have an address with the city "New York":: >>> Customer.objects.filter(address__city="New York") + +.. _embedded-model-array-field-example: + +``EmbeddedModelArrayField`` +--------------------------- + +The basics +~~~~~~~~~~ + +Let's consider this example:: + + from django.db import models + + from django_mongodb_backend.fields import EmbeddedModelArrayField + from django_mongodb_backend.models import EmbeddedModel + + + class Post(models.Model): + name = models.CharField(max_length=200) + tags = EmbeddedModelArrayField("Tag") + + def __str__(self): + return self.name + + + class Tag(EmbeddedModel): + name = models.CharField(max_length=100) + + def __str__(self): + return self.name + + +The API is similar to that of Django's relational fields:: + + >>> post = Post.objects.create( + ... name="Hello world!", + ... tags=[Tag(name="welcome"), Tag(name="test")], + ... ) + >>> post.tags + [, ] + >>> post.tags[0].name + 'welcome' + +Represented in BSON, the post's structure looks like this: + +.. code-block:: js + + { + _id: ObjectId('683dee4c6b79670044c38e3f'), + name: 'Hello world!', + tags: [ { name: 'welcome' }, { name: 'test' } ] + } diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index 9a4efc89..d6c384ee 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -2,7 +2,12 @@ from django.db import models -from django_mongodb_backend.fields import ArrayField, EmbeddedModelField, ObjectIdField +from django_mongodb_backend.fields import ( + ArrayField, + EmbeddedModelArrayField, + EmbeddedModelField, + ObjectIdField, +) from django_mongodb_backend.models import EmbeddedModel @@ -143,3 +148,57 @@ class Library(models.Model): def __str__(self): return self.name + + +# EmbeddedModelArrayField +class Review(EmbeddedModel): + title = models.CharField(max_length=255) + rating = models.DecimalField(max_digits=6, decimal_places=1) + + def __str__(self): + return self.title + + +class Movie(models.Model): + title = models.CharField(max_length=255) + reviews = EmbeddedModelArrayField(Review, null=True) + + def __str__(self): + return self.title + + +class RestorationRecord(EmbeddedModel): + date = models.DateField() + restored_by = models.CharField(max_length=255) + + +# Details about a specific artifact. +class ArtifactDetail(EmbeddedModel): + name = models.CharField(max_length=255) + metadata = models.JSONField() + restorations = EmbeddedModelArrayField(RestorationRecord, null=True) + last_restoration = EmbeddedModelField(RestorationRecord, null=True) + + +# A section within an exhibit, containing multiple artifacts. +class ExhibitSection(EmbeddedModel): + section_number = models.IntegerField() + artifacts = EmbeddedModelArrayField(ArtifactDetail, null=True) + + +# An exhibit in the museum, composed of multiple sections. +class MuseumExhibit(models.Model): + exhibit_name = models.CharField(max_length=255) + sections = EmbeddedModelArrayField(ExhibitSection, null=True) + main_section = EmbeddedModelField(ExhibitSection, null=True) + + def __str__(self): + return self.exhibit_name + + +class Tour(models.Model): + guide = models.CharField(max_length=100) + exhibit = models.ForeignKey(MuseumExhibit, on_delete=models.CASCADE) + + def __str__(self): + return f"Tour by {self.guide}" diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py new file mode 100644 index 00000000..c7499b6a --- /dev/null +++ b/tests/model_fields_/test_embedded_model_array.py @@ -0,0 +1,350 @@ +from datetime import date + +from django.core.exceptions import FieldDoesNotExist +from django.db import connection, models +from django.test import SimpleTestCase, TestCase +from django.test.utils import CaptureQueriesContext, isolate_apps + +from django_mongodb_backend.fields import EmbeddedModelArrayField +from django_mongodb_backend.models import EmbeddedModel + +from .models import ( + ArtifactDetail, + ExhibitSection, + Movie, + MuseumExhibit, + RestorationRecord, + Review, + Tour, +) + + +class MethodTests(SimpleTestCase): + def test_deconstruct(self): + field = EmbeddedModelArrayField("Data", null=True) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django_mongodb_backend.fields.EmbeddedModelArrayField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"embedded_model": "Data", "null": True}) + + def test_size_not_supported(self): + msg = "EmbeddedModelArrayField does not support size." + with self.assertRaisesMessage(ValueError, msg): + EmbeddedModelArrayField("Data", size=1) + + def test_get_db_prep_save_invalid(self): + msg = "Expected list of instances, not ." + with self.assertRaisesMessage(TypeError, msg): + Movie(reviews=42).save() + + def test_get_db_prep_save_invalid_list(self): + msg = "Expected instance of type , not ." + with self.assertRaisesMessage(TypeError, msg): + Movie(reviews=[42]).save() + + +class ModelTests(TestCase): + def test_save_load(self): + reviews = [ + Review(title="The best", rating=10), + Review(title="Mediocre", rating=5), + Review(title="Horrible", rating=1), + ] + Movie.objects.create(title="Lion King", reviews=reviews) + movie = Movie.objects.get(title="Lion King") + self.assertEqual(movie.reviews[0].title, "The best") + self.assertEqual(movie.reviews[0].rating, 10) + self.assertEqual(movie.reviews[1].title, "Mediocre") + self.assertEqual(movie.reviews[1].rating, 5) + self.assertEqual(movie.reviews[2].title, "Horrible") + self.assertEqual(movie.reviews[2].rating, 1) + self.assertEqual(len(movie.reviews), 3) + + def test_save_load_null(self): + movie = Movie.objects.create(title="Lion King") + movie = Movie.objects.get(title="Lion King") + self.assertIsNone(movie.reviews) + + +class QueryingTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.egypt = MuseumExhibit.objects.create( + exhibit_name="Ancient Egypt", + sections=[ + ExhibitSection( + section_number=1, + artifacts=[ + ArtifactDetail( + name="Ptolemaic Crown", + metadata={ + "origin": "Egypt", + }, + ) + ], + ) + ], + ) + cls.wonders = MuseumExhibit.objects.create( + exhibit_name="Wonders of the Ancient World", + sections=[ + ExhibitSection( + section_number=1, + artifacts=[ + ArtifactDetail( + name="Statue of Zeus", + metadata={"location": "Olympia", "height_m": 12}, + ), + ArtifactDetail( + name="Hanging Gardens", + ), + ], + ), + ExhibitSection( + section_number=2, + artifacts=[ + ArtifactDetail( + name="Lighthouse of Alexandria", + metadata={"height_m": 100, "built": "3rd century BC"}, + ) + ], + ), + ], + ) + cls.new_descoveries = MuseumExhibit.objects.create( + exhibit_name="New Discoveries", + sections=[ + ExhibitSection( + section_number=2, + artifacts=[ + ArtifactDetail( + name="Lighthouse of Alexandria", + metadata={"height_m": 100, "built": "3rd century BC"}, + ) + ], + ) + ], + ) + cls.lost_empires = MuseumExhibit.objects.create( + exhibit_name="Lost Empires", + main_section=ExhibitSection( + section_number=3, + artifacts=[ + ArtifactDetail( + name="Bronze Statue", + metadata={"origin": "Pergamon"}, + restorations=[ + RestorationRecord( + date=date(1998, 4, 15), + restored_by="Zacarias", + ), + RestorationRecord( + date=date(2010, 7, 22), + restored_by="Vicente", + ), + ], + last_restoration=RestorationRecord( + date=date(2010, 7, 22), + restored_by="Monzon", + ), + ) + ], + ), + ) + cls.egypt_tour = Tour.objects.create(guide="Amira", exhibit=cls.egypt) + cls.wonders_tour = Tour.objects.create(guide="Carlos", exhibit=cls.wonders) + cls.lost_tour = Tour.objects.create(guide="Yelena", exhibit=cls.lost_empires) + + def test_filter_with_field(self): + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__section_number=1), [self.egypt, self.wonders] + ) + + def test_filter_with_embeddedfield_path(self): + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__0__section_number=1), + [self.egypt, self.wonders], + ) + + def test_filter_with_embeddedfield_array_path(self): + self.assertCountEqual( + MuseumExhibit.objects.filter( + main_section__artifacts__restorations__0__restored_by="Zacarias" + ), + [self.lost_empires], + ) + + def test_filter_unsupported_lookups(self): + # handle the unsupported lookups as key in a keytransform + for lookup in ["contains", "range"]: + kwargs = {f"main_section__artifacts__metadata__origin__{lookup}": ["Pergamon", "Egypt"]} + with CaptureQueriesContext(connection) as captured_queries: + self.assertCountEqual(MuseumExhibit.objects.filter(**kwargs), []) + self.assertIn(f"'field': '{lookup}'", captured_queries[0]["sql"]) + + def test_all_filter(self): + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__section_number__all=[1, 2]), [self.wonders] + ) + + def test_contained_by(self): + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__section_number__contained_by=[1, 2, 3]), + [self.egypt, self.new_descoveries, self.wonders, self.lost_empires], + ) + + def test_len_filter(self): + self.assertCountEqual(MuseumExhibit.objects.filter(sections__len=10), []) + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__len=1), + [self.egypt, self.new_descoveries], + ) + # Nested EMF + self.assertCountEqual( + MuseumExhibit.objects.filter(main_section__artifacts__len=1), [self.lost_empires] + ) + self.assertCountEqual(MuseumExhibit.objects.filter(main_section__artifacts__len=2), []) + # Nested Indexed Array + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__0__artifacts__len=2), [self.wonders] + ) + self.assertCountEqual(MuseumExhibit.objects.filter(sections__0__artifacts__len=0), []) + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__1__artifacts__len=1), [self.wonders] + ) + + def test_in_filter(self): + self.assertCountEqual(MuseumExhibit.objects.filter(sections__section_number__in=[10]), []) + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__section_number__in=[1]), + [self.egypt, self.wonders], + ) + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__section_number__in=[2]), + [self.new_descoveries, self.wonders], + ) + self.assertCountEqual(MuseumExhibit.objects.filter(sections__section_number__in=[3]), []) + + def test_iexact_filter(self): + self.assertCountEqual( + MuseumExhibit.objects.filter( + sections__artifacts__0__name__iexact="lightHOuse of aLexandriA" + ), + [self.new_descoveries, self.wonders], + ) + + def test_gt_filter(self): + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__section_number__gt=1), + [self.new_descoveries, self.wonders], + ) + + def test_gte_filter(self): + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__section_number__gte=1), + [self.egypt, self.new_descoveries, self.wonders], + ) + + def test_lt_filter(self): + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__section_number__lt=2), [self.egypt, self.wonders] + ) + + def test_lte_filter(self): + self.assertCountEqual( + MuseumExhibit.objects.filter(sections__section_number__lte=2), + [self.egypt, self.wonders, self.new_descoveries], + ) + + def test_query_array_not_allowed(self): + msg = ( + "Cannot apply this lookup directly to EmbeddedModelArrayField. " + "Try querying one of its embedded fields instead." + ) + with self.assertRaisesMessage(ValueError, msg): + MuseumExhibit.objects.filter(sections=10).first() + + with self.assertRaisesMessage(ValueError, msg): + MuseumExhibit.objects.filter(sections__0_1=10).first() + + def test_missing_field(self): + msg = "ExhibitSection has no field named 'section'" + with self.assertRaisesMessage(FieldDoesNotExist, msg): + MuseumExhibit.objects.filter(sections__section__in=[10]).first() + + def test_missing_lookup(self): + msg = "Unsupported lookup 'return' for EmbeddedModelArrayField of 'IntegerField'" + with self.assertRaisesMessage(FieldDoesNotExist, msg): + MuseumExhibit.objects.filter(sections__section_number__return=3) + + def test_missing_operation(self): + msg = "Unsupported lookup 'rage' for EmbeddedModelArrayField of 'IntegerField'" + with self.assertRaisesMessage(FieldDoesNotExist, msg): + MuseumExhibit.objects.filter(sections__section_number__rage=[10]) + + def test_missing_lookup_suggestions(self): + msg = ( + "Unsupported lookup 'ltee' for EmbeddedModelArrayField of 'IntegerField', " + "perhaps you meant lte or lt?" + ) + with self.assertRaisesMessage(FieldDoesNotExist, msg): + MuseumExhibit.objects.filter(sections__section_number__ltee=3) + + def test_double_emfarray_transform(self): + msg = "Cannot perform multiple levels of array traversal in a query." + with self.assertRaisesMessage(ValueError, msg): + MuseumExhibit.objects.filter(sections__artifacts__name="") + + def test_slice(self): + self.assertSequenceEqual( + MuseumExhibit.objects.filter(sections__0_1__section_number=2), [self.new_descoveries] + ) + + def test_foreign_field_exact(self): + qs = Tour.objects.filter(exhibit__sections__section_number=1) + self.assertCountEqual(qs, [self.egypt_tour, self.wonders_tour]) + + def test_foreign_field_with_slice(self): + # Only wonders exhibit has exactly two sections, and this slice matches first two + qs = Tour.objects.filter(exhibit__sections__0_2__section_number__all=[1, 2]) + self.assertEqual(list(qs), [self.wonders_tour]) + + +@isolate_apps("model_fields_") +class CheckTests(SimpleTestCase): + def test_no_relational_fields(self): + class Target(EmbeddedModel): + key = models.ForeignKey("MyModel", models.CASCADE) + + class MyModel(models.Model): + field = EmbeddedModelArrayField(Target) + + errors = MyModel().check() + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, "django_mongodb_backend.array.E001") + msg = errors[0].msg + self.assertEqual( + msg, + "Base field for array has errors:\n " + "Embedded models cannot have relational fields (Target.key is a ForeignKey). " + "(django_mongodb_backend.embedded_model.E001)", + ) + + def test_embedded_model_subclass(self): + class Target(models.Model): + pass + + class MyModel(models.Model): + field = EmbeddedModelArrayField(Target) + + errors = MyModel().check() + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, "django_mongodb_backend.array.E001") + msg = errors[0].msg + self.assertEqual( + msg, + "Base field for array has errors:\n " + "Embedded models must be a subclass of " + "django_mongodb_backend.models.EmbeddedModel. " + "(django_mongodb_backend.embedded_model.E002)", + ) diff --git a/tests/model_forms_/forms.py b/tests/model_forms_/forms.py index 1ac7b92a..ff5eb413 100644 --- a/tests/model_forms_/forms.py +++ b/tests/model_forms_/forms.py @@ -1,6 +1,6 @@ from django import forms -from .models import Author, Book +from .models import Author, Book, Movie, Store class AuthorForm(forms.ModelForm): @@ -13,3 +13,15 @@ class BookForm(forms.ModelForm): class Meta: fields = "__all__" model = Book + + +class MovieForm(forms.ModelForm): + class Meta: + fields = "__all__" + model = Movie + + +class StoreForm(forms.ModelForm): + class Meta: + fields = "__all__" + model = Store diff --git a/tests/model_forms_/models.py b/tests/model_forms_/models.py index 4e7cd0d6..8a1dfe14 100644 --- a/tests/model_forms_/models.py +++ b/tests/model_forms_/models.py @@ -1,6 +1,6 @@ from django.db import models -from django_mongodb_backend.fields import EmbeddedModelField +from django_mongodb_backend.fields import EmbeddedModelArrayField, EmbeddedModelField from django_mongodb_backend.models import EmbeddedModel @@ -26,3 +26,31 @@ class Publisher(EmbeddedModel): class Book(models.Model): title = models.CharField(max_length=50) publisher = EmbeddedModelField(Publisher) + + +# EmbeddedModelArrayField +class Review(EmbeddedModel): + title = models.CharField(max_length=255) + rating = models.IntegerField() + + def __str__(self): + return self.title + + +class Movie(models.Model): + title = models.CharField(max_length=255) + reviews = EmbeddedModelArrayField(Review) + featured_reviews = EmbeddedModelArrayField(Review, null=True, blank=True, max_size=2) + + def __str__(self): + return self.title + + +class Product(EmbeddedModel): + name = models.CharField(max_length=255) + reviews = EmbeddedModelArrayField(Review) + + +class Store(models.Model): + name = models.CharField(max_length=255) + products = EmbeddedModelArrayField(Product) diff --git a/tests/model_forms_/test_embedded_model_array.py b/tests/model_forms_/test_embedded_model_array.py new file mode 100644 index 00000000..0654f4ae --- /dev/null +++ b/tests/model_forms_/test_embedded_model_array.py @@ -0,0 +1,1240 @@ +from django.test import TestCase + +from django_mongodb_backend.forms import EmbeddedModelArrayField + +from .forms import MovieForm, StoreForm +from .models import Movie, Product, Review, Store + + +class ModelFormTests(TestCase): + def test_add_another(self): + movie = Movie.objects.create( + title="Lion King", + reviews=[Review(title="Great!", rating=10)], + ) + data = { + "title": "Lion King", + "reviews-0-title": "Great!", + "reviews-0-rating": "10", + "reviews-1-title": "Not so great", + "reviews-1-rating": "1", + "reviews-TOTAL_FORMS": 2, + "reviews-INITIAL_FORMS": 1, + } + form = MovieForm(data, instance=movie) + self.assertTrue(form.is_valid()) + form.save() + self.assertEqual(form.changed_data, ["reviews"]) + movie.refresh_from_db() + self.assertEqual(len(movie.reviews), 2) + review = movie.reviews[0] + self.assertEqual(review.title, "Great!") + self.assertEqual(review.rating, 10) + review = movie.reviews[1] + self.assertEqual(review.title, "Not so great") + self.assertEqual(review.rating, 1) + + def test_no_change(self): + movie = Movie.objects.create( + title="Lion King", + reviews=[Review(title="Great!", rating=10)], + ) + data = { + "title": "Lion King", + "reviews-0-title": "Great!", + "reviews-0-rating": "10", + "reviews-TOTAL_FORMS": 2, + "reviews-INITIAL_FORMS": 1, + } + form = MovieForm(data, instance=movie) + self.assertTrue(form.is_valid()) + form.save() + self.assertEqual(form.changed_data, []) + movie.refresh_from_db() + self.assertEqual(len(movie.reviews), 1) + review = movie.reviews[0] + self.assertEqual(review.title, "Great!") + self.assertEqual(review.rating, 10) + + def test_update(self): + movie = Movie.objects.create( + title="Lion King", + reviews=[Review(title="Great!", rating=10)], + ) + data = { + "title": "Lion King", + "reviews-0-title": "Not so great", + "reviews-0-rating": "1", + "reviews-TOTAL_FORMS": 2, + "reviews-INITIAL_FORMS": 1, + } + form = MovieForm(data, instance=movie) + self.assertTrue(form.is_valid()) + form.save() + self.assertEqual(form.changed_data, ["reviews"]) + movie.refresh_from_db() + self.assertEqual(len(movie.reviews), 1) + review = movie.reviews[0] + self.assertEqual(review.title, "Not so great") + self.assertEqual(review.rating, 1) + + def test_some_missing_data(self): + movie = Movie.objects.create( + title="Lion King", + reviews=[Review(title="Great!", rating=10)], + ) + data = { + "title": "Lion King", + "reviews-0-title": "", + "reviews-0-rating": "1", + "reviews-TOTAL_FORMS": 2, + "reviews-INITIAL_FORMS": 1, + } + form = MovieForm(data, instance=movie) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors["reviews"], ["This field is required."]) + + def test_invalid_field_data(self): + """A field's data (rating) is invalid.""" + movie = Movie.objects.create( + title="Lion King", + reviews=[Review(title="Great!", rating=10)], + ) + data = { + "title": "Lion King", + "reviews-0-title": "Great!", + "reviews-0-rating": "not a number", + "reviews-TOTAL_FORMS": 2, + "reviews-INITIAL_FORMS": 1, + } + form = MovieForm(data, instance=movie) + self.assertFalse(form.is_valid()) + self.assertEqual( + form.errors["reviews"], + ["Enter a whole number."], + ) + + def test_all_missing_data(self): + movie = Movie.objects.create( + title="Lion King", + reviews=[Review(title="Great!", rating=10)], + ) + data = { + "title": "Lion King", + "reviews-0-title": "", + "reviews-0-rating": "", + "reviews-TOTAL_FORMS": 2, + "reviews-INITIAL_FORMS": 1, + } + form = MovieForm(data, instance=movie) + self.assertFalse(form.is_valid()) + self.assertEqual( + form.errors["reviews"], ["This field is required.", "This field is required."] + ) + + def test_delete(self): + movie = Movie.objects.create( + title="Lion King", + reviews=[Review(title="Great!", rating=10), Review(title="Okay", rating=5)], + ) + data = { + "title": "Lion King", + "reviews-0-title": "Not so great", + "reviews-0-rating": "1", + "reviews-0-DELETE": "1", + "reviews-1-title": "Okay", + "reviews-1-rating": "5", + "reviews-1-DELETE": "", + "reviews-TOTAL_FORMS": 3, + "reviews-INITIAL_FORMS": 2, + } + form = MovieForm(data, instance=movie) + self.assertTrue(form.is_valid()) + form.save() + movie.refresh_from_db() + self.assertEqual(len(movie.reviews), 1) + review = movie.reviews[0] + self.assertEqual(review.title, "Okay") + self.assertEqual(review.rating, 5) + + def test_delete_required(self): + movie = Movie.objects.create( + title="Lion King", + reviews=[Review(title="Great!", rating=10)], + ) + data = { + "title": "Lion King", + "reviews-0-title": "Not so great", + "reviews-0-rating": "1", + "reviews-0-DELETE": "1", + "reviews-TOTAL_FORMS": 2, + "reviews-INITIAL_FORMS": 1, + } + form = MovieForm(data, instance=movie) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors["reviews"], ["This field cannot be blank."]) + + def test_max_size(self): + """ + Submitting more than the allowed number of items (three featured + reviews for max_size=2) is prohibited. + """ + movie = Movie.objects.create( + title="Lion King", + reviews=[Review(title="Great!", rating=10)], + featured_reviews=[Review(title="Okay", rating=5)], + ) + data = { + "title": "Lion King", + "reviews-0-title": "Not so great", + "reviews-0-rating": "1", + "reviews-0-DELETE": "", + "reviews-TOTAL_FORMS": 2, + "reviews-INITIAL_FORMS": 1, + "featured_reviews-0-title": "Okay", + "featured_reviews-0-rating": "5", + "featured_reviews-1-title": "Okay", + "featured_reviews-1-rating": "5", + "featured_reviews-2-title": "Okay", + "featured_reviews-2-rating": "5", + "featured_reviews-TOTAL_FORMS": 3, + "featured_reviews-INITIAL_FORMS": 0, + } + form = MovieForm(data, instance=movie) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors["featured_reviews"], ["Please submit at most 2 forms."]) + + def test_nullable_field(self): + """A nullable field is emptied if all rows are deleted.""" + movie = Movie.objects.create( + title="Lion King", + reviews=[Review(title="Great!", rating=10)], + featured_reviews=[Review(title="Okay", rating=5)], + ) + data = { + "title": "Lion King", + "reviews-0-title": "Not so great", + "reviews-0-rating": "1", + "reviews-0-DELETE": "", + "reviews-TOTAL_FORMS": 2, + "reviews-INITIAL_FORMS": 1, + "featured_reviews-0-title": "Okay", + "featured_reviews-0-rating": "5", + "featured_reviews-0-DELETE": "1", + "featured_reviews-TOTAL_FORMS": 2, + "featured_reviews-INITIAL_FORMS": 1, + } + form = MovieForm(data, instance=movie) + self.assertTrue(form.is_valid()) + form.save() + movie.refresh_from_db() + self.assertEqual(len(movie.featured_reviews), 0) + + def test_rendering(self): + form = MovieForm() + self.assertHTMLEqual( + str(form.fields["reviews"].get_bound_field(form, "reviews").label_tag()), + '', + ) + self.assertHTMLEqual( + str(form.fields["reviews"].get_bound_field(form, "reviews")), + """ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +""", + ) + + def test_rendering_initial(self): + movie = Movie.objects.create( + title="Lion King", + reviews=[Review(title="Great!", rating=10)], + ) + form = MovieForm(instance=movie) + self.assertHTMLEqual( + str(form.fields["reviews"].get_bound_field(form, "reviews")), + """ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +""", + ) + + def test_extra_forms(self): + """The extra_forms argument specifies the number of extra forms.""" + + class ExtraMovieForm(MovieForm): + reviews = EmbeddedModelArrayField(Review, prefix="reviews", extra_forms=2) + + form = ExtraMovieForm() + self.assertHTMLEqual( + str(form.fields["reviews"].get_bound_field(form, "reviews")), + """ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ +
+ +
+ +
+ +
+ +
+ +""", + ) + + +class NestedFormTests(TestCase): + def test_update(self): + store = Store.objects.create( + name="Best Buy", + products=[Product(name="TV", reviews=[Review(title="Great", rating=9)])], + ) + data = { + "name": "Best Buy!", + "products-0-name": "TV!", + "products-0-reviews-0-title": "Great!", + "products-0-reviews-0-rating": "9", + "products-TOTAL_FORMS": 3, + "products-INITIAL_FORMS": 1, + "products-0-reviews-TOTAL_FORMS": 3, + "products-0-reviews-INITIAL_FORMS": 1, + } + form = StoreForm(data, instance=store) + self.assertTrue(form.is_valid()) + form.save() + store.refresh_from_db() + self.assertEqual(store.name, "Best Buy!") + self.assertEqual(store.products[0].name, "TV!") + self.assertEqual(store.products[0].reviews[0].title, "Great!") + + def test_delete(self): + """The "Bad" review is deleted.""" + store = Store.objects.create( + name="Best Buy", + products=[ + Product( + name="TV", + reviews=[ + Review(title="Great", rating=9), + Review(title="Bad", rating=1), + ], + ) + ], + ) + data = { + "name": "Best Buy", + "products-0-name": "TV", + "products-0-reviews-0-title": "Great", + "products-0-reviews-0-rating": "9", + "products-0-reviews-1-title": "Bad", + "products-0-reviews-1-rating": "1", + "products-0-reviews-1-DELETE": "1", + "products-TOTAL_FORMS": 3, + "products-INITIAL_FORMS": 1, + "products-0-reviews-TOTAL_FORMS": 3, + "products-0-reviews-INITIAL_FORMS": 2, + } + form = StoreForm(data, instance=store) + self.assertTrue(form.is_valid()) + form.save() + store.refresh_from_db() + self.assertEqual(len(store.products[0].reviews), 1) + self.assertEqual(store.products[0].reviews[0].title, "Great") + + def test_some_missing_data(self): + """A required field (Review.title) is missing.""" + store = Store.objects.create( + name="Best Buy", + products=[Product(name="TV", reviews=[Review(title="Great", rating=9)])], + ) + data = { + "name": "Best Buy!", + "products-0-name": "TV!", + "products-0-reviews-0-title": "", + "products-0-reviews-0-rating": "9", + "products-TOTAL_FORMS": 3, + "products-INITIAL_FORMS": 1, + "products-0-reviews-TOTAL_FORMS": 3, + "products-0-reviews-INITIAL_FORMS": 1, + "products-1-reviews-TOTAL_FORMS": 3, + "products-1-reviews-INITIAL_FORMS": 0, + "products-2-reviews-TOTAL_FORMS": 3, + "products-2-reviews-INITIAL_FORMS": 0, + } + form = StoreForm(data, instance=store) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors["products"], ["This field is required."]) + self.assertHTMLEqual( + str(form), + """ +
+ + +
+
+ +
    +
  • This field is required.
  • +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+
    +
  • This field is required.
  • +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
    +
  • This field is required.
  • +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ + + + +
+ +
+ + + + +
""", + ) + + def test_invalid_field_data(self): + """A field's data (Review.title) is too long.""" + data = { + "name": "Best Buy!", + "products-0-name": "TV!", + "products-0-reviews-0-title": "A" * 256, + "products-0-reviews-0-rating": "9", + "products-TOTAL_FORMS": 3, + "products-INITIAL_FORMS": 1, + "products-0-reviews-TOTAL_FORMS": 3, + "products-0-reviews-INITIAL_FORMS": 1, + "products-1-reviews-TOTAL_FORMS": 3, + "products-1-reviews-INITIAL_FORMS": 0, + "products-2-reviews-TOTAL_FORMS": 3, + "products-2-reviews-INITIAL_FORMS": 0, + } + form = StoreForm(data) + self.assertFalse(form.is_valid()) + self.assertEqual( + form.errors["products"], + ["Ensure this value has at most 255 characters (it has 256)."], + ) + + def test_all_missing_data(self): + """ + An embedded model array field (reviews) with all data missing triggers + a required error. + """ + store = Store.objects.create( + name="Best Buy", + products=[Product(name="TV", reviews=[Review(title="Great", rating=9)])], + ) + data = { + "name": "Best Buy!", + "products-0-name": "TV!", + "products-0-reviews-0-title": "", + "products-0-reviews-0-rating": "", + "products-TOTAL_FORMS": 3, + "products-INITIAL_FORMS": 1, + "products-0-reviews-TOTAL_FORMS": 3, + "products-0-reviews-INITIAL_FORMS": 1, + "products-1-reviews-TOTAL_FORMS": 3, + "products-1-reviews-INITIAL_FORMS": 0, + "products-2-reviews-TOTAL_FORMS": 3, + "products-2-reviews-INITIAL_FORMS": 0, + } + form = StoreForm(data, instance=store) + self.assertFalse(form.is_valid()) + self.assertEqual( + form.errors["products"], ["This field is required.", "This field is required."] + ) + + def test_rendering(self): + form = StoreForm() + self.assertHTMLEqual( + str(form.fields["products"].get_bound_field(form, "products")), + """ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ + + +
+ +
+ + +""", + )