Skip to content

Commit dd451be

Browse files
committed
Simplify and rename EmbeddedModelArrayField's transform classes
1 parent 50a5b65 commit dd451be

File tree

2 files changed

+30
-32
lines changed

2 files changed

+30
-32
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def get_transform(self, name):
6767
transform = super().get_transform(name)
6868
if transform:
6969
return transform
70-
return KeyTransformFactory(name, self)
70+
field = self.base_field.embedded_model._meta.get_field(name)
71+
return EmbeddedModelArrayFieldTransformFactory(field)
7172

7273
def _get_lookup(self, lookup_name):
7374
lookup = super()._get_lookup(lookup_name)
@@ -223,17 +224,15 @@ class EmbeddedModelArrayFieldLessThanOrEqual(
223224
pass
224225

225226

226-
class KeyTransform(Transform):
227+
class EmbeddedModelArrayFieldTransform(Transform):
227228
field_class_name = "EmbeddedModelArrayField"
228229

229-
def __init__(self, key_name, array_field, *args, **kwargs):
230+
def __init__(self, field, *args, **kwargs):
230231
super().__init__(*args, **kwargs)
231-
self.array_field = array_field
232-
self.key_name = key_name
233232
# Lookups iterate over the array of embedded models. A virtual column
234233
# of the queried field's type represents each element.
235-
column_target = array_field.base_field.embedded_model._meta.get_field(key_name).clone()
236-
column_name = f"$item.{key_name}"
234+
column_target = field.clone()
235+
column_name = f"$item.{field.column}"
237236
column_target.db_column = column_name
238237
column_target.set_attributes_from_name(column_name)
239238
self._lhs = Col(None, column_target)
@@ -254,7 +253,7 @@ def get_transform(self, name):
254253
# Once the sub-lhs is a transform, all the filters are applied over it.
255254
# Otherwise get the transform from the nested embedded model field.
256255
if transform := self._lhs.get_transform(name):
257-
if isinstance(transform, KeyTransformFactory):
256+
if isinstance(transform, EmbeddedModelArrayFieldTransformFactory):
258257
raise ValueError("Cannot perform multiple levels of array traversal in a query.")
259258
self._sub_transform = transform
260259
return self
@@ -296,10 +295,9 @@ def output_field(self):
296295
return _EmbeddedModelArrayOutputField(self._lhs.output_field)
297296

298297

299-
class KeyTransformFactory:
300-
def __init__(self, key_name, base_field):
301-
self.key_name = key_name
302-
self.base_field = base_field
298+
class EmbeddedModelArrayFieldTransformFactory:
299+
def __init__(self, field):
300+
self.field = field
303301

304302
def __call__(self, *args, **kwargs):
305-
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)
303+
return EmbeddedModelArrayFieldTransform(self.field, *args, **kwargs)

django_mongodb_backend/fields/polymorphic_embedded_model_array.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
from . import PolymorphicEmbeddedModelField
99
from .array import ArrayField, ArrayLenTransform
10-
from .embedded_model_array import KeyTransform as ArrayFieldKeyTransform
11-
from .embedded_model_array import KeyTransformFactory as ArrayFieldKeyTransformFactory
10+
from .embedded_model_array import (
11+
EmbeddedModelArrayFieldTransform,
12+
EmbeddedModelArrayFieldTransformFactory,
13+
)
1214

1315

1416
class PolymorphicEmbeddedModelArrayField(ArrayField):
@@ -62,7 +64,15 @@ def get_transform(self, name):
6264
transform = super().get_transform(name)
6365
if transform:
6466
return transform
65-
return KeyTransformFactory(name, self)
67+
for model in self.base_field.embedded_models:
68+
with contextlib.suppress(FieldDoesNotExist):
69+
field = model._meta.get_field(name)
70+
break
71+
else:
72+
raise FieldDoesNotExist(
73+
f"The models of field '{self.name}' have no field named '{name}'."
74+
)
75+
return PolymorphicArrayFieldTransformFactory(field)
6676

6777
def _get_lookup(self, lookup_name):
6878
lookup = super()._get_lookup(lookup_name)
@@ -79,32 +89,22 @@ def as_mql(self, compiler, connection):
7989
return EmbeddedModelArrayFieldLookups
8090

8191

82-
class KeyTransform(ArrayFieldKeyTransform):
92+
class PolymorphicArrayFieldTransform(EmbeddedModelArrayFieldTransform):
8393
field_class_name = "PolymorphicEmbeddedModelArrayField"
8494

85-
def __init__(self, key_name, array_field, *args, **kwargs):
86-
# Skip ArrayFieldKeyTransform.__init__()
95+
def __init__(self, field, *args, **kwargs):
96+
# Skip EmbeddedModelArrayFieldTransform.__init__()
8797
Transform.__init__(self, *args, **kwargs)
88-
self.array_field = array_field
89-
self.key_name = key_name
90-
for model in array_field.base_field.embedded_models:
91-
with contextlib.suppress(FieldDoesNotExist):
92-
field = model._meta.get_field(key_name)
93-
break
94-
else:
95-
raise FieldDoesNotExist(
96-
f"The models of field '{array_field.name}' have no field named '{key_name}'."
97-
)
9898
# Lookups iterate over the array of embedded models. A virtual column
9999
# of the queried field's type represents each element.
100100
column_target = field.clone()
101-
column_name = f"$item.{key_name}"
101+
column_name = f"$item.{field.column}"
102102
column_target.db_column = column_name
103103
column_target.set_attributes_from_name(column_name)
104104
self._lhs = Col(None, column_target)
105105
self._sub_transform = None
106106

107107

108-
class KeyTransformFactory(ArrayFieldKeyTransformFactory):
108+
class PolymorphicArrayFieldTransformFactory(EmbeddedModelArrayFieldTransformFactory):
109109
def __call__(self, *args, **kwargs):
110-
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)
110+
return PolymorphicArrayFieldTransform(self.field, *args, **kwargs)

0 commit comments

Comments
 (0)