Skip to content

Commit 78448d5

Browse files
committed
Fix unit test
1 parent 9c5eaf3 commit 78448d5

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

django_mongodb_backend/fields/embedded_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,11 @@ def as_mql(self, compiler, connection):
248248
key_transforms.insert(0, previous.key_name)
249249
previous = previous.lhs
250250
mql = previous.as_mql(compiler, connection)
251+
# transform = ".".join(key_transforms)
251252
for key in key_transforms:
252253
mql = {"$getField": {"input": mql, "field": key}}
253254
return mql
255+
# return f"{mql}.{transform}"
254256

255257
@property
256258
def output_field(self):

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,20 @@ def as_mql(self, compiler, connection):
5959
lhs_mql = process_lhs(self, compiler, connection)
6060
value = process_rhs(self, compiler, connection)
6161
if isinstance(self.lhs, Col | KeyTransform):
62+
if isinstance(self.lhs, Col):
63+
inner_lhs_mql = "$$item"
64+
else:
65+
lhs_mql, inner_lhs_mql = lhs_mql
6266
if isinstance(value, models.Model):
6367
value, emf_data = self.model_to_dict(value)
6468
# Get conditions for any nested EmbeddedModelFields.
65-
conditions = self.get_conditions({lhs_mql[1]: (value, emf_data)})
69+
conditions = self.get_conditions({inner_lhs_mql: (value, emf_data)})
6670
return {
6771
"$anyElementTrue": {
6872
"$ifNull": [
6973
{
7074
"$map": {
71-
"input": lhs_mql[0],
75+
"input": lhs_mql,
7276
"as": "item",
7377
"in": {"$and": conditions},
7478
}
@@ -82,9 +86,9 @@ def as_mql(self, compiler, connection):
8286
"$ifNull": [
8387
{
8488
"$map": {
85-
"input": lhs_mql[0],
89+
"input": lhs_mql,
8690
"as": "item",
87-
"in": {"$eq": [lhs_mql[1], value]},
91+
"in": {"$eq": [inner_lhs_mql, value]},
8892
}
8993
},
9094
[],
@@ -138,10 +142,7 @@ def get_transform(self, name):
138142
)
139143

140144
def as_mql(self, compiler, connection):
141-
if isinstance(self._lhs, Transform):
142-
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
143-
else:
144-
inner_lhs_mql = None
145+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
145146
lhs_mql = process_lhs(self, compiler, connection)
146147
return lhs_mql, inner_lhs_mql
147148

0 commit comments

Comments
 (0)