Skip to content

Commit d351d83

Browse files
committed
add support for QuerySet.annotate()
1 parent b3f1f85 commit d351d83

File tree

6 files changed

+73
-28
lines changed

6 files changed

+73
-28
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ DATABASES = {
4747
## Known issues and limitations
4848

4949
- The following `QuerySet` methods aren't supported:
50-
- `annotate()`
5150
- `aggregate()`
5251
- `dates()`
5352
- `datetimes()`

django_mongodb/base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,25 @@ class DatabaseWrapper(BaseDatabaseWrapper):
5454
"TimeField": "date",
5555
"UUIDField": "string",
5656
}
57+
# Django uses these operators to generate SQL queries before it generates
58+
# MQL queries.
5759
operators = {
60+
"exact": "= %s",
61+
"iexact": "= UPPER(%s)",
62+
"contains": "LIKE %s",
63+
"icontains": "LIKE UPPER(%s)",
64+
"regex": "~ %s",
65+
"iregex": "~* %s",
66+
"gt": "> %s",
67+
"gte": ">= %s",
68+
"lt": "< %s",
69+
"lte": "<= %s",
70+
"startswith": "LIKE %s",
71+
"endswith": "LIKE %s",
72+
"istartswith": "LIKE UPPER(%s)",
73+
"iendswith": "LIKE UPPER(%s)",
74+
}
75+
mongo_operators = {
5876
"exact": lambda val: val,
5977
"gt": lambda val: {"$gt": val},
6078
"gte": lambda val: {"$gte": val},
@@ -73,6 +91,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
7391
"regex": lambda val: re.compile(val),
7492
"iregex": lambda val: re.compile(val, re.IGNORECASE),
7593
}
94+
mongo_aggregates = {
95+
"exact": lambda a, b: {"$eq": [a, b]},
96+
"gt": lambda a, b: {"$gt": [a, b]},
97+
"gte": lambda a, b: {"$gte": [a, b]},
98+
"lt": lambda a, b: {"$lt": [a, b]},
99+
"lte": lambda a, b: {"$lte": [a, b]},
100+
}
76101

77102
display_name = "MongoDB"
78103
vendor = "mongodb"

django_mongodb/compiler.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from django.core.exceptions import EmptyResultSet, FullResultSet
22
from django.db import DatabaseError, IntegrityError, NotSupportedError
3-
from django.db.models import NOT_PROVIDED, Count, Expression, Value
3+
from django.db.models import NOT_PROVIDED, Count, Expression
44
from django.db.models.aggregates import Aggregate
55
from django.db.models.constants import LOOKUP_SEP
66
from django.db.models.sql import compiler
@@ -57,11 +57,16 @@ def has_results(self):
5757

5858
def get_converters(self, columns):
5959
converters = {}
60-
for column in columns:
60+
for name_column in columns:
61+
try:
62+
name, column = name_column
63+
except TypeError:
64+
# e.g., Count("*")
65+
continue
6166
backend_converters = self.connection.ops.get_db_converters(column)
6267
field_converters = column.field.get_db_converters(self.connection)
6368
if backend_converters or field_converters:
64-
converters[column.target.column] = backend_converters + field_converters
69+
converters[name] = backend_converters + field_converters
6570
return converters
6671

6772
def _make_result(self, entity, columns, converters, tuple_expected=False):
@@ -72,15 +77,14 @@ def _make_result(self, entity, columns, converters, tuple_expected=False):
7277
names as keys.
7378
"""
7479
result = []
75-
for col in columns:
80+
for name, col in columns:
7681
field = col.field
77-
column = col.target.column
78-
value = entity.get(column, NOT_PROVIDED)
82+
value = entity.get(name, NOT_PROVIDED)
7983
if value is NOT_PROVIDED:
8084
value = field.get_default()
8185
elif converters:
8286
# Decode values using Django's database converters API.
83-
for converter in converters.get(column, ()):
87+
for converter in converters.get(name, ()):
8488
value = converter(value, col, self.connection)
8589
result.append(value)
8690
if tuple_expected:
@@ -91,12 +95,6 @@ def check_query(self):
9195
"""Check if the current query is supported by the database."""
9296
if self.query.is_empty():
9397
raise EmptyResultSet()
94-
# Supported annotations are Exists() and Count().
95-
if self.query.annotations and self.query.annotations not in (
96-
{"a": Value(1)},
97-
{"__count": Count("*")},
98-
):
99-
raise NotSupportedError("QuerySet.annotate() is not supported on MongoDB.")
10098
if self.query.distinct:
10199
# This is a heuristic to detect QuerySet.datetimes() and dates().
102100
# "datetimefield" and "datefield" are the names of the annotations
@@ -144,11 +142,17 @@ def build_query(self, columns=None):
144142
return query
145143

146144
def get_columns(self):
147-
"""Return columns which should be loaded by the query."""
145+
"""
146+
Return a tuple of (name, expression) with the columns and annotations
147+
which should be loaded by the query.
148+
"""
148149
select_mask = self.query.get_select_mask()
149-
return (
150+
columns = (
150151
self.get_default_columns(select_mask) if self.query.default_cols else self.query.select
151152
)
153+
return tuple((column.target.column, column) for column in columns) + tuple(
154+
self.query.annotations.items()
155+
)
152156

153157
def _get_ordering(self):
154158
"""

django_mongodb/features.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,27 +133,30 @@ class DatabaseFeatures(BaseDatabaseFeatures):
133133
"timezones.tests.LegacyDatabaseTests.test_query_aggregation",
134134
"timezones.tests.NewDatabaseTests.test_query_aggregation",
135135
},
136-
"QuerySet.annotate() not supported.": {
137-
"lookup.test_decimalfield.DecimalFieldLookupTests",
136+
"QuerySet.annotate() has some limitations.": {
137+
# Exists not supported.
138138
"lookup.tests.LookupTests.test_exact_exists",
139139
"lookup.tests.LookupTests.test_nested_outerref_lhs",
140+
"lookup.tests.LookupQueryingTests.test_filter_exists_lhs",
141+
# QuerySet.alias() doesn't work.
140142
"lookup.tests.LookupQueryingTests.test_alias",
141-
"lookup.tests.LookupQueryingTests.test_annotate",
143+
# Comparing two fields doesn't work.
142144
"lookup.tests.LookupQueryingTests.test_annotate_field_greater_than_field",
143-
"lookup.tests.LookupQueryingTests.test_annotate_field_greater_than_literal",
144-
"lookup.tests.LookupQueryingTests.test_annotate_field_greater_than_value",
145-
"lookup.tests.LookupQueryingTests.test_annotate_greater_than_or_equal",
146-
"lookup.tests.LookupQueryingTests.test_annotate_greater_than_or_equal_float",
147-
"lookup.tests.LookupQueryingTests.test_annotate_less_than_float",
145+
# Value() not supported.
148146
"lookup.tests.LookupQueryingTests.test_annotate_literal_greater_than_field",
149147
"lookup.tests.LookupQueryingTests.test_annotate_value_greater_than_value",
148+
# annotate() with combined expressions doesn't work:
149+
# 'WhereNode' object has no attribute 'field'
150150
"lookup.tests.LookupQueryingTests.test_combined_annotated_lookups_in_filter",
151151
"lookup.tests.LookupQueryingTests.test_combined_annotated_lookups_in_filter_false",
152152
"lookup.tests.LookupQueryingTests.test_combined_lookups",
153+
# Case not supported.
153154
"lookup.tests.LookupQueryingTests.test_conditional_expression",
154-
"lookup.tests.LookupQueryingTests.test_filter_exists_lhs",
155+
# Using expression in filter() doesn't work.
155156
"lookup.tests.LookupQueryingTests.test_filter_lookup_lhs",
157+
# Subquery not supported.
156158
"lookup.tests.LookupQueryingTests.test_filter_subquery_lhs",
159+
# ExpressionWrapper not supported.
157160
"lookup.tests.LookupQueryingTests.test_filter_wrapped_lookup_lhs",
158161
},
159162
"QuerySet.dates() is not supported on MongoDB.": {
@@ -190,6 +193,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
190193
"defer.tests.DeferTests.test_only_baseclass_when_subclass_has_no_added_fields",
191194
"defer.tests.TestDefer2.test_defer_inheritance_pk_chaining",
192195
"defer_regress.tests.DeferRegressionTest.test_ticket_16409",
196+
"lookup.test_decimalfield.DecimalFieldLookupTests",
193197
"lookup.tests.LookupQueryingTests.test_multivalued_join_reuse",
194198
"lookup.tests.LookupTests.test_filter_by_reverse_related_field_transform",
195199
"lookup.tests.LookupTests.test_lookup_collision",

django_mongodb/lookups.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@
88
def builtin_lookup(self, compiler, connection):
99
lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True)
1010
value = process_rhs(self, compiler, connection)
11-
rhs_mql = connection.operators[self.lookup_name](value)
11+
rhs_mql = connection.mongo_operators[self.lookup_name](value)
1212
return {lhs_mql: rhs_mql}
1313

1414

15+
def builtin_lookup_agg(self, compiler, connection):
16+
lhs_mql = process_lhs(self, compiler, connection)
17+
value = process_rhs(self, compiler, connection)
18+
return connection.mongo_aggregates[self.lookup_name](lhs_mql, value)
19+
20+
1521
def exact(self, compiler, connection):
1622
lhs_mql = process_lhs(self, compiler, connection)
1723
value = process_rhs(self, compiler, connection)
@@ -28,7 +34,7 @@ def is_null(self, compiler, connection):
2834
if not isinstance(self.rhs, bool):
2935
raise ValueError("The QuerySet value for an isnull lookup must be True or False.")
3036
lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True)
31-
rhs_mql = connection.operators["isnull"](self.rhs)
37+
rhs_mql = connection.mongo_operators["isnull"](self.rhs)
3238
return {lhs_mql: rhs_mql}
3339

3440

@@ -38,6 +44,7 @@ def uuid_text_mixin(self, compiler, connection): # noqa: ARG001
3844

3945
def register_lookups():
4046
BuiltinLookup.as_mql = builtin_lookup
47+
BuiltinLookup.as_mql_agg = builtin_lookup_agg
4148
Exact.as_mql = exact
4249
In.as_mql = RelatedIn.as_mql = in_
4350
IsNull.as_mql = is_null

django_mongodb/query.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,13 @@ def delete(self):
8585
def get_cursor(self):
8686
if self.query.low_mark == self.query.high_mark:
8787
return []
88-
fields = [col.target.column for col in self.columns] if self.columns else None
88+
fields = {}
89+
for name, expr in self.columns or []:
90+
try:
91+
fields[expr.target.column] = 1
92+
except AttributeError:
93+
# Generate the MQL for an annotation.
94+
fields[name] = expr.as_mql_agg(self.compiler, self.connection)
8995
cursor = self.collection.find(self.mongo_query, fields)
9096
if self.ordering:
9197
cursor.sort(self.ordering)

0 commit comments

Comments
 (0)