Skip to content

Commit 21ed108

Browse files
committed
add support for Case/When expressions
1 parent d4bee63 commit 21ed108

File tree

3 files changed

+97
-2
lines changed

3 files changed

+97
-2
lines changed

.github/workflows/test-python.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ jobs:
7878
empty
7979
expressions.tests.ExpressionOperatorTests
8080
expressions.tests.NegatedExpressionTests
81+
expressions_case
8182
defer
8283
defer_regress
8384
from_db_value

django_mongodb/expressions.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,40 @@
1+
from django.core.exceptions import EmptyResultSet, FullResultSet
12
from django.db.models.expressions import (
3+
Case,
24
Col,
35
CombinedExpression,
46
ExpressionWrapper,
57
NegatedExpression,
68
Value,
9+
When,
710
)
811

912

13+
def case(self, compiler, connection):
14+
case_parts = []
15+
for case in self.cases:
16+
case_mql = {}
17+
try:
18+
case_mql["case"] = case.as_mql(compiler, connection)
19+
except EmptyResultSet:
20+
continue
21+
except FullResultSet:
22+
default_mql = case.result.as_mql(compiler, connection)
23+
break
24+
case_mql["then"] = case.result.as_mql(compiler, connection)
25+
case_parts.append(case_mql)
26+
else:
27+
default_mql = self.default.as_mql(compiler, connection)
28+
if not case_parts:
29+
return default_mql
30+
return {
31+
"$switch": {
32+
"branches": case_parts,
33+
"default": default_mql,
34+
}
35+
}
36+
37+
1038
def col(self, compiler, connection): # noqa: ARG001
1139
return f"${self.target.column}"
1240

@@ -27,13 +55,19 @@ def negated_expression(self, compiler, connection):
2755
return {"$not": expression_wrapper(self, compiler, connection)}
2856

2957

58+
def when(self, compiler, connection):
59+
return self.condition.as_mql(compiler, connection)
60+
61+
3062
def value(self, compiler, connection): # noqa: ARG001
3163
return {"$literal": self.value}
3264

3365

3466
def register_expressions():
67+
Case.as_mql = case
3568
Col.as_mql = col
3669
CombinedExpression.as_mql = combined_expression
3770
ExpressionWrapper.as_mql = expression_wrapper
3871
NegatedExpression.as_mql = negated_expression
72+
When.as_mql = when
3973
Value.as_mql = value

django_mongodb/features.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
3434
"db_functions.tests.FunctionTests.test_nested_function_ordering",
3535
"db_functions.text.test_length.LengthTests.test_ordering",
3636
"db_functions.text.test_strindex.StrIndexTests.test_order_by",
37+
"expressions_case.tests.CaseExpressionTests.test_order_by_conditional_explicit",
3738
"lookup.tests.LookupQueryingTests.test_lookup_in_order_by",
3839
# annotate() after values() doesn't raise NotSupportedError.
3940
"lookup.tests.LookupTests.test_exact_query_rhs_with_selected_columns",
@@ -60,6 +61,12 @@ class DatabaseFeatures(BaseDatabaseFeatures):
6061
"annotations.tests.NonAggregateAnnotationTestCase.test_annotation_and_alias_filter_in_subquery",
6162
# Length of null considered zero rather than null.
6263
"db_functions.text.test_length.LengthTests.test_basic",
64+
# annotating with Decimal() crashes: bson.errors.InvalidDocument:
65+
# cannot encode object: Decimal('1'), of type: <class 'decimal.Decimal'>
66+
"expressions_case.tests.CaseExpressionTests.test_annotate_filter_decimal",
67+
# Case(..., then=datetime.date()) crashes: bson.errors.InvalidDocument:
68+
# cannot encode object: datetime.date(2024, 5, 14), of type: <class 'datetime.date'>
69+
"expressions_case.tests.CaseDocumentationExamples.test_filter_example",
6370
}
6471
# $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3.
6572
_django_test_expected_failures_bitwise = {
@@ -107,6 +114,36 @@ def django_test_expected_failures(self):
107114
"db_functions.text.test_replace.ReplaceTests.test_update",
108115
"db_functions.text.test_substr.SubstrTests.test_basic",
109116
"db_functions.text.test_upper.UpperTests.test_basic",
117+
"expressions_case.tests.CaseDocumentationExamples.test_conditional_update_example",
118+
"expressions_case.tests.CaseExpressionTests.test_update",
119+
"expressions_case.tests.CaseExpressionTests.test_update_big_integer",
120+
"expressions_case.tests.CaseExpressionTests.test_update_binary",
121+
"expressions_case.tests.CaseExpressionTests.test_update_boolean",
122+
"expressions_case.tests.CaseExpressionTests.test_update_date",
123+
"expressions_case.tests.CaseExpressionTests.test_update_date_time",
124+
"expressions_case.tests.CaseExpressionTests.test_update_decimal",
125+
"expressions_case.tests.CaseExpressionTests.test_update_duration",
126+
"expressions_case.tests.CaseExpressionTests.test_update_email",
127+
"expressions_case.tests.CaseExpressionTests.test_update_file",
128+
"expressions_case.tests.CaseExpressionTests.test_update_file_path",
129+
"expressions_case.tests.CaseExpressionTests.test_update_fk",
130+
"expressions_case.tests.CaseExpressionTests.test_update_float",
131+
"expressions_case.tests.CaseExpressionTests.test_update_generic_ip_address",
132+
"expressions_case.tests.CaseExpressionTests.test_update_image",
133+
"expressions_case.tests.CaseExpressionTests.test_update_null_boolean",
134+
"expressions_case.tests.CaseExpressionTests.test_update_positive_big_integer",
135+
"expressions_case.tests.CaseExpressionTests.test_update_positive_integer",
136+
"expressions_case.tests.CaseExpressionTests.test_update_positive_small_integer",
137+
"expressions_case.tests.CaseExpressionTests.test_update_slug",
138+
"expressions_case.tests.CaseExpressionTests.test_update_small_integer",
139+
"expressions_case.tests.CaseExpressionTests.test_update_string",
140+
"expressions_case.tests.CaseExpressionTests.test_update_text",
141+
"expressions_case.tests.CaseExpressionTests.test_update_time",
142+
"expressions_case.tests.CaseExpressionTests.test_update_url",
143+
"expressions_case.tests.CaseExpressionTests.test_update_uuid",
144+
"expressions_case.tests.CaseExpressionTests.test_update_with_expression_as_condition",
145+
"expressions_case.tests.CaseExpressionTests.test_update_with_expression_as_value",
146+
"expressions_case.tests.CaseExpressionTests.test_update_without_default",
110147
"model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values",
111148
"timezones.tests.NewDatabaseTests.test_update_with_timedelta",
112149
"update.tests.AdvancedTests.test_update_annotated_queryset",
@@ -169,6 +206,10 @@ def django_test_expected_failures(self):
169206
"annotations.tests.NonAggregateAnnotationTestCase.test_annotation_in_f_grouped_by_annotation",
170207
"annotations.tests.NonAggregateAnnotationTestCase.test_annotation_subquery_and_aggregate_values_chaining",
171208
"annotations.tests.NonAggregateAnnotationTestCase.test_filter_agg_with_double_f",
209+
"expressions_case.tests.CaseExpressionTests.test_aggregate",
210+
"expressions_case.tests.CaseExpressionTests.test_aggregate_with_expression_as_condition",
211+
"expressions_case.tests.CaseExpressionTests.test_aggregate_with_expression_as_value",
212+
"expressions_case.tests.CaseExpressionTests.test_aggregation_empty_cases",
172213
"lookup.tests.LookupQueryingTests.test_aggregate_combined_lookup",
173214
"from_db_value.tests.FromDBValueTest.test_aggregation",
174215
"timezones.tests.LegacyDatabaseTests.test_query_aggregation",
@@ -185,20 +226,20 @@ def django_test_expected_failures(self):
185226
"lookup.tests.LookupQueryingTests.test_combined_annotated_lookups_in_filter",
186227
"lookup.tests.LookupQueryingTests.test_combined_annotated_lookups_in_filter_false",
187228
"lookup.tests.LookupQueryingTests.test_combined_lookups",
188-
# Case not supported.
189-
"lookup.tests.LookupQueryingTests.test_conditional_expression",
190229
# Subquery not supported.
191230
"annotations.tests.NonAggregateAnnotationTestCase.test_empty_queryset_annotation",
192231
"db_functions.comparison.test_coalesce.CoalesceTests.test_empty_queryset",
193232
"db_functions.datetime.test_extract_trunc.DateFunctionTests.test_extract_outerref",
194233
"db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_subquery_with_parameters",
234+
"expressions_case.tests.CaseExpressionTests.test_in_subquery",
195235
"lookup.tests.LookupQueryingTests.test_filter_subquery_lhs",
196236
# Invalid $project :: caused by :: Unknown expression $count,
197237
"annotations.tests.NonAggregateAnnotationTestCase.test_combined_expression_annotation_with_aggregation",
198238
"annotations.tests.NonAggregateAnnotationTestCase.test_combined_f_expression_annotation_with_aggregation",
199239
"annotations.tests.NonAggregateAnnotationTestCase.test_full_expression_annotation_with_aggregation",
200240
"annotations.tests.NonAggregateAnnotationTestCase.test_grouping_by_q_expression_annotation",
201241
"annotations.tests.NonAggregateAnnotationTestCase.test_q_expression_annotation_with_aggregation",
242+
"expressions_case.tests.CaseDocumentationExamples.test_conditional_aggregation_example",
202243
# Func not implemented.
203244
"annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions",
204245
"annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions_can_ref_other_functions",
@@ -209,6 +250,8 @@ def django_test_expected_failures(self):
209250
"annotations.tests.NonAggregateAnnotationTestCase.test_order_by_aggregate",
210251
"annotations.tests.NonAggregateAnnotationTestCase.test_order_by_annotation",
211252
"expressions.tests.NegatedExpressionTests.test_filter",
253+
"expressions_case.tests.CaseExpressionTests.test_annotate_values_not_in_order_by",
254+
"expressions_case.tests.CaseExpressionTests.test_order_by_conditional_implicit",
212255
# annotate().filter().count() gives incorrect results.
213256
"db_functions.datetime.test_extract_trunc.DateFunctionTests.test_extract_year_exact_lookup",
214257
},
@@ -270,6 +313,23 @@ def django_test_expected_failures(self):
270313
"defer.tests.DeferTests.test_only_baseclass_when_subclass_has_no_added_fields",
271314
"defer.tests.TestDefer2.test_defer_inheritance_pk_chaining",
272315
"defer_regress.tests.DeferRegressionTest.test_ticket_16409",
316+
"expressions_case.tests.CaseExpressionTests.test_annotate_with_aggregation_in_condition",
317+
"expressions_case.tests.CaseExpressionTests.test_annotate_with_aggregation_in_predicate",
318+
"expressions_case.tests.CaseExpressionTests.test_annotate_with_aggregation_in_value",
319+
"expressions_case.tests.CaseExpressionTests.test_annotate_with_in_clause",
320+
"expressions_case.tests.CaseExpressionTests.test_annotate_with_join_in_condition",
321+
"expressions_case.tests.CaseExpressionTests.test_annotate_with_join_in_predicate",
322+
"expressions_case.tests.CaseExpressionTests.test_annotate_with_join_in_value",
323+
"expressions_case.tests.CaseExpressionTests.test_filter_with_aggregation_in_condition",
324+
"expressions_case.tests.CaseExpressionTests.test_filter_with_aggregation_in_predicate",
325+
"expressions_case.tests.CaseExpressionTests.test_filter_with_aggregation_in_value",
326+
"expressions_case.tests.CaseExpressionTests.test_filter_with_join_in_condition",
327+
"expressions_case.tests.CaseExpressionTests.test_filter_with_join_in_predicate",
328+
"expressions_case.tests.CaseExpressionTests.test_filter_with_join_in_value",
329+
"expressions_case.tests.CaseExpressionTests.test_join_promotion",
330+
"expressions_case.tests.CaseExpressionTests.test_join_promotion_multiple_annotations",
331+
"expressions_case.tests.CaseExpressionTests.test_m2m_exclude",
332+
"expressions_case.tests.CaseExpressionTests.test_m2m_reuse",
273333
"lookup.test_decimalfield.DecimalFieldLookupTests",
274334
"lookup.tests.LookupQueryingTests.test_multivalued_join_reuse",
275335
"lookup.tests.LookupTests.test_filter_by_reverse_related_field_transform",

0 commit comments

Comments
 (0)