diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index c2646ad8..23989e0d 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -67,6 +67,7 @@ jobs: - name: Run tests run: > python3 django_repo/tests/runtests.py --settings mongodb_settings -v 2 + annotations auth_tests.test_models.UserManagerTestCase backends.base.test_base.DatabaseWrapperTests basic diff --git a/README.md b/README.md index 9f0ccc65..49b9053e 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,6 @@ DATABASES = { ## Known issues and limitations - The following `QuerySet` methods aren't supported: - - `annotate()` - `aggregate()` - `dates()` - `datetimes()` diff --git a/django_mongodb/base.py b/django_mongodb/base.py index ffaabd75..92c85cb1 100644 --- a/django_mongodb/base.py +++ b/django_mongodb/base.py @@ -54,7 +54,25 @@ class DatabaseWrapper(BaseDatabaseWrapper): "TimeField": "date", "UUIDField": "string", } + # Django uses these operators to generate SQL queries before it generates + # MQL queries. operators = { + "exact": "= %s", + "iexact": "= UPPER(%s)", + "contains": "LIKE %s", + "icontains": "LIKE UPPER(%s)", + "regex": "~ %s", + "iregex": "~* %s", + "gt": "> %s", + "gte": ">= %s", + "lt": "< %s", + "lte": "<= %s", + "startswith": "LIKE %s", + "endswith": "LIKE %s", + "istartswith": "LIKE UPPER(%s)", + "iendswith": "LIKE UPPER(%s)", + } + mongo_operators = { "exact": lambda val: val, "gt": lambda val: {"$gt": val}, "gte": lambda val: {"$gte": val}, @@ -73,6 +91,13 @@ class DatabaseWrapper(BaseDatabaseWrapper): "regex": lambda val: re.compile(val), "iregex": lambda val: re.compile(val, re.IGNORECASE), } + mongo_aggregations = { + "exact": lambda a, b: {"$eq": [a, b]}, + "gt": lambda a, b: {"$gt": [a, b]}, + "gte": lambda a, b: {"$gte": [a, b]}, + "lt": lambda a, b: {"$lt": [a, b]}, + "lte": lambda a, b: {"$lte": [a, b]}, + } display_name = "MongoDB" vendor = "mongodb" diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index 97d620c2..5159c22c 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -1,6 +1,6 @@ from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import DatabaseError, IntegrityError, NotSupportedError -from django.db.models import NOT_PROVIDED, Count, Expression, Value +from django.db.models import NOT_PROVIDED, Count, Expression from django.db.models.aggregates import Aggregate from django.db.models.constants import LOOKUP_SEP from django.db.models.sql import compiler @@ -22,8 +22,11 @@ def execute_sql( # QuerySet.count() if self.query.annotations == {"__count": Count("*")}: return [self.get_count()] + # Specify columns if there are any annotations so that annotations are + # computed via $project. + columns = self.get_columns() if self.query.annotations else None try: - query = self.build_query() + query = self.build_query(columns) except EmptyResultSet: return None return query.fetch() @@ -55,13 +58,18 @@ def results_iter( def has_results(self): return bool(self.get_count(check_exists=True)) - def get_converters(self, columns): + def get_converters(self, expressions): converters = {} - for column in columns: - backend_converters = self.connection.ops.get_db_converters(column) - field_converters = column.field.get_db_converters(self.connection) + for name_expr in expressions: + try: + name, expr = name_expr + except TypeError: + # e.g., Count("*") + continue + backend_converters = self.connection.ops.get_db_converters(expr) + field_converters = expr.get_db_converters(self.connection) if backend_converters or field_converters: - converters[column.target.column] = backend_converters + field_converters + converters[name] = backend_converters + field_converters return converters def _make_result(self, entity, columns, converters, tuple_expected=False): @@ -72,15 +80,14 @@ def _make_result(self, entity, columns, converters, tuple_expected=False): names as keys. """ result = [] - for col in columns: + for name, col in columns: field = col.field - column = col.target.column - value = entity.get(column, NOT_PROVIDED) + value = entity.get(name, NOT_PROVIDED) if value is NOT_PROVIDED: value = field.get_default() elif converters: # Decode values using Django's database converters API. - for converter in converters.get(column, ()): + for converter in converters.get(name, ()): value = converter(value, col, self.connection) result.append(value) if tuple_expected: @@ -91,12 +98,6 @@ def check_query(self): """Check if the current query is supported by the database.""" if self.query.is_empty(): raise EmptyResultSet() - # Supported annotations are Exists() and Count(). - if self.query.annotations and self.query.annotations not in ( - {"a": Value(1)}, - {"__count": Count("*")}, - ): - raise NotSupportedError("QuerySet.annotate() is not supported on MongoDB.") if self.query.distinct: # This is a heuristic to detect QuerySet.datetimes() and dates(). # "datetimefield" and "datefield" are the names of the annotations @@ -144,11 +145,17 @@ def build_query(self, columns=None): return query def get_columns(self): - """Return columns which should be loaded by the query.""" + """ + Return a tuple of (name, expression) with the columns and annotations + which should be loaded by the query. + """ select_mask = self.query.get_select_mask() - return ( + columns = ( self.get_default_columns(select_mask) if self.query.default_cols else self.query.select ) + return tuple((column.target.column, column) for column in columns) + tuple( + self.query.annotations.items() + ) def _get_ordering(self): """ diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py index 57139927..99151d40 100644 --- a/django_mongodb/expressions.py +++ b/django_mongodb/expressions.py @@ -1,9 +1,19 @@ -from django.db.models.expressions import Col +from django.db.models.expressions import Col, Value def col(self, compiler, connection): # noqa: ARG001 - return self.target.column + return f"${self.target.column}" + + +def value(self, compiler, connection): # noqa: ARG001 + return self.value + + +def value_agg(self, compiler, connection): # noqa: ARG001 + return {"$literal": self.value} def register_expressions(): Col.as_mql = col + Value.as_mql = value + Value.as_mql_agg = value_agg diff --git a/django_mongodb/features.py b/django_mongodb/features.py index 1c5226bd..ada7aafa 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -33,13 +33,11 @@ class DatabaseFeatures(BaseDatabaseFeatures): "lookup.tests.LookupTests.test_exact_none_transform", # "Save with update_fields did not affect any rows." "basic.tests.SelectOnSaveTests.test_select_on_save_lying_update", - # filtering on large decimalfield, see https://code.djangoproject.com/ticket/34590 - # for some background. - "model_fields.test_decimalfield.DecimalFieldTests.test_lookup_decimal_larger_than_max_digits", - "model_fields.test_decimalfield.DecimalFieldTests.test_lookup_really_big_value", # 'TruncDate' object has no attribute 'as_mql'. "model_fields.test_datetimefield.DateTimeFieldTests.test_lookup_date_with_use_tz", "model_fields.test_datetimefield.DateTimeFieldTests.test_lookup_date_without_use_tz", + # BaseDatabaseOperations.date_extract_sql() not implemented. + "annotations.tests.AliasTests.test_basic_alias_f_transform_annotation", # Slicing with QuerySet.count() doesn't work. "lookup.tests.LookupTests.test_count", # Lookup in order_by() not supported: @@ -74,6 +72,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): "model_fields.test_uuid.TestQuerying.test_startswith", }, "QuerySet.update() with expression not supported.": { + "annotations.tests.AliasTests.test_update_with_alias", + "annotations.tests.NonAggregateAnnotationTestCase.test_update_with_annotation", "model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values", "timezones.tests.NewDatabaseTests.test_update_with_timedelta", "update.tests.AdvancedTests.test_update_annotated_queryset", @@ -89,6 +89,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): "model_fields.test_autofield.SmallAutoFieldTests", }, "QuerySet.select_related() not supported.": { + "annotations.tests.AliasTests.test_joined_alias_annotation", + "annotations.tests.NonAggregateAnnotationTestCase.test_joined_annotation", "defer.tests.DeferTests.test_defer_foreign_keys_are_deferred_and_not_traversed", "defer.tests.DeferTests.test_defer_with_select_related", "defer.tests.DeferTests.test_only_with_select_related", @@ -126,39 +128,87 @@ class DatabaseFeatures(BaseDatabaseFeatures): }, # https://github.com/mongodb-labs/django-mongodb/issues/12 "QuerySet.aggregate() not supported.": { + "annotations.tests.AliasTests.test_filter_alias_agg_with_double_f", + "annotations.tests.NonAggregateAnnotationTestCase.test_aggregate_over_annotation", + "annotations.tests.NonAggregateAnnotationTestCase.test_aggregate_over_full_expression_annotation", + "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_exists_aggregate_values_chaining", + "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_in_f_grouped_by_annotation", + "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_subquery_and_aggregate_values_chaining", + "annotations.tests.NonAggregateAnnotationTestCase.test_filter_agg_with_double_f", "lookup.tests.LookupQueryingTests.test_aggregate_combined_lookup", "from_db_value.tests.FromDBValueTest.test_aggregation", "timezones.tests.LegacyDatabaseTests.test_query_aggregation", "timezones.tests.NewDatabaseTests.test_query_aggregation", }, - "QuerySet.annotate() not supported.": { - "lookup.test_decimalfield.DecimalFieldLookupTests", + "QuerySet.annotate() has some limitations.": { + # Exists not supported. + "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_exists_none_query", "lookup.tests.LookupTests.test_exact_exists", "lookup.tests.LookupTests.test_nested_outerref_lhs", + "lookup.tests.LookupQueryingTests.test_filter_exists_lhs", + # QuerySet.alias() doesn't work. + "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_and_alias_filter_in_subquery", "lookup.tests.LookupQueryingTests.test_alias", - "lookup.tests.LookupQueryingTests.test_annotate", - "lookup.tests.LookupQueryingTests.test_annotate_field_greater_than_field", - "lookup.tests.LookupQueryingTests.test_annotate_field_greater_than_literal", - "lookup.tests.LookupQueryingTests.test_annotate_field_greater_than_value", - "lookup.tests.LookupQueryingTests.test_annotate_greater_than_or_equal", - "lookup.tests.LookupQueryingTests.test_annotate_greater_than_or_equal_float", - "lookup.tests.LookupQueryingTests.test_annotate_less_than_float", - "lookup.tests.LookupQueryingTests.test_annotate_literal_greater_than_field", - "lookup.tests.LookupQueryingTests.test_annotate_value_greater_than_value", + # annotate() with combined expressions doesn't work: + # 'WhereNode' object has no attribute 'field' "lookup.tests.LookupQueryingTests.test_combined_annotated_lookups_in_filter", "lookup.tests.LookupQueryingTests.test_combined_annotated_lookups_in_filter_false", "lookup.tests.LookupQueryingTests.test_combined_lookups", + # Case not supported. "lookup.tests.LookupQueryingTests.test_conditional_expression", - "lookup.tests.LookupQueryingTests.test_filter_exists_lhs", + # Using expression in filter() doesn't work. "lookup.tests.LookupQueryingTests.test_filter_lookup_lhs", + # Subquery not supported. + "annotations.tests.NonAggregateAnnotationTestCase.test_empty_queryset_annotation", "lookup.tests.LookupQueryingTests.test_filter_subquery_lhs", + # ExpressionWrapper not supported. + "annotations.tests.NonAggregateAnnotationTestCase.test_combined_expression_annotation_with_aggregation", + "annotations.tests.NonAggregateAnnotationTestCase.test_combined_f_expression_annotation_with_aggregation", + "annotations.tests.NonAggregateAnnotationTestCase.test_empty_expression_annotation", + "annotations.tests.NonAggregateAnnotationTestCase.test_full_expression_annotation", + "annotations.tests.NonAggregateAnnotationTestCase.test_full_expression_annotation_with_aggregation", + "annotations.tests.NonAggregateAnnotationTestCase.test_grouping_by_q_expression_annotation", + "annotations.tests.NonAggregateAnnotationTestCase.test_mixed_type_annotation_numbers", + "annotations.tests.NonAggregateAnnotationTestCase.test_q_expression_annotation_with_aggregation", "lookup.tests.LookupQueryingTests.test_filter_wrapped_lookup_lhs", + # Length not implemented. + "annotations.tests.NonAggregateAnnotationTestCase.test_chaining_transforms", + # CombinedExpression not implemented. + "annotations.tests.NonAggregateAnnotationTestCase.test_combined_annotation_commutative", + "annotations.tests.NonAggregateAnnotationTestCase.test_decimal_annotation", + "annotations.tests.NonAggregateAnnotationTestCase.test_defer_annotation", + "annotations.tests.NonAggregateAnnotationTestCase.test_filter_decimal_annotation", + "annotations.tests.NonAggregateAnnotationTestCase.test_values_annotation", + # Func not implemented. + "annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions", + "annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions_can_ref_other_functions", + # Floor not implemented. + "annotations.tests.NonAggregateAnnotationTestCase.test_custom_transform_annotation", + # Coalesce not implemented. + "annotations.tests.AliasTests.test_alias_annotation_expression", + "annotations.tests.NonAggregateAnnotationTestCase.test_full_expression_wrapped_annotation", + # BaseDatabaseOperations may require a datetime_extract_sql(). + "annotations.tests.NonAggregateAnnotationTestCase.test_joined_transformed_annotation", + # BaseDatabaseOperations may require a format_for_duration_arithmetic(). + "annotations.tests.NonAggregateAnnotationTestCase.test_mixed_type_annotation_date_interval", + # FieldDoesNotExist with ordering. + "annotations.tests.AliasTests.test_order_by_alias", + "annotations.tests.NonAggregateAnnotationTestCase.test_order_by_aggregate", + "annotations.tests.NonAggregateAnnotationTestCase.test_order_by_annotation", + }, + "Count doesn't work in QuerySet.annotate()": { + "annotations.tests.AliasTests.test_alias_annotate_with_aggregation", + "annotations.tests.AliasTests.test_order_by_alias_aggregate", + "annotations.tests.NonAggregateAnnotationTestCase.test_annotate_exists", + "annotations.tests.NonAggregateAnnotationTestCase.test_annotate_with_aggregation", }, "QuerySet.dates() is not supported on MongoDB.": { + "annotations.tests.AliasTests.test_dates_alias", "dates.tests.DatesTests.test_dates_trunc_datetime_fields", "dates.tests.DatesTests.test_related_model_traverse", }, "QuerySet.datetimes() is not supported on MongoDB.": { + "annotations.tests.AliasTests.test_datetimes_alias", "datetimes.tests.DateTimesTests.test_21432", "datetimes.tests.DateTimesTests.test_datetimes_has_lazy_iterator", "datetimes.tests.DateTimesTests.test_datetimes_returns_available_dates_for_given_scope_and_given_field", @@ -171,6 +221,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): "update.tests.AdvancedTests.test_update_all", }, "QuerySet.extra() is not supported.": { + "annotations.tests.NonAggregateAnnotationTestCase.test_column_field_ordering", + "annotations.tests.NonAggregateAnnotationTestCase.test_column_field_ordering_with_deferred", "basic.tests.ModelTest.test_extra_method_select_argument_with_dashes", "basic.tests.ModelTest.test_extra_method_select_argument_with_dashes_and_values", "defer.tests.DeferTests.test_defer_extra", @@ -178,6 +230,16 @@ class DatabaseFeatures(BaseDatabaseFeatures): "lookup.tests.LookupTests.test_values_list", }, "Queries with multiple tables are not supported.": { + "annotations.tests.AliasTests.test_alias_default_alias_expression", + "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o", + "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_and_alias_filter_related_in_subquery", + "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_filter_with_subquery", + "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_reverse_m2m", + "annotations.tests.NonAggregateAnnotationTestCase.test_mti_annotations", + "annotations.tests.NonAggregateAnnotationTestCase.test_values_with_pk_annotation", + "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_subquery_outerref_transform", + "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_with_m2m", + "annotations.tests.NonAggregateAnnotationTestCase.test_chaining_annotation_filter_with_m2m", "defer.tests.BigChildDeferTests.test_defer_baseclass_when_subclass_has_added_field", "defer.tests.BigChildDeferTests.test_defer_subclass", "defer.tests.BigChildDeferTests.test_defer_subclass_both", @@ -188,6 +250,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): "defer.tests.DeferTests.test_only_baseclass_when_subclass_has_no_added_fields", "defer.tests.TestDefer2.test_defer_inheritance_pk_chaining", "defer_regress.tests.DeferRegressionTest.test_ticket_16409", + "lookup.test_decimalfield.DecimalFieldLookupTests", "lookup.tests.LookupQueryingTests.test_multivalued_join_reuse", "lookup.tests.LookupTests.test_filter_by_reverse_related_field_transform", "lookup.tests.LookupTests.test_lookup_collision", @@ -211,6 +274,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): "lookup.tests.LookupTests.test_textfield_exact_null", }, "Test executes raw SQL.": { + "annotations.tests.NonAggregateAnnotationTestCase.test_raw_sql_with_inherited_field", "timezones.tests.LegacyDatabaseTests.test_cursor_execute_accepts_naive_datetime", "timezones.tests.LegacyDatabaseTests.test_cursor_execute_returns_naive_datetime", "timezones.tests.LegacyDatabaseTests.test_raw_sql", diff --git a/django_mongodb/functions.py b/django_mongodb/functions.py index f9595257..2c188edf 100644 --- a/django_mongodb/functions.py +++ b/django_mongodb/functions.py @@ -1,5 +1,4 @@ from django.db import NotSupportedError -from django.db.models.expressions import Col from django.db.models.functions.datetime import Extract from .query_utils import process_lhs @@ -15,8 +14,6 @@ def extract(self, compiler, connection): operator = "$year" else: raise NotSupportedError("%s is not supported." % self.__class__.__name__) - if isinstance(self.lhs, Col): - lhs_mql = f"${lhs_mql}" return {operator: lhs_mql} diff --git a/django_mongodb/lookups.py b/django_mongodb/lookups.py index a7cab6a8..33cc6b26 100644 --- a/django_mongodb/lookups.py +++ b/django_mongodb/lookups.py @@ -1,5 +1,4 @@ from django.db import NotSupportedError -from django.db.models.expressions import Col from django.db.models.fields.related_lookups import In, MultiColSource, RelatedIn from django.db.models.lookups import BuiltinLookup, Exact, IsNull, UUIDTextMixin @@ -7,17 +6,21 @@ def builtin_lookup(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True) value = process_rhs(self, compiler, connection) - rhs_mql = connection.operators[self.lookup_name](value) + rhs_mql = connection.mongo_operators[self.lookup_name](value) return {lhs_mql: rhs_mql} +def builtin_lookup_agg(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + value = process_rhs(self, compiler, connection) + return connection.mongo_aggregations[self.lookup_name](lhs_mql, value) + + def exact(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) value = process_rhs(self, compiler, connection) - if isinstance(self.lhs, Col): - lhs_mql = f"${lhs_mql}" return {"$expr": {"$eq": [lhs_mql, value]}} @@ -30,8 +33,8 @@ def in_(self, compiler, connection): def is_null(self, compiler, connection): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") - lhs_mql = process_lhs(self, compiler, connection) - rhs_mql = connection.operators["isnull"](self.rhs) + lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True) + rhs_mql = connection.mongo_operators["isnull"](self.rhs) return {lhs_mql: rhs_mql} @@ -41,6 +44,7 @@ def uuid_text_mixin(self, compiler, connection): # noqa: ARG001 def register_lookups(): BuiltinLookup.as_mql = builtin_lookup + BuiltinLookup.as_mql_agg = builtin_lookup_agg Exact.as_mql = exact In.as_mql = RelatedIn.as_mql = in_ IsNull.as_mql = is_null diff --git a/django_mongodb/operations.py b/django_mongodb/operations.py index 3768584f..5bd71b93 100644 --- a/django_mongodb/operations.py +++ b/django_mongodb/operations.py @@ -1,7 +1,7 @@ import datetime -import decimal import uuid +from bson.decimal128 import Decimal128 from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations from django.utils import timezone @@ -21,6 +21,10 @@ def adapt_datetimefield_value(self, value): value = timezone.make_aware(value) return value + def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None): + """Store DecimalField as Decimal128.""" + return Decimal128(value) + def adapt_timefield_value(self, value): """Store TimeField as datetime.""" if value is None: @@ -56,7 +60,8 @@ def convert_datetimefield_value(self, value, expression, connection): def convert_decimalfield_value(self, value, expression, connection): if value is not None: - value = decimal.Decimal(value) + # from Decimal128 to decimal.Decimal() + value = value.to_decimal() return value def convert_timefield_value(self, value, expression, connection): diff --git a/django_mongodb/query.py b/django_mongodb/query.py index 84d38d28..5b18c5b4 100644 --- a/django_mongodb/query.py +++ b/django_mongodb/query.py @@ -85,7 +85,17 @@ def delete(self): def get_cursor(self): if self.query.low_mark == self.query.high_mark: return [] - fields = {col.target.column: 1 for col in self.columns} if self.columns else None + fields = {} + for name, expr in self.columns or []: + try: + column = expr.target.column + except AttributeError: + # Generate the MQL for an annotation. + fields[name] = expr.as_mql_agg(self.compiler, self.connection) + else: + # If name != column, then this is an annotatation referencing + # another column. + fields[name] = 1 if name == column else f"${column}" pipeline = [] if self.mongo_query: pipeline.append({"$match": self.mongo_query}) diff --git a/django_mongodb/query_utils.py b/django_mongodb/query_utils.py index ba1dcc68..c6761571 100644 --- a/django_mongodb/query_utils.py +++ b/django_mongodb/query_utils.py @@ -5,13 +5,20 @@ def is_direct_value(node): return not hasattr(node, "as_sql") -def process_lhs(node, compiler, connection): +def process_lhs(node, compiler, connection, bare_column_ref=False): if is_direct_value(node.lhs): return node - return node.lhs.as_mql(compiler, connection) + mql = node.lhs.as_mql(compiler, connection) + # Remove the unneeded $ from column references. + if bare_column_ref and mql.startswith("$"): + mql = mql[1:] + return mql def process_rhs(node, compiler, connection): + rhs = node.rhs + if hasattr(rhs, "as_mql"): + return rhs.as_mql(compiler, connection) _, value = node.process_rhs(compiler, connection) lookup_name = node.lookup_name # Undo Lookup.get_db_prep_lookup() putting params in a list.