Skip to content

Commit 6e5995e

Browse files
committed
add support for ExpressionWrapper
1 parent 6ce79ac commit 6e5995e

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

django_mongodb/expressions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1-
from django.db.models.expressions import Col, Value
1+
from django.db.models.expressions import Col, ExpressionWrapper, Value
22

33

44
def col(self, compiler, connection): # noqa: ARG001
55
return f"${self.target.column}"
66

77

8+
def expression_wrapper(self, compiler, connection):
9+
return self.expression.as_mql(compiler, connection)
10+
11+
812
def value(self, compiler, connection): # noqa: ARG001
913
return {"$literal": self.value}
1014

1115

1216
def register_expressions():
1317
Col.as_mql = col
18+
ExpressionWrapper.as_mql = expression_wrapper
1419
Value.as_mql = value

django_mongodb/features.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,21 +175,18 @@ class DatabaseFeatures(BaseDatabaseFeatures):
175175
"db_functions.datetime.test_extract_trunc.DateFunctionTests.test_extract_outerref",
176176
"db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_subquery_with_parameters",
177177
"lookup.tests.LookupQueryingTests.test_filter_subquery_lhs",
178-
# ExpressionWrapper not supported.
179-
"annotations.tests.NonAggregateAnnotationTestCase.test_combined_expression_annotation_with_aggregation",
180-
"annotations.tests.NonAggregateAnnotationTestCase.test_combined_f_expression_annotation_with_aggregation",
181-
"annotations.tests.NonAggregateAnnotationTestCase.test_empty_expression_annotation",
182-
"annotations.tests.NonAggregateAnnotationTestCase.test_full_expression_annotation",
178+
# Invalid $project :: caused by :: Unknown expression $count,
183179
"annotations.tests.NonAggregateAnnotationTestCase.test_full_expression_annotation_with_aggregation",
184180
"annotations.tests.NonAggregateAnnotationTestCase.test_grouping_by_q_expression_annotation",
185-
"annotations.tests.NonAggregateAnnotationTestCase.test_mixed_type_annotation_numbers",
186181
"annotations.tests.NonAggregateAnnotationTestCase.test_q_expression_annotation_with_aggregation",
187-
"lookup.tests.LookupQueryingTests.test_filter_wrapped_lookup_lhs",
188182
# CombinedExpression not implemented.
189183
"annotations.tests.NonAggregateAnnotationTestCase.test_combined_annotation_commutative",
184+
"annotations.tests.NonAggregateAnnotationTestCase.test_combined_expression_annotation_with_aggregation",
185+
"annotations.tests.NonAggregateAnnotationTestCase.test_combined_f_expression_annotation_with_aggregation",
190186
"annotations.tests.NonAggregateAnnotationTestCase.test_decimal_annotation",
191187
"annotations.tests.NonAggregateAnnotationTestCase.test_defer_annotation",
192188
"annotations.tests.NonAggregateAnnotationTestCase.test_filter_decimal_annotation",
189+
"annotations.tests.NonAggregateAnnotationTestCase.test_mixed_type_annotation_numbers",
193190
"annotations.tests.NonAggregateAnnotationTestCase.test_values_annotation",
194191
# Func not implemented.
195192
"annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions",

django_mongodb/query.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from django.core.exceptions import EmptyResultSet, FullResultSet
44
from django.db import DatabaseError, IntegrityError
5+
from django.db.models import Value
56
from django.db.models.sql.where import AND, XOR, WhereNode
67
from pymongo import ASCENDING, DESCENDING
78
from pymongo.errors import DuplicateKeyError, PyMongoError
@@ -91,7 +92,12 @@ def get_cursor(self):
9192
column = expr.target.column
9293
except AttributeError:
9394
# Generate the MQL for an annotation.
94-
fields[name] = expr.as_mql(self.compiler, self.connection)
95+
try:
96+
fields[name] = expr.as_mql(self.compiler, self.connection)
97+
except EmptyResultSet:
98+
fields[name] = Value(False).as_mql(self.compiler, self.connection)
99+
except FullResultSet:
100+
fields[name] = Value(True).as_mql(self.compiler, self.connection)
95101
else:
96102
# If name != column, then this is an annotatation referencing
97103
# another column.

0 commit comments

Comments
 (0)