diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py index bdf98a46..78373fd0 100644 --- a/django_mongodb/expressions.py +++ b/django_mongodb/expressions.py @@ -1,14 +1,19 @@ -from django.db.models.expressions import Col, Value +from django.db.models.expressions import Col, ExpressionWrapper, Value def col(self, compiler, connection): # noqa: ARG001 return f"${self.target.column}" +def expression_wrapper(self, compiler, connection): + return self.expression.as_mql(compiler, connection) + + def value(self, compiler, connection): # noqa: ARG001 return {"$literal": self.value} def register_expressions(): Col.as_mql = col + ExpressionWrapper.as_mql = expression_wrapper Value.as_mql = value diff --git a/django_mongodb/features.py b/django_mongodb/features.py index e36049df..3b214757 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -175,21 +175,18 @@ class DatabaseFeatures(BaseDatabaseFeatures): "db_functions.datetime.test_extract_trunc.DateFunctionTests.test_extract_outerref", "db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_subquery_with_parameters", "lookup.tests.LookupQueryingTests.test_filter_subquery_lhs", - # ExpressionWrapper not supported. - "annotations.tests.NonAggregateAnnotationTestCase.test_combined_expression_annotation_with_aggregation", - "annotations.tests.NonAggregateAnnotationTestCase.test_combined_f_expression_annotation_with_aggregation", - "annotations.tests.NonAggregateAnnotationTestCase.test_empty_expression_annotation", - "annotations.tests.NonAggregateAnnotationTestCase.test_full_expression_annotation", + # Invalid $project :: caused by :: Unknown expression $count, "annotations.tests.NonAggregateAnnotationTestCase.test_full_expression_annotation_with_aggregation", "annotations.tests.NonAggregateAnnotationTestCase.test_grouping_by_q_expression_annotation", - "annotations.tests.NonAggregateAnnotationTestCase.test_mixed_type_annotation_numbers", "annotations.tests.NonAggregateAnnotationTestCase.test_q_expression_annotation_with_aggregation", - "lookup.tests.LookupQueryingTests.test_filter_wrapped_lookup_lhs", # CombinedExpression not implemented. "annotations.tests.NonAggregateAnnotationTestCase.test_combined_annotation_commutative", + "annotations.tests.NonAggregateAnnotationTestCase.test_combined_expression_annotation_with_aggregation", + "annotations.tests.NonAggregateAnnotationTestCase.test_combined_f_expression_annotation_with_aggregation", "annotations.tests.NonAggregateAnnotationTestCase.test_decimal_annotation", "annotations.tests.NonAggregateAnnotationTestCase.test_defer_annotation", "annotations.tests.NonAggregateAnnotationTestCase.test_filter_decimal_annotation", + "annotations.tests.NonAggregateAnnotationTestCase.test_mixed_type_annotation_numbers", "annotations.tests.NonAggregateAnnotationTestCase.test_values_annotation", # Func not implemented. "annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions", diff --git a/django_mongodb/query.py b/django_mongodb/query.py index 2d714858..2f2922dd 100644 --- a/django_mongodb/query.py +++ b/django_mongodb/query.py @@ -2,6 +2,7 @@ from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import DatabaseError, IntegrityError +from django.db.models import Value from django.db.models.sql.where import AND, XOR, WhereNode from pymongo import ASCENDING, DESCENDING from pymongo.errors import DuplicateKeyError, PyMongoError @@ -91,7 +92,12 @@ def get_cursor(self): column = expr.target.column except AttributeError: # Generate the MQL for an annotation. - fields[name] = expr.as_mql(self.compiler, self.connection) + try: + fields[name] = expr.as_mql(self.compiler, self.connection) + except EmptyResultSet: + fields[name] = Value(False).as_mql(self.compiler, self.connection) + except FullResultSet: + fields[name] = Value(True).as_mql(self.compiler, self.connection) else: # If name != column, then this is an annotatation referencing # another column.