Skip to content

Commit 5702778

Browse files
committed
Refactor
1 parent 78448d5 commit 5702778

File tree

3 files changed

+53
-50
lines changed

3 files changed

+53
-50
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -58,55 +58,56 @@ class EMFArrayExact(EMFExact):
5858
def as_mql(self, compiler, connection):
5959
lhs_mql = process_lhs(self, compiler, connection)
6060
value = process_rhs(self, compiler, connection)
61-
if isinstance(self.lhs, Col | KeyTransform):
62-
if isinstance(self.lhs, Col):
63-
inner_lhs_mql = "$$item"
64-
else:
65-
lhs_mql, inner_lhs_mql = lhs_mql
66-
if isinstance(value, models.Model):
67-
value, emf_data = self.model_to_dict(value)
68-
# Get conditions for any nested EmbeddedModelFields.
69-
conditions = self.get_conditions({inner_lhs_mql: (value, emf_data)})
70-
return {
71-
"$anyElementTrue": {
72-
"$ifNull": [
73-
{
74-
"$map": {
75-
"input": lhs_mql,
76-
"as": "item",
77-
"in": {"$and": conditions},
78-
}
79-
},
80-
[],
81-
]
82-
}
83-
}
61+
if isinstance(self.lhs, KeyTransform):
62+
lhs_mql, inner_lhs_mql = lhs_mql
63+
else:
64+
inner_lhs_mql = "$$item"
65+
if isinstance(value, models.Model):
66+
value, emf_data = self.model_to_dict(value)
67+
# Get conditions for any nested EmbeddedModelFields.
68+
conditions = self.get_conditions({inner_lhs_mql: (value, emf_data)})
8469
return {
8570
"$anyElementTrue": {
8671
"$ifNull": [
8772
{
8873
"$map": {
8974
"input": lhs_mql,
9075
"as": "item",
91-
"in": {"$eq": [inner_lhs_mql, value]},
76+
"in": {"$and": conditions},
9277
}
9378
},
9479
[],
9580
]
9681
}
9782
}
98-
return connection.mongo_operators[self.lookup_name](lhs_mql, value)
83+
return {
84+
"$anyElementTrue": {
85+
"$ifNull": [
86+
{
87+
"$map": {
88+
"input": lhs_mql,
89+
"as": "item",
90+
"in": {"$eq": [inner_lhs_mql, value]},
91+
}
92+
},
93+
[],
94+
]
95+
}
96+
}
9997

10098

10199
class KeyTransform(Transform):
102100
# it should be different class than EMF keytransform even most of the methods are equal.
103101
def __init__(self, key_name, base_field, *args, **kwargs):
104102
super().__init__(*args, **kwargs)
105103
self.base_field = base_field
106-
# TODO: Need to create a column, will refactor this thing.
104+
self.key_name = key_name
105+
# The iteration items begins from the base_field, a virtual column with
106+
# base field output type is created.
107107
column_target = base_field.clone()
108-
column_target.db_column = f"$item.{key_name}"
109-
column_target.set_attributes_from_name(f"$item.{key_name}")
108+
column_name = f"$item.{key_name}"
109+
column_target.db_column = column_name
110+
column_target.set_attributes_from_name(column_name)
110111
self._lhs = Col(None, column_target)
111112
self._sub_transform = None
112113

@@ -117,19 +118,8 @@ def __call__(self, this, *args, **kwargs):
117118
def get_lookup(self, name):
118119
return self.output_field.get_lookup(name)
119120

120-
def get_transform(self, name):
121-
"""
122-
Validate that `name` is either a field of an embedded model or a
123-
lookup on an embedded model's field.
124-
"""
125-
if isinstance(self._lhs, Transform):
126-
transform = self._lhs.get_transform(name)
127-
else:
128-
transform = self.base_field.get_transform(name)
129-
if transform:
130-
self._sub_transform = transform
131-
return self
132-
suggested_lookups = difflib.get_close_matches(name, self.base_field.get_lookups())
121+
def _get_missing_field_or_lookup_exception(self, lhs, name):
122+
suggested_lookups = difflib.get_close_matches(name, lhs.get_lookups())
133123
if suggested_lookups:
134124
suggested_lookups = " or ".join(suggested_lookups)
135125
suggestion = f", perhaps you meant {suggested_lookups}?"
@@ -141,6 +131,25 @@ def get_transform(self, name):
141131
f"{suggestion}"
142132
)
143133

134+
def get_transform(self, name):
135+
"""
136+
Validate that `name` is either a field of an embedded model or a
137+
lookup on an embedded model's field.
138+
"""
139+
# Once the sub lhs is a transform, all the filter are applied over it.
140+
141+
transform = (
142+
self._lhs.get_transform(name)
143+
if isinstance(self._lhs, Transform)
144+
else self.base_field.get_transform(name)
145+
)
146+
if transform:
147+
self._sub_transform = transform
148+
return self
149+
raise self._get_missing_field_or_lookup_exception(
150+
self._lhs if isinstance(self._lhs, Transform) else self.base_field, name
151+
)
152+
144153
def as_mql(self, compiler, connection):
145154
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
146155
lhs_mql = process_lhs(self, compiler, connection)

tests/model_fields_/models.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,17 +205,12 @@ class ExhibitSection(EmbeddedModel):
205205
artifacts = EmbeddedModelArrayField(ArtifactDetail, null=True)
206206

207207

208-
class ExhibitMeta(EmbeddedModel):
209-
curator_name = models.CharField(max_length=255)
210-
artifacts = EmbeddedModelArrayField(ArtifactDetail, null=True)
211-
212-
213208
class MuseumExhibit(models.Model):
214209
"""An exhibit in the museum, composed of multiple sections."""
215210

216211
exhibit_name = models.CharField(max_length=255)
217212
sections = EmbeddedModelArrayField(ExhibitSection, null=True)
218-
meta = EmbeddedModelField(ExhibitMeta, null=True)
213+
main_section = EmbeddedModelField(ExhibitSection, null=True)
219214

220215
def __str__(self):
221216
return self.exhibit_name

tests/model_fields_/test_embedded_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
D,
2929
Data,
3030
E,
31-
ExhibitMeta,
3231
ExhibitSection,
3332
Holder,
3433
Library,
@@ -201,8 +200,8 @@ def setUpTestData(cls):
201200
)
202201
cls.lost_empires = MuseumExhibit.objects.create(
203202
exhibit_name="Lost Empires",
204-
meta=ExhibitMeta(
205-
curator_name="Dr. Amina Hale",
203+
main_section=ExhibitSection(
204+
section_number=3,
206205
artifacts=[
207206
ArtifactDetail(
208207
name="Bronze Statue",

0 commit comments

Comments
 (0)