Skip to content

Commit 5257688

Browse files
committed
INTPYTHON-624 Add PolymorphicEmbeddedModelField
1 parent 84c8299 commit 5257688

File tree

7 files changed

+507
-1
lines changed

7 files changed

+507
-1
lines changed

django_mongodb_backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ def execute_sql(self, result_type):
746746
elif hasattr(value, "prepare_database_save"):
747747
if field.remote_field:
748748
value = value.prepare_database_save(field)
749-
elif not hasattr(field, "embedded_model"):
749+
elif not (hasattr(field, "embedded_model") or hasattr(field, "embedded_models")):
750750
raise TypeError(
751751
f"Tried to update field {field} with a model "
752752
f"instance, {value!r}. Use a value compatible with "

django_mongodb_backend/fields/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .embedded_model_array import EmbeddedModelArrayField
66
from .json import register_json_field
77
from .objectid import ObjectIdField
8+
from .polymorphic_embedded_model import PolymorphicEmbeddedModelField
89

910
__all__ = [
1011
"register_fields",
@@ -13,6 +14,7 @@
1314
"EmbeddedModelField",
1415
"ObjectIdAutoField",
1516
"ObjectIdField",
17+
"PolymorphicEmbeddedModelField",
1618
]
1719

1820

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
import contextlib
2+
import difflib
3+
4+
from django.core import checks
5+
from django.core.exceptions import FieldDoesNotExist, ValidationError
6+
from django.db import models
7+
from django.db.models.fields.related import lazy_related_operation
8+
from django.db.models.lookups import Transform
9+
10+
11+
class PolymorphicEmbeddedModelField(models.Field):
12+
"""Field that stores a model instance."""
13+
14+
def __init__(self, embedded_models, *args, **kwargs):
15+
"""
16+
`embedded_models` is a list of possible model classes to be stored.
17+
Like other relational fields, each model may also be passed as a
18+
string.
19+
"""
20+
self.embedded_models = embedded_models
21+
super().__init__(*args, **kwargs)
22+
23+
def db_type(self, connection):
24+
return "embeddedDocuments"
25+
26+
def check(self, **kwargs):
27+
from ..models import EmbeddedModel
28+
29+
errors = super().check(**kwargs)
30+
for model in self.embedded_models:
31+
if not issubclass(model, EmbeddedModel):
32+
return [
33+
checks.Error(
34+
"Embedded models must be a subclass of "
35+
"django_mongodb_backend.models.EmbeddedModel.",
36+
obj=self,
37+
hint="{model} doesn't subclass EmbeddedModel.",
38+
id="django_mongodb_backend.embedded_model.E002",
39+
)
40+
]
41+
for field in model._meta.fields:
42+
if field.remote_field:
43+
errors.append(
44+
checks.Error(
45+
"Embedded models cannot have relational fields "
46+
f"({model().__class__.__name__}.{field.name} "
47+
f"is a {field.__class__.__name__}).",
48+
obj=self,
49+
id="django_mongodb_backend.embedded_model.E001",
50+
)
51+
)
52+
return errors
53+
54+
def deconstruct(self):
55+
name, path, args, kwargs = super().deconstruct()
56+
if path.startswith("django_mongodb_backend.fields.polymorphic_embedded_model"):
57+
path = path.replace(
58+
"django_mongodb_backend.fields.polymorphic_embedded_model",
59+
"django_mongodb_backend.fields",
60+
)
61+
kwargs["embedded_models"] = self.embedded_models
62+
return name, path, args, kwargs
63+
64+
def get_internal_type(self):
65+
return "PolymorphicEmbeddedModelField"
66+
67+
def _set_model(self, model):
68+
"""
69+
Resolve embedded model classes once the field knows the model it
70+
belongs to. If any of the items in __init__()'s embedded_models
71+
argument are strings, resolve each to the actual model class,
72+
similar to relation fields.
73+
"""
74+
self._model = model
75+
if model is not None:
76+
for embedded_model in self.embedded_models:
77+
if isinstance(embedded_model, str):
78+
79+
def _resolve_lookup(_, *resolved_models):
80+
self.embedded_models = resolved_models
81+
82+
lazy_related_operation(_resolve_lookup, model, *self.embedded_models)
83+
84+
model = property(lambda self: self._model, _set_model)
85+
86+
def from_db_value(self, value, expression, connection):
87+
return self.to_python(value)
88+
89+
def to_python(self, value):
90+
"""
91+
Pass embedded model fields' values through each field's to_python() and
92+
reinstantiate the embedded instance.
93+
"""
94+
if value is None:
95+
return None
96+
if not isinstance(value, dict):
97+
return value
98+
model_class = self._get_model_from_label(value.pop("_label"))
99+
instance = model_class(
100+
**{
101+
field.attname: field.to_python(value[field.attname])
102+
for field in model_class._meta.fields
103+
if field.attname in value
104+
}
105+
)
106+
instance._state.adding = False
107+
return instance
108+
109+
def get_db_prep_save(self, embedded_instance, connection):
110+
"""
111+
Apply pre_save() and get_db_prep_save() of embedded instance fields and
112+
create the {field: value} dict to be saved.
113+
"""
114+
if embedded_instance is None:
115+
return None
116+
if not isinstance(embedded_instance, self.embedded_models):
117+
raise TypeError(
118+
f"Expected instance of type {self.embedded_models!r}, not "
119+
f"{type(embedded_instance)!r}."
120+
)
121+
field_values = {}
122+
add = embedded_instance._state.adding
123+
for field in embedded_instance._meta.fields:
124+
value = field.get_db_prep_save(
125+
field.pre_save(embedded_instance, add), connection=connection
126+
)
127+
# Exclude unset primary keys (e.g. {'id': None}).
128+
if field.primary_key and value is None:
129+
continue
130+
field_values[field.attname] = value
131+
field_values["_label"] = embedded_instance._meta.label
132+
# This instance will exist in the database soon.
133+
embedded_instance._state.adding = False
134+
return field_values
135+
136+
def get_transform(self, name):
137+
transform = super().get_transform(name)
138+
if transform:
139+
return transform
140+
field = None
141+
for model in self.embedded_models:
142+
with contextlib.suppress(FieldDoesNotExist):
143+
field = model._meta.get_field(name)
144+
if field is None:
145+
raise FieldDoesNotExist(
146+
f"The models of field '{self.name}' have no field named '{name}'."
147+
)
148+
return KeyTransformFactory(name, field)
149+
150+
def validate(self, value, model_instance):
151+
super().validate(value, model_instance)
152+
if not isinstance(value, self.embedded_models):
153+
raise ValidationError(
154+
f"Expected instance of type {self.embedded_models!r}, not {type(value)!r}."
155+
)
156+
for field in value._meta.fields:
157+
attname = field.attname
158+
field.validate(getattr(value, attname), model_instance)
159+
160+
def formfield(self, **kwargs):
161+
raise NotImplementedError("PolymorphicEmbeddedModelField does not support forms.")
162+
163+
def _get_model_from_label(self, label):
164+
return {model._meta.label: model for model in self.embedded_models}[label]
165+
166+
167+
class KeyTransform(Transform):
168+
def __init__(self, key_name, ref_field, *args, **kwargs):
169+
super().__init__(*args, **kwargs)
170+
self.key_name = str(key_name)
171+
self.ref_field = ref_field
172+
173+
def get_lookup(self, name):
174+
return self.ref_field.get_lookup(name)
175+
176+
def get_transform(self, name):
177+
"""
178+
Validate that `name` is either a field of an embedded model or a
179+
lookup on an embedded model's field.
180+
"""
181+
if transform := self.ref_field.get_transform(name):
182+
return transform
183+
suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups())
184+
if suggested_lookups:
185+
suggested_lookups = " or ".join(suggested_lookups)
186+
suggestion = f", perhaps you meant {suggested_lookups}?"
187+
else:
188+
suggestion = "."
189+
raise FieldDoesNotExist(
190+
f"Unsupported lookup '{name}' for "
191+
f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'"
192+
f"{suggestion}"
193+
)
194+
195+
def as_mql(self, compiler, connection):
196+
previous = self
197+
key_transforms = []
198+
while isinstance(previous, KeyTransform):
199+
key_transforms.insert(0, previous.key_name)
200+
previous = previous.lhs
201+
mql = previous.as_mql(compiler, connection)
202+
for key in key_transforms:
203+
mql = {"$getField": {"input": mql, "field": key}}
204+
return mql
205+
206+
@property
207+
def output_field(self):
208+
return self.ref_field
209+
210+
211+
class KeyTransformFactory:
212+
def __init__(self, key_name, ref_field):
213+
self.key_name = key_name
214+
self.ref_field = ref_field
215+
216+
def __call__(self, *args, **kwargs):
217+
return KeyTransform(self.key_name, self.ref_field, *args, **kwargs)

django_mongodb_backend/operations.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def get_db_converters(self, expression):
122122
)
123123
elif internal_type == "JSONField":
124124
converters.append(self.convert_jsonfield_value)
125+
elif internal_type == "PolymorphicEmbeddedModelField":
126+
converters.append(self.convert_polymorphicembeddedmodelfield_value)
125127
elif internal_type == "TimeField":
126128
# Trunc(... output_field="TimeField") values must remain datetime
127129
# until Trunc.convert_value() so they can be converted from UTC
@@ -182,6 +184,19 @@ def convert_jsonfield_value(self, value, expression, connection):
182184
"""
183185
return json.dumps(value)
184186

187+
def convert_polymorphicembeddedmodelfield_value(self, value, expression, connection):
188+
if value is not None:
189+
model_class = expression.output_field._get_model_from_label(value["_label"])
190+
# Apply database converters to each field of the embedded model.
191+
for field in model_class._meta.fields:
192+
field_expr = Expression(output_field=field)
193+
converters = connection.ops.get_db_converters(
194+
field_expr
195+
) + field_expr.get_db_converters(connection)
196+
for converter in converters:
197+
value[field.attname] = converter(value[field.attname], field_expr, connection)
198+
return value
199+
185200
def convert_timefield_value(self, value, expression, connection):
186201
if value is not None:
187202
value = value.time()

docs/source/ref/models/fields.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,19 @@ These indexes use 0-based indexing.
313313
.. class:: ObjectIdField
314314

315315
Stores an :class:`~bson.objectid.ObjectId`.
316+
317+
``PolymorphicEmbeddedModelField``
318+
---------------------------------
319+
320+
.. class:: PolymorphicEmbeddedModelField(embedded_models, **kwargs)
321+
322+
Stores a model of type ``embedded_models``.
323+
324+
.. attribute:: embedded_models
325+
326+
This is a required argument.
327+
328+
Specifies a list of model classes that may be embedded.
329+
330+
Each model class reference works just like
331+
:attr:`.EmbeddedModelField.embedded_model`.

tests/model_fields_/models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
EmbeddedModelArrayField,
88
EmbeddedModelField,
99
ObjectIdField,
10+
PolymorphicEmbeddedModelField,
1011
)
1112
from django_mongodb_backend.models import EmbeddedModel
1213

@@ -222,3 +223,30 @@ class Tour(models.Model):
222223

223224
def __str__(self):
224225
return f"Tour by {self.guide}"
226+
227+
228+
# PolymorphicEmbeddedModelField
229+
class Person(models.Model):
230+
name = models.CharField(max_length=100)
231+
pet = PolymorphicEmbeddedModelField(("Dog", "Cat"), blank=True, null=True)
232+
233+
def __str__(self):
234+
return self.name
235+
236+
237+
class Dog(EmbeddedModel):
238+
name = models.CharField(max_length=100)
239+
barks = models.BooleanField(default=True)
240+
data = models.JSONField(default=dict)
241+
242+
def __str__(self):
243+
return self.name
244+
245+
246+
class Cat(EmbeddedModel):
247+
name = models.CharField(max_length=100)
248+
purs = models.BooleanField(default=True)
249+
weight = models.DecimalField(max_digits=4, decimal_places=2, blank=True, null=True)
250+
251+
def __str__(self):
252+
return self.name

0 commit comments

Comments
 (0)