Skip to content

Commit ccd7c5f

Browse files
committed
add support for math database functions
1 parent db3b2d0 commit ccd7c5f

File tree

5 files changed

+79
-0
lines changed

5 files changed

+79
-0
lines changed

.github/workflows/test-python.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ jobs:
7474
bulk_create
7575
dates
7676
datetimes
77+
db_functions.math
7778
empty
7879
defer
7980
defer_regress

django_mongodb/features.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,20 @@ class DatabaseFeatures(BaseDatabaseFeatures):
5454
"lookup.tests.LookupTests.test_pattern_lookups_with_substr",
5555
# Querying ObjectID with string doesn't work.
5656
"lookup.tests.LookupTests.test_lookup_int_as_str",
57+
# MongoDB gives the wrong result of log(number, base) when base is a
58+
# fractional Decimal: https://jira.mongodb.org/browse/SERVER-91223
59+
"db_functions.math.test_log.LogTests.test_decimal",
60+
# MongoDB gives ROUND(365, -1)=360 instead of 370 like other databases.
61+
"db_functions.math.test_round.RoundTests.test_integer_with_negative_precision",
5762
}
5863

5964
django_test_skips = {
6065
"Insert expressions aren't supported.": {
6166
"bulk_create.tests.BulkCreateTests.test_bulk_insert_now",
6267
"bulk_create.tests.BulkCreateTests.test_bulk_insert_expressions",
68+
# PI()
69+
"db_functions.math.test_round.RoundTests.test_decimal_with_precision",
70+
"db_functions.math.test_round.RoundTests.test_float_with_precision",
6371
},
6472
"Pattern lookups on UUIDField are not supported.": {
6573
"model_fields.test_uuid.TestQuerying.test_contains",
@@ -292,4 +300,28 @@ class DatabaseFeatures(BaseDatabaseFeatures):
292300
"timezones.tests.NewDatabaseTests.test_aware_datetime_in_local_timezone_with_microsecond",
293301
"timezones.tests.NewDatabaseTests.test_naive_datetime_with_microsecond",
294302
},
303+
"Transform not supported.": {
304+
"db_functions.math.test_abs.AbsTests.test_transform",
305+
"db_functions.math.test_acos.ACosTests.test_transform",
306+
"db_functions.math.test_asin.ASinTests.test_transform",
307+
"db_functions.math.test_atan.ATanTests.test_transform",
308+
"db_functions.math.test_ceil.CeilTests.test_transform",
309+
"db_functions.math.test_cos.CosTests.test_transform",
310+
"db_functions.math.test_cot.CotTests.test_transform",
311+
"db_functions.math.test_degrees.DegreesTests.test_transform",
312+
"db_functions.math.test_exp.ExpTests.test_transform",
313+
"db_functions.math.test_floor.FloorTests.test_transform",
314+
"db_functions.math.test_ln.LnTests.test_transform",
315+
"db_functions.math.test_radians.RadiansTests.test_transform",
316+
"db_functions.math.test_round.RoundTests.test_transform",
317+
"db_functions.math.test_sin.SinTests.test_transform",
318+
"db_functions.math.test_sqrt.SqrtTests.test_transform",
319+
"db_functions.math.test_tan.TanTests.test_transform",
320+
},
321+
"MongoDB does not support Sign.": {
322+
"db_functions.math.test_sign.SignTests",
323+
},
324+
"MongoDB can't annotate ($project) a function like PI().": {
325+
"db_functions.math.test_pi.PiTests.test",
326+
},
295327
}

django_mongodb/functions.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
11
from django.db import NotSupportedError
2+
from django.db.models.expressions import Func
23
from django.db.models.functions.datetime import Extract
4+
from django.db.models.functions.math import Ceil, Cot, Degrees, Log, Power, Radians, Random, Round
5+
from django.db.models.functions.text import Upper
36

47
from .query_utils import process_lhs
58

9+
MONGO_OPERATORS = {
10+
Ceil: "ceil",
11+
Degrees: "radiansToDegrees",
12+
Power: "pow",
13+
Radians: "degreesToRadians",
14+
Random: "rand",
15+
Upper: "toUpper",
16+
}
17+
18+
19+
def cot(self, compiler, connection):
20+
lhs_mql = process_lhs(self, compiler, connection)
21+
return {"$divide": [1, {"$tan": lhs_mql}]}
22+
623

724
def extract(self, compiler, connection):
825
lhs_mql = process_lhs(self, compiler, connection)
@@ -17,5 +34,28 @@ def extract(self, compiler, connection):
1734
return {operator: lhs_mql}
1835

1936

37+
def func(self, compiler, connection):
38+
lhs_mql = process_lhs(self, compiler, connection)
39+
operator = MONGO_OPERATORS.get(self.__class__, self.function.lower())
40+
return {f"${operator}": lhs_mql}
41+
42+
43+
def log(self, compiler, connection):
44+
# This function is usually log(base, num) but on MongoDB it's log(num, base).
45+
clone = self.copy()
46+
clone.set_source_expressions(self.get_source_expressions()[::-1])
47+
return func(clone, compiler, connection)
48+
49+
50+
def round_(self, compiler, connection):
51+
# Round needs its own function because it's a special case that inherits
52+
# from Transform but has two arguments.
53+
return {"$round": [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]}
54+
55+
2056
def register_functions():
57+
Cot.as_mql_agg = cot
2158
Extract.as_mql = extract
59+
Func.as_mql_agg = func
60+
Log.as_mql_agg = log
61+
Round.as_mql_agg = round_

django_mongodb/operations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def adapt_datetimefield_value(self, value):
2323

2424
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
2525
"""Store DecimalField as Decimal128."""
26+
if value is None:
27+
return None
2628
return Decimal128(value)
2729

2830
def adapt_timefield_value(self, value):

django_mongodb/query_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ def is_direct_value(node):
66

77

88
def process_lhs(node, compiler, connection, bare_column_ref=False):
9+
if not hasattr(node, "lhs"):
10+
# node is a Func or Expression, possibly with multiple source expressions.
11+
return [expr.as_mql(compiler, connection) for expr in node.get_source_expressions()]
12+
# node is a Transform with just one source expression, aliased as "lhs".
913
if is_direct_value(node.lhs):
1014
return node
1115
mql = node.lhs.as_mql(compiler, connection)

0 commit comments

Comments
 (0)