Skip to content

Commit 85c4206

Browse files
committed
add support for QuerySet.annotate()
1 parent d57986b commit 85c4206

File tree

6 files changed

+71
-28
lines changed

6 files changed

+71
-28
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ DATABASES = {
4747
## Known issues and limitations
4848

4949
- The following `QuerySet` methods aren't supported:
50-
- `annotate()`
5150
- `aggregate()`
5251
- `dates()`
5352
- `datetimes()`

django_mongodb/base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,25 @@ class DatabaseWrapper(BaseDatabaseWrapper):
5454
"TimeField": "date",
5555
"UUIDField": "string",
5656
}
57+
# Django uses these operators to generate SQL queries before it generates
58+
# MQL queries.
5759
operators = {
60+
"exact": "= %s",
61+
"iexact": "= UPPER(%s)",
62+
"contains": "LIKE %s",
63+
"icontains": "LIKE UPPER(%s)",
64+
"regex": "~ %s",
65+
"iregex": "~* %s",
66+
"gt": "> %s",
67+
"gte": ">= %s",
68+
"lt": "< %s",
69+
"lte": "<= %s",
70+
"startswith": "LIKE %s",
71+
"endswith": "LIKE %s",
72+
"istartswith": "LIKE UPPER(%s)",
73+
"iendswith": "LIKE UPPER(%s)",
74+
}
75+
mongo_operators = {
5876
"exact": lambda val: val,
5977
"gt": lambda val: {"$gt": val},
6078
"gte": lambda val: {"$gte": val},
@@ -73,6 +91,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
7391
"regex": lambda val: re.compile(val),
7492
"iregex": lambda val: re.compile(val, re.IGNORECASE),
7593
}
94+
mongo_aggregates = {
95+
"exact": lambda val1, val2: {"$eq": [val1, val2]},
96+
"gt": lambda val1, val2: {"$gt": [val1, val2]},
97+
"gte": lambda val1, val2: {"$gte": [val1, val2]},
98+
"lt": lambda val1, val2: {"$lt": [val1, val2]},
99+
"lte": lambda val1, val2: {"$lte": [val1, val2]},
100+
}
76101

77102
display_name = "MongoDB"
78103
vendor = "mongodb"

django_mongodb/compiler.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from django.core.exceptions import EmptyResultSet, FullResultSet
22
from django.db import DatabaseError, IntegrityError, NotSupportedError
3-
from django.db.models import NOT_PROVIDED, Count, Expression, Value
3+
from django.db.models import NOT_PROVIDED, Count, Expression
44
from django.db.models.aggregates import Aggregate
55
from django.db.models.constants import LOOKUP_SEP
66
from django.db.models.sql import compiler
@@ -9,6 +9,8 @@
99
from .base import Cursor
1010
from .query import MongoQuery, wrap_database_errors
1111

12+
COUNT_STAR = {"__count": Count("*")}
13+
1214

1315
class SQLCompiler(compiler.SQLCompiler):
1416
"""Base class for all Mongo compilers."""
@@ -20,7 +22,7 @@ def execute_sql(
2022
):
2123
self.pre_sql_setup()
2224
# QuerySet.count()
23-
if self.query.annotations == {"__count": Count("*")}:
25+
if self.query.annotations == COUNT_STAR:
2426
return [self.get_count()]
2527
try:
2628
query = self.build_query()
@@ -57,11 +59,16 @@ def has_results(self):
5759

5860
def get_converters(self, columns):
5961
converters = {}
60-
for column in columns:
62+
for name_column in columns:
63+
try:
64+
name, column = name_column
65+
except TypeError:
66+
# e.g., Count("*")
67+
continue
6168
backend_converters = self.connection.ops.get_db_converters(column)
6269
field_converters = column.field.get_db_converters(self.connection)
6370
if backend_converters or field_converters:
64-
converters[column.target.column] = backend_converters + field_converters
71+
converters[name] = backend_converters + field_converters
6572
return converters
6673

6774
def _make_result(self, entity, columns, converters, tuple_expected=False):
@@ -72,15 +79,14 @@ def _make_result(self, entity, columns, converters, tuple_expected=False):
7279
names as keys.
7380
"""
7481
result = []
75-
for col in columns:
82+
for name, col in columns:
7683
field = col.field
77-
column = col.target.column
78-
value = entity.get(column, NOT_PROVIDED)
84+
value = entity.get(name, NOT_PROVIDED)
7985
if value is NOT_PROVIDED:
8086
value = field.get_default()
8187
elif converters:
8288
# Decode values using Django's database converters API.
83-
for converter in converters.get(column, ()):
89+
for converter in converters.get(name, ()):
8490
value = converter(value, col, self.connection)
8591
result.append(value)
8692
if tuple_expected:
@@ -91,12 +97,6 @@ def check_query(self):
9197
"""Check if the current query is supported by the database."""
9298
if self.query.is_empty():
9399
raise EmptyResultSet()
94-
# Supported annotations are Exists() and Count().
95-
if self.query.annotations and self.query.annotations not in (
96-
{"a": Value(1)},
97-
{"__count": Count("*")},
98-
):
99-
raise NotSupportedError("QuerySet.annotate() is not supported on MongoDB.")
100100
if self.query.distinct:
101101
# This is a heuristic to detect QuerySet.datetimes() and dates().
102102
# "datetimefield" and "datefield" are the names of the annotations
@@ -146,9 +146,12 @@ def build_query(self, columns=None):
146146
def get_columns(self):
147147
"""Return columns which should be loaded by the query."""
148148
select_mask = self.query.get_select_mask()
149-
return (
149+
columns = (
150150
self.get_default_columns(select_mask) if self.query.default_cols else self.query.select
151151
)
152+
return tuple((column.target.column, column) for column in columns) + tuple(
153+
self.query.annotations.items() # if self.query.annotations.items() != COUNT_STAR else ()
154+
)
152155

153156
def _get_ordering(self):
154157
"""

django_mongodb/features.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,27 +131,30 @@ class DatabaseFeatures(BaseDatabaseFeatures):
131131
"timezones.tests.LegacyDatabaseTests.test_query_aggregation",
132132
"timezones.tests.NewDatabaseTests.test_query_aggregation",
133133
},
134-
"QuerySet.annotate() not supported.": {
135-
"lookup.test_decimalfield.DecimalFieldLookupTests",
134+
"QuerySet.annotate() has some limitations.": {
135+
# Exists not supported.
136136
"lookup.tests.LookupTests.test_exact_exists",
137137
"lookup.tests.LookupTests.test_nested_outerref_lhs",
138+
"lookup.tests.LookupQueryingTests.test_filter_exists_lhs",
139+
# QuerySet.alias() doesn't work.
138140
"lookup.tests.LookupQueryingTests.test_alias",
139-
"lookup.tests.LookupQueryingTests.test_annotate",
141+
# Comparing two fields doesn't work.
140142
"lookup.tests.LookupQueryingTests.test_annotate_field_greater_than_field",
141-
"lookup.tests.LookupQueryingTests.test_annotate_field_greater_than_literal",
142-
"lookup.tests.LookupQueryingTests.test_annotate_field_greater_than_value",
143-
"lookup.tests.LookupQueryingTests.test_annotate_greater_than_or_equal",
144-
"lookup.tests.LookupQueryingTests.test_annotate_greater_than_or_equal_float",
145-
"lookup.tests.LookupQueryingTests.test_annotate_less_than_float",
143+
# Value() not supported.
146144
"lookup.tests.LookupQueryingTests.test_annotate_literal_greater_than_field",
147145
"lookup.tests.LookupQueryingTests.test_annotate_value_greater_than_value",
146+
# annotate() with combined expressions doesn't work:
147+
# 'WhereNode' object has no attribute 'field'
148148
"lookup.tests.LookupQueryingTests.test_combined_annotated_lookups_in_filter",
149149
"lookup.tests.LookupQueryingTests.test_combined_annotated_lookups_in_filter_false",
150150
"lookup.tests.LookupQueryingTests.test_combined_lookups",
151+
# Case not supported.
151152
"lookup.tests.LookupQueryingTests.test_conditional_expression",
152-
"lookup.tests.LookupQueryingTests.test_filter_exists_lhs",
153+
# Using expression in filter() doesn't work.
153154
"lookup.tests.LookupQueryingTests.test_filter_lookup_lhs",
155+
# Subquery not supported.
154156
"lookup.tests.LookupQueryingTests.test_filter_subquery_lhs",
157+
# ExpressionWrapper not supported.
155158
"lookup.tests.LookupQueryingTests.test_filter_wrapped_lookup_lhs",
156159
},
157160
"QuerySet.dates() is not supported on MongoDB.": {
@@ -188,6 +191,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
188191
"defer.tests.DeferTests.test_only_baseclass_when_subclass_has_no_added_fields",
189192
"defer.tests.TestDefer2.test_defer_inheritance_pk_chaining",
190193
"defer_regress.tests.DeferRegressionTest.test_ticket_16409",
194+
"lookup.test_decimalfield.DecimalFieldLookupTests",
191195
"lookup.tests.LookupQueryingTests.test_multivalued_join_reuse",
192196
"lookup.tests.LookupTests.test_filter_by_reverse_related_field_transform",
193197
"lookup.tests.LookupTests.test_lookup_collision",

django_mongodb/lookups.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@
88
def builtin_lookup(self, compiler, connection):
99
lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True)
1010
value = process_rhs(self, compiler, connection)
11-
rhs_mql = connection.operators[self.lookup_name](value)
11+
rhs_mql = connection.mongo_operators[self.lookup_name](value)
1212
return {lhs_mql: rhs_mql}
1313

1414

15+
def builtin_lookup_agg(self, compiler, connection):
16+
lhs_mql = process_lhs(self, compiler, connection)
17+
value = process_rhs(self, compiler, connection)
18+
return connection.mongo_aggregates[self.lookup_name](lhs_mql, value)
19+
20+
1521
def exact(self, compiler, connection):
1622
lhs_mql = process_lhs(self, compiler, connection)
1723
value = process_rhs(self, compiler, connection)
@@ -28,7 +34,7 @@ def is_null(self, compiler, connection):
2834
if not isinstance(self.rhs, bool):
2935
raise ValueError("The QuerySet value for an isnull lookup must be True or False.")
3036
lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True)
31-
rhs_mql = connection.operators["isnull"](self.rhs)
37+
rhs_mql = connection.mongo_operators["isnull"](self.rhs)
3238
return {lhs_mql: rhs_mql}
3339

3440

@@ -38,6 +44,7 @@ def uuid_text_mixin(self, compiler, connection): # noqa: ARG001
3844

3945
def register_lookups():
4046
BuiltinLookup.as_mql = builtin_lookup
47+
BuiltinLookup.as_mql_agg = builtin_lookup_agg
4148
Exact.as_mql = exact
4249
In.as_mql = RelatedIn.as_mql = in_
4350
IsNull.as_mql = is_null

django_mongodb/query.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,12 @@ def delete(self):
8585
def get_cursor(self):
8686
if self.query.low_mark == self.query.high_mark:
8787
return []
88-
fields = [col.target.column for col in self.columns] if self.columns else None
88+
fields = {}
89+
for name, expr in self.columns or []:
90+
try:
91+
fields[expr.target.column] = 1
92+
except AttributeError:
93+
fields[name] = expr.as_mql_agg(self.compiler, self.connection)
8994
cursor = self.collection.find(self.mongo_query, fields)
9095
if self.ordering:
9196
cursor.sort(self.ordering)

0 commit comments

Comments
 (0)