diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 23989e0d..3f8090e9 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -74,6 +74,7 @@ jobs: bulk_create dates datetimes + db_functions.math empty defer defer_regress diff --git a/django_mongodb/features.py b/django_mongodb/features.py index ada7aafa..08c7cbe2 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -54,12 +54,20 @@ class DatabaseFeatures(BaseDatabaseFeatures): "lookup.tests.LookupTests.test_pattern_lookups_with_substr", # Querying ObjectID with string doesn't work. "lookup.tests.LookupTests.test_lookup_int_as_str", + # MongoDB gives the wrong result of log(number, base) when base is a + # fractional Decimal: https://jira.mongodb.org/browse/SERVER-91223 + "db_functions.math.test_log.LogTests.test_decimal", + # MongoDB gives ROUND(365, -1)=360 instead of 370 like other databases. + "db_functions.math.test_round.RoundTests.test_integer_with_negative_precision", } django_test_skips = { "Insert expressions aren't supported.": { "bulk_create.tests.BulkCreateTests.test_bulk_insert_now", "bulk_create.tests.BulkCreateTests.test_bulk_insert_expressions", + # PI() + "db_functions.math.test_round.RoundTests.test_decimal_with_precision", + "db_functions.math.test_round.RoundTests.test_float_with_precision", }, "Pattern lookups on UUIDField are not supported.": { "model_fields.test_uuid.TestQuerying.test_contains", @@ -292,4 +300,28 @@ class DatabaseFeatures(BaseDatabaseFeatures): "timezones.tests.NewDatabaseTests.test_aware_datetime_in_local_timezone_with_microsecond", "timezones.tests.NewDatabaseTests.test_naive_datetime_with_microsecond", }, + "Transform not supported.": { + "db_functions.math.test_abs.AbsTests.test_transform", + "db_functions.math.test_acos.ACosTests.test_transform", + "db_functions.math.test_asin.ASinTests.test_transform", + "db_functions.math.test_atan.ATanTests.test_transform", + "db_functions.math.test_ceil.CeilTests.test_transform", + "db_functions.math.test_cos.CosTests.test_transform", + "db_functions.math.test_cot.CotTests.test_transform", + "db_functions.math.test_degrees.DegreesTests.test_transform", + "db_functions.math.test_exp.ExpTests.test_transform", + "db_functions.math.test_floor.FloorTests.test_transform", + "db_functions.math.test_ln.LnTests.test_transform", + "db_functions.math.test_radians.RadiansTests.test_transform", + "db_functions.math.test_round.RoundTests.test_transform", + "db_functions.math.test_sin.SinTests.test_transform", + "db_functions.math.test_sqrt.SqrtTests.test_transform", + "db_functions.math.test_tan.TanTests.test_transform", + }, + "MongoDB does not support Sign.": { + "db_functions.math.test_sign.SignTests", + }, + "MongoDB can't annotate ($project) a function like PI().": { + "db_functions.math.test_pi.PiTests.test", + }, } diff --git a/django_mongodb/functions.py b/django_mongodb/functions.py index 2c188edf..303c21eb 100644 --- a/django_mongodb/functions.py +++ b/django_mongodb/functions.py @@ -1,8 +1,25 @@ from django.db import NotSupportedError +from django.db.models.expressions import Func from django.db.models.functions.datetime import Extract +from django.db.models.functions.math import Ceil, Cot, Degrees, Log, Power, Radians, Random, Round +from django.db.models.functions.text import Upper from .query_utils import process_lhs +MONGO_OPERATORS = { + Ceil: "ceil", + Degrees: "radiansToDegrees", + Power: "pow", + Radians: "degreesToRadians", + Random: "rand", + Upper: "toUpper", +} + + +def cot(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + return {"$divide": [1, {"$tan": lhs_mql}]} + def extract(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) @@ -17,5 +34,28 @@ def extract(self, compiler, connection): return {operator: lhs_mql} +def func(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + operator = MONGO_OPERATORS.get(self.__class__, self.function.lower()) + return {f"${operator}": lhs_mql} + + +def log(self, compiler, connection): + # This function is usually log(base, num) but on MongoDB it's log(num, base). + clone = self.copy() + clone.set_source_expressions(self.get_source_expressions()[::-1]) + return func(clone, compiler, connection) + + +def round_(self, compiler, connection): + # Round needs its own function because it's a special case that inherits + # from Transform but has two arguments. + return {"$round": [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]} + + def register_functions(): + Cot.as_mql_agg = cot Extract.as_mql = extract + Func.as_mql_agg = func + Log.as_mql_agg = log + Round.as_mql_agg = round_ diff --git a/django_mongodb/operations.py b/django_mongodb/operations.py index 5bd71b93..4b210321 100644 --- a/django_mongodb/operations.py +++ b/django_mongodb/operations.py @@ -23,6 +23,8 @@ def adapt_datetimefield_value(self, value): def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None): """Store DecimalField as Decimal128.""" + if value is None: + return None return Decimal128(value) def adapt_timefield_value(self, value): diff --git a/django_mongodb/query_utils.py b/django_mongodb/query_utils.py index c6761571..204d076a 100644 --- a/django_mongodb/query_utils.py +++ b/django_mongodb/query_utils.py @@ -6,6 +6,10 @@ def is_direct_value(node): def process_lhs(node, compiler, connection, bare_column_ref=False): + if not hasattr(node, "lhs"): + # node is a Func or Expression, possibly with multiple source expressions. + return [expr.as_mql(compiler, connection) for expr in node.get_source_expressions()] + # node is a Transform with just one source expression, aliased as "lhs". if is_direct_value(node.lhs): return node mql = node.lhs.as_mql(compiler, connection)