Skip to content

Commit 6b15c0f

Browse files
committed
add aggregation function to keytransforms.
1 parent 7620aa4 commit 6b15c0f

File tree

3 files changed

+44
-21
lines changed

3 files changed

+44
-21
lines changed

django_mongodb/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
9797
"gte": lambda a, b: {"$gte": [a, b]},
9898
"lt": lambda a, b: {"$lt": [a, b]},
9999
"lte": lambda a, b: {"$lte": [a, b]},
100+
"in": lambda a, b: {"$in": [a, b]},
100101
}
101102

102103
display_name = "MongoDB"

django_mongodb/compiler.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,18 @@ def get_columns(self):
153153
columns = (
154154
self.get_default_columns(select_mask) if self.query.default_cols else self.query.select
155155
)
156-
return tuple((column.target.column, column) for column in columns) + tuple(
157-
self.query.annotations.items()
158-
)
156+
annotations_cnt = 0
157+
result = []
158+
for column in columns:
159+
if hasattr(column, "target"):
160+
target = column.target.column
161+
else:
162+
annotations_cnt += 1
163+
target = f"annotation_{annotations_cnt}"
164+
165+
result.append((target, column))
166+
167+
return tuple(result) + tuple(self.query.annotations.items())
159168

160169
def _get_ordering(self):
161170
"""

django_mongodb/fields.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,9 @@ def from_db_value(self, value, expression, connection):
6464

6565
def json_process_rhs(node, compiler, connection):
6666
_, value = node.process_rhs(compiler, connection)
67-
6867
lookup_name = node.lookup_name
6968
if lookup_name not in ("in", "range"):
7069
value = value[0] if len(value) > 0 else []
71-
else:
72-
result_value = []
73-
for ind, elem in enumerate(node.rhs):
74-
item = f"${value[ind]}" if isinstance(elem, KeyTransform) else value[ind]
75-
result_value.append(item)
76-
value = result_value
7770

7871
return value
7972

@@ -97,35 +90,31 @@ def contained_by(self, compiler, connection): # noqa: ARG001
9790

9891

9992
def json_exact(self, compiler, connection):
100-
lhs_mql = process_lhs(self, compiler, connection)
93+
lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True)
10194
rhs_mql = json_process_rhs(self, compiler, connection)
10295
if rhs_mql == "null":
10396
return {"$or": [{lhs_mql: {"$eq": None}}, {lhs_mql: {"$exists": False}}]}
104-
# return {lhs_mql: {"$eq": None, "$exists": True}}
105-
# return key_transform_isnull(self, compiler, connection)
10697
return {lhs_mql: {"$eq": rhs_mql, "$exists": True}}
10798

10899

109100
def key_transform_isnull(self, compiler, connection):
110-
lhs_mql = process_lhs(self, compiler, connection)
101+
lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True)
111102
rhs_mql = json_process_rhs(self, compiler, connection)
112-
# if rhs_mql is False:
113-
# return {lhs_mql: {"$neq": None}}
114-
# return {"$or": [{lhs_mql: {"$eq": None}}, {lhs_mql: {"$exists": False}}]}
115103

116104
# https://code.djangoproject.com/ticket/32252
117105
return {lhs_mql: {"$exists": not rhs_mql}}
118106

119107

120108
def key_transform_in(self, compiler, connection):
121-
lhs_mql = process_lhs(self, compiler, connection)
109+
lhs_mql = key_transform_agg(self.lhs, compiler, connection)
110+
bare_lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True)
122111
value = json_process_rhs(self, compiler, connection)
123-
rhs_mql = connection.operators[self.lookup_name](value)
124-
return {"$expr": {lhs_mql: rhs_mql}}
112+
expr = connection.mongo_aggregations[self.lookup_name](lhs_mql, value)
113+
return {"$expr": expr, bare_lhs_mql: {"$exists": True}}
125114

126115

127116
def has_key_lookup(self, compiler, connection):
128-
lhs = process_lhs(self, compiler, connection)
117+
lhs = process_lhs(self, compiler, connection, bare_column_ref=True)
129118
rhs = self.rhs
130119
if not isinstance(rhs, (list | tuple)):
131120
rhs = [rhs]
@@ -148,10 +137,34 @@ def has_key_lookup(self, compiler, connection):
148137
return {self.mongo_operator: keys}
149138

150139

140+
def key_transform_agg(self, compiler, connection):
141+
key_transforms = [self.key_name]
142+
previous = self.lhs
143+
while isinstance(previous, KeyTransform):
144+
key_transforms.insert(0, previous.key_name)
145+
previous = previous.lhs
146+
lhs_mql = previous.as_mql(compiler, connection)
147+
result = f"{lhs_mql}"
148+
for key in key_transforms:
149+
get_field = {"$getField": {"input": result, "field": key}}
150+
if key.isdigit():
151+
result = {
152+
"$cond": {
153+
"if": {"$isArray": result},
154+
"then": {"$arrayElemAt": [result, int(key)]},
155+
"else": get_field,
156+
}
157+
}
158+
else:
159+
result = get_field
160+
return result
161+
162+
151163
def load_fields():
152164
JSONField.from_db_value = from_db_value
153165
DataContains.as_mql = data_contains
154166
KeyTransform.as_mql = key_transform
167+
KeyTransform.as_mql_agg = key_transform_agg
155168
JSONExact.as_mql = json_exact
156169
ContainedBy.as_mql = contained_by
157170
HasKeyLookup.as_mql = has_key_lookup

0 commit comments

Comments
 (0)