Skip to content

Commit 9c5eaf3

Browse files
committed
POC: Manage sub array queries with a different transform path.
1 parent bf4cec1 commit 9c5eaf3

File tree

1 file changed

+51
-20
lines changed

1 file changed

+51
-20
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,31 +62,53 @@ def as_mql(self, compiler, connection):
6262
if isinstance(value, models.Model):
6363
value, emf_data = self.model_to_dict(value)
6464
# Get conditions for any nested EmbeddedModelFields.
65-
conditions = self.get_conditions({"$$item": (value, emf_data)})
65+
conditions = self.get_conditions({lhs_mql[1]: (value, emf_data)})
6666
return {
6767
"$anyElementTrue": {
68-
"$map": {"input": lhs_mql, "as": "item", "in": {"$and": conditions}}
68+
"$ifNull": [
69+
{
70+
"$map": {
71+
"input": lhs_mql[0],
72+
"as": "item",
73+
"in": {"$and": conditions},
74+
}
75+
},
76+
[],
77+
]
6978
}
7079
}
71-
lhs_mql = process_lhs(self.lhs, compiler, connection)
7280
return {
7381
"$anyElementTrue": {
74-
"$map": {
75-
"input": lhs_mql,
76-
"as": "item",
77-
"in": {"$eq": [f"$$item.{self.lhs.key_name}", value]},
78-
}
82+
"$ifNull": [
83+
{
84+
"$map": {
85+
"input": lhs_mql[0],
86+
"as": "item",
87+
"in": {"$eq": [lhs_mql[1], value]},
88+
}
89+
},
90+
[],
91+
]
7992
}
8093
}
8194
return connection.mongo_operators[self.lookup_name](lhs_mql, value)
8295

8396

8497
class KeyTransform(Transform):
8598
# it should be different class than EMF keytransform even most of the methods are equal.
86-
def __init__(self, key_name, ref_field, *args, **kwargs):
99+
def __init__(self, key_name, base_field, *args, **kwargs):
87100
super().__init__(*args, **kwargs)
88-
self.key_name = str(key_name)
89-
self.ref_field = ref_field
101+
self.base_field = base_field
102+
# TODO: Need to create a column, will refactor this thing.
103+
column_target = base_field.clone()
104+
column_target.db_column = f"$item.{key_name}"
105+
column_target.set_attributes_from_name(f"$item.{key_name}")
106+
self._lhs = Col(None, column_target)
107+
self._sub_transform = None
108+
109+
def __call__(self, this, *args, **kwargs):
110+
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
111+
return self
90112

91113
def get_lookup(self, name):
92114
return self.output_field.get_lookup(name)
@@ -96,33 +118,42 @@ def get_transform(self, name):
96118
Validate that `name` is either a field of an embedded model or a
97119
lookup on an embedded model's field.
98120
"""
99-
if transform := self.ref_field.get_transform(name):
100-
return transform
101-
suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups())
121+
if isinstance(self._lhs, Transform):
122+
transform = self._lhs.get_transform(name)
123+
else:
124+
transform = self.base_field.get_transform(name)
125+
if transform:
126+
self._sub_transform = transform
127+
return self
128+
suggested_lookups = difflib.get_close_matches(name, self.base_field.get_lookups())
102129
if suggested_lookups:
103130
suggested_lookups = " or ".join(suggested_lookups)
104131
suggestion = f", perhaps you meant {suggested_lookups}?"
105132
else:
106133
suggestion = "."
107134
raise FieldDoesNotExist(
108135
f"Unsupported lookup '{name}' for "
109-
f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'"
136+
f"{self.base_field.__class__.__name__} '{self.base_field.name}'"
110137
f"{suggestion}"
111138
)
112139

113140
def as_mql(self, compiler, connection):
141+
if isinstance(self._lhs, Transform):
142+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
143+
else:
144+
inner_lhs_mql = None
114145
lhs_mql = process_lhs(self, compiler, connection)
115-
return f"{lhs_mql}.{self.key_name}"
146+
return lhs_mql, inner_lhs_mql
116147

117148
@property
118149
def output_field(self):
119-
return EmbeddedModelArrayField(self.ref_field)
150+
return EmbeddedModelArrayField(self.base_field)
120151

121152

122153
class KeyTransformFactory:
123-
def __init__(self, key_name, ref_field):
154+
def __init__(self, key_name, base_field):
124155
self.key_name = key_name
125-
self.ref_field = ref_field
156+
self.base_field = base_field
126157

127158
def __call__(self, *args, **kwargs):
128-
return KeyTransform(self.key_name, self.ref_field, *args, **kwargs)
159+
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)

0 commit comments

Comments
 (0)