Skip to content

Commit 30cecad

Browse files
committed
INTPYTHON-624 Add PolymorphicEmbeddedModelField
1 parent 4dba2ed commit 30cecad

File tree

8 files changed

+600
-0
lines changed

8 files changed

+600
-0
lines changed

django_mongodb_backend/fields/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .embedded_model_array import EmbeddedModelArrayField
66
from .json import register_json_field
77
from .objectid import ObjectIdField
8+
from .polymorphic_embedded_model import PolymorphicEmbeddedModelField
89

910
__all__ = [
1011
"register_fields",
@@ -13,6 +14,7 @@
1314
"EmbeddedModelField",
1415
"ObjectIdAutoField",
1516
"ObjectIdField",
17+
"PolymorphicEmbeddedModelField",
1618
]
1719

1820

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import contextlib
2+
import difflib
3+
4+
from django.core import checks
5+
from django.core.exceptions import FieldDoesNotExist, ValidationError
6+
from django.db import models
7+
from django.db.models.fields.related import lazy_related_operation
8+
from django.db.models.lookups import Transform
9+
10+
11+
class PolymorphicEmbeddedModelField(models.Field):
12+
"""
13+
Field that stores a model instance that may be of more than one model
14+
class.
15+
"""
16+
17+
value_is_model_instance = True
18+
19+
def __init__(self, embedded_models, *args, **kwargs):
20+
"""
21+
`embedded_models` is a list of possible model classes to be stored.
22+
Like other relational fields, each model may also be passed as a
23+
string.
24+
"""
25+
self.embedded_models = embedded_models
26+
kwargs["editable"] = False
27+
super().__init__(*args, **kwargs)
28+
29+
def db_type(self, connection):
30+
return "embeddedDocuments"
31+
32+
def check(self, **kwargs):
33+
from ..models import EmbeddedModel
34+
35+
errors = super().check(**kwargs)
36+
for model in self.embedded_models:
37+
if not issubclass(model, EmbeddedModel):
38+
return [
39+
checks.Error(
40+
"Embedded models must be a subclass of "
41+
"django_mongodb_backend.models.EmbeddedModel.",
42+
obj=self,
43+
hint="{model} doesn't subclass EmbeddedModel.",
44+
id="django_mongodb_backend.embedded_model.E002",
45+
)
46+
]
47+
for field in model._meta.fields:
48+
if field.remote_field:
49+
errors.append(
50+
checks.Error(
51+
"Embedded models cannot have relational fields "
52+
f"({model().__class__.__name__}.{field.name} "
53+
f"is a {field.__class__.__name__}).",
54+
obj=self,
55+
id="django_mongodb_backend.embedded_model.E001",
56+
)
57+
)
58+
return errors
59+
60+
def deconstruct(self):
61+
name, path, args, kwargs = super().deconstruct()
62+
if path.startswith("django_mongodb_backend.fields.polymorphic_embedded_model"):
63+
path = path.replace(
64+
"django_mongodb_backend.fields.polymorphic_embedded_model",
65+
"django_mongodb_backend.fields",
66+
)
67+
kwargs["embedded_models"] = self.embedded_models
68+
del kwargs["editable"]
69+
return name, path, args, kwargs
70+
71+
def get_internal_type(self):
72+
return "PolymorphicEmbeddedModelField"
73+
74+
def _set_model(self, model):
75+
"""
76+
Resolve embedded model classes once the field knows the model it
77+
belongs to. If any of the items in __init__()'s embedded_models
78+
argument are strings, resolve each to the actual model class,
79+
similar to relation fields.
80+
"""
81+
self._model = model
82+
if model is not None:
83+
for embedded_model in self.embedded_models:
84+
if isinstance(embedded_model, str):
85+
86+
def _resolve_lookup(_, *resolved_models):
87+
self.embedded_models = resolved_models
88+
89+
lazy_related_operation(_resolve_lookup, model, *self.embedded_models)
90+
91+
model = property(lambda self: self._model, _set_model)
92+
93+
def from_db_value(self, value, expression, connection):
94+
return self.to_python(value)
95+
96+
def to_python(self, value):
97+
"""
98+
Pass embedded model fields' values through each field's to_python() and
99+
reinstantiate the embedded instance.
100+
"""
101+
if value is None:
102+
return None
103+
if not isinstance(value, dict):
104+
return value
105+
model_class = self._get_model_from_label(value.pop("_label"))
106+
instance = model_class(
107+
**{
108+
field.attname: field.to_python(value[field.attname])
109+
for field in model_class._meta.fields
110+
if field.attname in value
111+
}
112+
)
113+
instance._state.adding = False
114+
return instance
115+
116+
def get_db_prep_save(self, embedded_instance, connection):
117+
"""
118+
Apply pre_save() and get_db_prep_save() of embedded instance fields and
119+
create the {field: value} dict to be saved.
120+
"""
121+
if embedded_instance is None:
122+
return None
123+
if not isinstance(embedded_instance, self.embedded_models):
124+
raise TypeError(
125+
f"Expected instance of type {self.embedded_models!r}, not "
126+
f"{type(embedded_instance)!r}."
127+
)
128+
field_values = {}
129+
add = embedded_instance._state.adding
130+
for field in embedded_instance._meta.fields:
131+
value = field.get_db_prep_save(
132+
field.pre_save(embedded_instance, add), connection=connection
133+
)
134+
# Exclude unset primary keys (e.g. {'id': None}).
135+
if field.primary_key and value is None:
136+
continue
137+
field_values[field.attname] = value
138+
field_values["_label"] = embedded_instance._meta.label
139+
# This instance will exist in the database soon.
140+
embedded_instance._state.adding = False
141+
return field_values
142+
143+
def get_transform(self, name):
144+
transform = super().get_transform(name)
145+
if transform:
146+
return transform
147+
field = None
148+
for model in self.embedded_models:
149+
with contextlib.suppress(FieldDoesNotExist):
150+
field = model._meta.get_field(name)
151+
if field is None:
152+
raise FieldDoesNotExist(
153+
f"The models of field '{self.name}' have no field named '{name}'."
154+
)
155+
return KeyTransformFactory(name, field)
156+
157+
def validate(self, value, model_instance):
158+
super().validate(value, model_instance)
159+
if not isinstance(value, self.embedded_models):
160+
raise ValidationError(
161+
f"Expected instance of type {self.embedded_models!r}, not {type(value)!r}."
162+
)
163+
for field in value._meta.fields:
164+
attname = field.attname
165+
field.validate(getattr(value, attname), model_instance)
166+
167+
def formfield(self, **kwargs):
168+
raise NotImplementedError("PolymorphicEmbeddedModelField does not support forms.")
169+
170+
def _get_model_from_label(self, label):
171+
return {model._meta.label: model for model in self.embedded_models}[label]
172+
173+
174+
class KeyTransform(Transform):
175+
def __init__(self, key_name, ref_field, *args, **kwargs):
176+
super().__init__(*args, **kwargs)
177+
self.key_name = str(key_name)
178+
self.ref_field = ref_field
179+
180+
def get_lookup(self, name):
181+
return self.ref_field.get_lookup(name)
182+
183+
def get_transform(self, name):
184+
"""
185+
Validate that `name` is either a field of an embedded model or a
186+
lookup on an embedded model's field.
187+
"""
188+
if transform := self.ref_field.get_transform(name):
189+
return transform
190+
suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups())
191+
if suggested_lookups:
192+
suggested_lookups = " or ".join(suggested_lookups)
193+
suggestion = f", perhaps you meant {suggested_lookups}?"
194+
else:
195+
suggestion = "."
196+
raise FieldDoesNotExist(
197+
f"Unsupported lookup '{name}' for "
198+
f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'"
199+
f"{suggestion}"
200+
)
201+
202+
def as_mql(self, compiler, connection):
203+
previous = self
204+
key_transforms = []
205+
while isinstance(previous, KeyTransform):
206+
key_transforms.insert(0, previous.key_name)
207+
previous = previous.lhs
208+
mql = previous.as_mql(compiler, connection)
209+
for key in key_transforms:
210+
mql = {"$getField": {"input": mql, "field": key}}
211+
return mql
212+
213+
@property
214+
def output_field(self):
215+
return self.ref_field
216+
217+
218+
class KeyTransformFactory:
219+
def __init__(self, key_name, ref_field):
220+
self.key_name = key_name
221+
self.ref_field = ref_field
222+
223+
def __call__(self, *args, **kwargs):
224+
return KeyTransform(self.key_name, self.ref_field, *args, **kwargs)

django_mongodb_backend/operations.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def get_db_converters(self, expression):
122122
)
123123
elif internal_type == "JSONField":
124124
converters.append(self.convert_jsonfield_value)
125+
elif internal_type == "PolymorphicEmbeddedModelField":
126+
converters.append(self.convert_polymorphicembeddedmodelfield_value)
125127
elif internal_type == "TimeField":
126128
# Trunc(... output_field="TimeField") values must remain datetime
127129
# until Trunc.convert_value() so they can be converted from UTC
@@ -182,6 +184,19 @@ def convert_jsonfield_value(self, value, expression, connection):
182184
"""
183185
return json.dumps(value)
184186

187+
def convert_polymorphicembeddedmodelfield_value(self, value, expression, connection):
188+
if value is not None:
189+
model_class = expression.output_field._get_model_from_label(value["_label"])
190+
# Apply database converters to each field of the embedded model.
191+
for field in model_class._meta.fields:
192+
field_expr = Expression(output_field=field)
193+
converters = connection.ops.get_db_converters(
194+
field_expr
195+
) + field_expr.get_db_converters(connection)
196+
for converter in converters:
197+
value[field.attname] = converter(value[field.attname], field_expr, connection)
198+
return value
199+
185200
def convert_timefield_value(self, value, expression, connection):
186201
if value is not None:
187202
value = value.time()

docs/source/ref/models/fields.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,32 @@ These indexes use 0-based indexing.
313313
.. class:: ObjectIdField
314314

315315
Stores an :class:`~bson.objectid.ObjectId`.
316+
317+
``PolymorphicEmbeddedModelField``
318+
---------------------------------
319+
320+
.. class:: PolymorphicEmbeddedModelField(embedded_models, **kwargs)
321+
322+
.. versionadded:: 5.2.0b2
323+
324+
Stores a model of type ``embedded_models``.
325+
326+
.. attribute:: embedded_models
327+
328+
This is a required argument that specifies a list of model classes
329+
that may be embedded.
330+
331+
Each model class reference works just like
332+
:attr:`.EmbeddedModelField.embedded_model`.
333+
334+
See :ref:`the embedded model topic guide
335+
<polymorphic-embedded-model-field-example>` for more details and examples.
336+
337+
.. admonition:: Migrations support is limited
338+
339+
:djadmin:`makemigrations` does not yet detect changes to embedded models,
340+
nor does it create indexes or constraints for embedded models.
341+
342+
.. admonition:: Forms are not supported
343+
344+
``PolymorphicEmbeddedModelField``\s don't appear in model forms.

docs/source/releases/5.2.x.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ New features
1414
- Added the ``options`` parameter to
1515
:func:`~django_mongodb_backend.utils.parse_uri`.
1616
- Added support for :ref:`database transactions <transactions>`.
17+
- Added :class:`~.fields.PolymorphicEmbeddedModelArrayField` for storing a
18+
model instance that may be of more than one model class.
1719

1820
5.2.0 beta 1
1921
============

docs/source/topics/embedded-models.rst

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,84 @@ For example, if the ``Tag`` model had an ``EmbeddedModelArrayField`` called
181181
>>> Post.objects.filter(tags__colors__name="blue")
182182
...
183183
ValueError: Cannot perform multiple levels of array traversal in a query.
184+
185+
.. _polymorphic-embedded-model-field-example:
186+
187+
``PolymorphicEmbeddedModelField``
188+
---------------------------------
189+
190+
The basics
191+
~~~~~~~~~~
192+
193+
Let's consider this example::
194+
195+
from django.db import models
196+
197+
from django_mongodb_backend.fields import PolymorphicEmbeddedModelField
198+
from django_mongodb_backend.models import EmbeddedModel
199+
200+
201+
class Person(models.Model):
202+
name = models.CharField(max_length=255)
203+
pet = PolymorphicEmbeddedModelField(["Cat", "Dog"])
204+
205+
def __str__(self):
206+
return self.name
207+
208+
209+
class Cat(EmbeddedModel):
210+
name = models.CharField(max_length=255)
211+
purrs = models.BooleanField(default=True)
212+
213+
def __str__(self):
214+
return self.name
215+
216+
217+
class Dog(EmbeddedModel):
218+
name = models.CharField(max_length=255)
219+
barks = models.BooleanField(default=True)
220+
221+
def __str__(self):
222+
return self.name
223+
224+
225+
The API is similar to that of Django's relational fields::
226+
227+
>>> bob = Person.objects.create(name="Bob", pet=Dog(name="Woofer"))
228+
>>> bob.pet
229+
<Dog: Woofer>
230+
>>> bob.pet.name
231+
'Woofer'
232+
>>> bob = Person.objects.create(name="Fred", pet=Cat(name="Pheobe"))
233+
234+
Represented in BSON, the person structures looks like this:
235+
236+
.. code-block:: js
237+
238+
{
239+
_id: ObjectId('685da4895e42adade0c8db29'),
240+
name: 'Bob',
241+
pet: { name: 'Woofer', barks: true, _label: 'myapp.Dog' }
242+
},
243+
{
244+
_id: ObjectId('685da4925e42adade0c8db2a'),
245+
name: 'Fred',
246+
pet: { name: 'Pheobe', purrs: true, _label: 'myapp.Cat' }
247+
}
248+
249+
The ``_label`` field contains the model's
250+
:attr:`~django.db.models.Options.label`.
251+
252+
Querying ``PolymorphicEmbeddedModelField``
253+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
254+
255+
You can query into a polymorphic embedded model field using the same double
256+
underscore syntax as relational fields. For example, to retrieve all people
257+
who have a pet named "Lassy"::
258+
259+
>>> Person.objects.filter(pet__name="Lassy")
260+
261+
You can also filter on fields that aren't shared among the embedded models. For
262+
example, if you filter on ``barks``, you'll only get back people with dogs::
263+
264+
>>> Person.objects.filter(pet__barks=True)

0 commit comments

Comments
 (0)