Skip to content

Commit 7372300

Browse files
committed
Edits.
1 parent dbbd7db commit 7372300

File tree

13 files changed

+101
-346
lines changed

13 files changed

+101
-346
lines changed

django_mongodb_backend/aggregates.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88
MONGO_AGGREGATIONS = {Count: "sum"}
99

1010

11-
def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False):
11+
def aggregate(
12+
self,
13+
compiler,
14+
connection,
15+
operator=None,
16+
resolve_inner_expression=False,
17+
**extra_context, # noqa: ARG001
18+
):
1219
if self.filter:
1320
node = self.copy()
1421
node.filter = None
@@ -24,7 +31,7 @@ def aggregate(self, compiler, connection, operator=None, resolve_inner_expressio
2431
return {f"${operator}": lhs_mql}
2532

2633

27-
def count(self, compiler, connection, resolve_inner_expression=False):
34+
def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001
2835
"""
2936
When resolve_inner_expression=True, return the MQL that resolves as a
3037
value. This is used to count different elements, so the inner values are
@@ -57,7 +64,7 @@ def count(self, compiler, connection, resolve_inner_expression=False):
5764
return {"$add": [{"$size": lhs_mql}, exits_null]}
5865

5966

60-
def stddev_variance(self, compiler, connection):
67+
def stddev_variance(self, compiler, connection, **extra_context): # noqa: ARG001
6168
if self.function.endswith("_SAMP"):
6269
operator = "stdDevSamp"
6370
elif self.function.endswith("_POP"):

django_mongodb_backend/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _isnull_operator_match(a, b):
113113
return {"$or": [{a: {"$exists": False}}, {a: None}]}
114114
return {"$and": [{a: {"$exists": True}}, {a: {"$ne": None}}]}
115115

116-
mongo_operators_expr = {
116+
mongo_expr_operators = {
117117
"exact": lambda a, b: {"$eq": [a, b]},
118118
"gt": lambda a, b: {"$gt": [a, b]},
119119
"gte": lambda a, b: {"$gte": [a, b]},
@@ -153,7 +153,8 @@ def range_match(a, b):
153153
return {"$literal": True}
154154
return {"$and": conditions}
155155

156-
mongo_operators_match = {
156+
# match, path, find? don't know which name use.
157+
mongo_match_operators = {
157158
"exact": lambda a, b: {a: b},
158159
"gt": lambda a, b: {a: {"$gt": b}},
159160
"gte": lambda a, b: {a: {"$gte": b}},

django_mongodb_backend/expressions/builtins.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@
2929
from ..query_utils import process_lhs
3030

3131

32+
def base_expression(self, compiler, connection, as_path=False, **extra):
33+
if (
34+
as_path
35+
and hasattr(self, "as_mql_path")
36+
and getattr(self, "is_simple_expression", lambda: False)()
37+
):
38+
return self.as_mql_path(compiler, connection, **extra)
39+
40+
expr = self.as_mql_expr(compiler, connection, **extra)
41+
return {"$expr": expr} if as_path else expr
42+
43+
3244
def case(self, compiler, connection):
3345
case_parts = []
3446
for case in self.cases:
@@ -190,7 +202,7 @@ def exists(self, compiler, connection, get_wrapping_pipeline=None):
190202
lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
191203
except EmptyResultSet:
192204
return Value(False).as_mql(compiler, connection)
193-
return connection.mongo_operators_expr["isnull"](lhs_mql, False)
205+
return connection.mongo_expr_operators["isnull"](lhs_mql, False)
194206

195207

196208
def when(self, compiler, connection, as_path=False):
@@ -221,18 +233,6 @@ def value(self, compiler, connection, as_path=False): # noqa: ARG001
221233
return value
222234

223235

224-
def base_expression(self, compiler, connection, as_path=False, **extra):
225-
if (
226-
as_path
227-
and hasattr(self, "as_mql_path")
228-
and getattr(self, "is_simple_expression", lambda: False)()
229-
):
230-
return self.as_mql_path(compiler, connection, **extra)
231-
232-
expr = self.as_mql_expr(compiler, connection, **extra)
233-
return {"$expr": expr} if as_path else expr
234-
235-
236236
def register_expressions():
237237
BaseExpression.as_mql = base_expression
238238
BaseExpression.is_simple_column = False

django_mongodb_backend/fields/array.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -254,13 +254,6 @@ def __init__(self, lhs, rhs):
254254
class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
255255
lookup_name = "contains"
256256

257-
def as_mql_path(self, compiler, connection):
258-
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
259-
value = process_rhs(self, compiler, connection, as_path=True)
260-
if value is None:
261-
return False
262-
return {lhs_mql: {"$all": value}}
263-
264257
def as_mql_expr(self, compiler, connection):
265258
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
266259
value = process_rhs(self, compiler, connection, as_path=False)
@@ -272,6 +265,13 @@ def as_mql_expr(self, compiler, connection):
272265
]
273266
}
274267

268+
def as_mql_path(self, compiler, connection):
269+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
270+
value = process_rhs(self, compiler, connection, as_path=True)
271+
if value is None:
272+
return False
273+
return {lhs_mql: {"$all": value}}
274+
275275

276276
@ArrayField.register_lookup
277277
class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
@@ -333,11 +333,6 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
333333
},
334334
]
335335

336-
def as_mql_path(self, compiler, connection):
337-
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
338-
value = process_rhs(self, compiler, connection, as_path=True)
339-
return {lhs_mql: {"$in": value}}
340-
341336
def as_mql_expr(self, compiler, connection):
342337
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
343338
value = process_rhs(self, compiler, connection, as_path=False)
@@ -348,6 +343,11 @@ def as_mql_expr(self, compiler, connection):
348343
]
349344
}
350345

346+
def as_mql_path(self, compiler, connection):
347+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
348+
value = process_rhs(self, compiler, connection, as_path=True)
349+
return {lhs_mql: {"$in": value}}
350+
351351

352352
@ArrayField.register_lookup
353353
class ArrayLenTransform(Transform):
@@ -388,14 +388,14 @@ def is_simple_expression(self):
388388
def is_simple_column(self):
389389
return self.lhs.is_simple_column
390390

391-
def as_mql_path(self, compiler, connection):
392-
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
393-
return f"{lhs_mql}.{self.index}"
394-
395391
def as_mql_expr(self, compiler, connection):
396392
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
397393
return {"$arrayElemAt": [lhs_mql, self.index]}
398394

395+
def as_mql_path(self, compiler, connection):
396+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
397+
return f"{lhs_mql}.{self.index}"
398+
399399
@property
400400
def output_field(self):
401401
return self.base_field

django_mongodb_backend/fields/embedded_model.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,26 +197,26 @@ def get_transform(self, name):
197197
f"{suggestion}"
198198
)
199199

200-
def as_mql_path(self, compiler, connection):
200+
def as_mql_expr(self, compiler, connection):
201201
previous = self
202202
key_transforms = []
203203
while isinstance(previous, KeyTransform):
204204
key_transforms.insert(0, previous.key_name)
205205
previous = previous.lhs
206-
mql = previous.as_mql(compiler, connection, as_path=True)
207-
mql_path = ".".join(key_transforms)
208-
return f"{mql}.{mql_path}"
206+
mql = previous.as_mql(compiler, connection)
207+
for key in key_transforms:
208+
mql = {"$getField": {"input": mql, "field": key}}
209+
return mql
209210

210-
def as_mql_expr(self, compiler, connection):
211+
def as_mql_path(self, compiler, connection):
211212
previous = self
212213
key_transforms = []
213214
while isinstance(previous, KeyTransform):
214215
key_transforms.insert(0, previous.key_name)
215216
previous = previous.lhs
216-
mql = previous.as_mql(compiler, connection)
217-
for key in key_transforms:
218-
mql = {"$getField": {"input": mql, "field": key}}
219-
return mql
217+
mql = previous.as_mql(compiler, connection, as_path=True)
218+
mql_path = ".".join(key_transforms)
219+
return f"{mql}.{mql_path}"
220220

221221
@property
222222
def output_field(self):

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def as_mql_expr(self, compiler, connection):
137137
lhs_mql = process_lhs(self, compiler, connection)
138138
inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"]
139139
values = process_rhs(self, compiler, connection)
140-
lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators_expr[self.lookup_name](
140+
lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_expr_operators[self.lookup_name](
141141
inner_lhs_mql, values
142142
)
143143
return {"$anyElementTrue": lhs_mql}
@@ -230,7 +230,6 @@ class EmbeddedModelArrayFieldLessThanOrEqual(
230230

231231
class KeyTransform(Transform):
232232
field_class_name = "EmbeddedModelArrayField"
233-
PREFIX_ITERABLE = "item"
234233

235234
def __init__(self, key_name, array_field, *args, **kwargs):
236235
super().__init__(*args, **kwargs)
@@ -239,7 +238,7 @@ def __init__(self, key_name, array_field, *args, **kwargs):
239238
# Lookups iterate over the array of embedded models. A virtual column
240239
# of the queried field's type represents each element.
241240
column_target = array_field.base_field.embedded_model._meta.get_field(key_name).clone()
242-
column_name = f"${self.PREFIX_ITERABLE}.{key_name}"
241+
column_name = f"$item.{key_name}"
243242
column_target.db_column = column_name
244243
column_target.set_attributes_from_name(column_name)
245244
self._lhs = Col(None, column_target)
@@ -293,13 +292,6 @@ def get_transform(self, name):
293292
f"{suggestion}"
294293
)
295294

296-
def as_mql_path(self, compiler, connection):
297-
inner_lhs_mql = self._lhs.as_mql(compiler, connection, as_path=True).removeprefix(
298-
f"${self.PREFIX_ITERABLE}."
299-
)
300-
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
301-
return f"{lhs_mql}.{inner_lhs_mql}"
302-
303295
def as_mql_expr(self, compiler, connection):
304296
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
305297
lhs_mql = process_lhs(self, compiler, connection)
@@ -308,14 +300,19 @@ def as_mql_expr(self, compiler, connection):
308300
{
309301
"$map": {
310302
"input": lhs_mql,
311-
"as": self.PREFIX_ITERABLE,
303+
"as": "item",
312304
"in": inner_lhs_mql,
313305
}
314306
},
315307
[],
316308
]
317309
}
318310

311+
def as_mql_path(self, compiler, connection):
312+
inner_lhs_mql = self._lhs.as_mql(compiler, connection, as_path=True).removeprefix("$item.")
313+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
314+
return f"{lhs_mql}.{inner_lhs_mql}"
315+
319316
@property
320317
def output_field(self):
321318
return _EmbeddedModelArrayOutputField(self._lhs.output_field)

0 commit comments

Comments
 (0)