Skip to content

Commit 9877c83

Browse files
committed
Refactor jsonfield as $expr.
1 parent e8d257a commit 9877c83

File tree

10 files changed

+181
-162
lines changed

10 files changed

+181
-162
lines changed

django_mongodb/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
check_django_compatability()
88

99
from .expressions import register_expressions # noqa: E402
10-
from .fields import load_fields # noqa: E402
10+
from .fields import register_fields # noqa: E402
1111
from .functions import register_functions # noqa: E402
1212
from .lookups import register_lookups # noqa: E402
1313
from .query import register_nodes # noqa: E402
1414

1515
register_expressions()
16+
register_fields()
1617
register_functions()
1718
register_lookups()
1819
register_nodes()
19-
load_fields()

django_mongodb/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
4444
"IntegerField": "int",
4545
"BigIntegerField": "long",
4646
"GenericIPAddressField": "string",
47+
"JSONField": "object",
4748
"OneToOneField": "int",
4849
"PositiveBigIntegerField": "int",
4950
"PositiveIntegerField": "long",

django_mongodb/expressions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1-
from django.db.models.expressions import Col, Value
1+
from django.db.models.expressions import Col, ExpressionWrapper, Value
22

33

44
def col(self, compiler, connection): # noqa: ARG001
55
return f"${self.target.column}"
66

77

8+
def expression_wrapper(self, compiler, connection):
9+
return self.expression.as_mql(compiler, connection)
10+
11+
812
def value(self, compiler, connection): # noqa: ARG001
913
return {"$literal": self.value}
1014

1115

1216
def register_expressions():
1317
Col.as_mql = col
18+
ExpressionWrapper.as_mql = expression_wrapper
1419
Value.as_mql = value

django_mongodb/features.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,11 +386,12 @@ class DatabaseFeatures(BaseDatabaseFeatures):
386386
"model_fields.test_jsonfield.JSONFieldTests.test_db_check_constraints",
387387
},
388388
"Mongodb's Null behaviour is different from sql's": {
389-
"model_fields.test_jsonfield.TestQuerying.test_none_key_and_exact_lookup",
389+
"model_fields.test_jsonfield.TestQuerying.test_expression_wrapper_key_transform",
390390
"model_fields.test_jsonfield.TestSaveLoad.test_json_null_different_from_sql_null",
391-
"model_fields.test_jsonfield.TestQuerying.test_none_key",
392391
"model_fields.test_jsonfield.TestQuerying.test_lookup_exclude",
393392
"model_fields.test_jsonfield.TestQuerying.test_lookup_exclude_nonexistent_key",
393+
"model_fields.test_jsonfield.TestQuerying.test_none_key",
394+
"model_fields.test_jsonfield.TestQuerying.test_none_key_exclude",
394395
},
395396
"Pipeline filtering": {"model_fields.test_jsonfield.TestQuerying.test_icontains"},
396397
}

django_mongodb/fields/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .auto import MongoAutoField
2+
from .json_field import register_fields
23

3-
__all__ = ["MongoAutoField"]
4+
__all__ = ["MongoAutoField", "register_fields"]

django_mongodb/fields/auto.py

Lines changed: 0 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,8 @@
11
from bson import ObjectId, errors
22
from django.core import exceptions
3-
from django.db import NotSupportedError
4-
from django.db.models import JSONField
53
from django.db.models.fields import AutoField, Field
6-
from django.db.models.fields.json import (
7-
ContainedBy,
8-
DataContains,
9-
HasAnyKeys,
10-
HasKey,
11-
HasKeyLookup,
12-
HasKeys,
13-
JSONExact,
14-
KeyTransform,
15-
KeyTransformIn,
16-
KeyTransformIsNull,
17-
)
184
from django.utils.translation import gettext_lazy as _
195

20-
from .base import DatabaseWrapper
21-
from .query_utils import process_lhs
22-
236

247
class MongoAutoField(AutoField):
258
default_error_messages = {
@@ -49,130 +32,3 @@ def to_python(self, value):
4932
code="invalid",
5033
params={"value": value},
5134
) from None
52-
53-
54-
_from_db_value = JSONField.from_db_value
55-
56-
57-
def from_db_value(self, value, expression, connection):
58-
return (
59-
value
60-
if isinstance(connection, DatabaseWrapper)
61-
else _from_db_value(self, value, expression, connection)
62-
)
63-
64-
65-
def json_process_rhs(node, compiler, connection):
66-
rhs = node.rhs
67-
if hasattr(rhs, "as_mql"):
68-
return rhs.as_mql(compiler, connection)
69-
_, value = node.process_rhs(compiler, connection)
70-
lookup_name = node.lookup_name
71-
if lookup_name not in ("in", "range"):
72-
value = value[0] if len(value) > 0 else []
73-
74-
return value
75-
76-
77-
def key_transform(self, compiler, connection):
78-
key_transforms = [self.key_name]
79-
previous = self.lhs
80-
while isinstance(previous, KeyTransform):
81-
key_transforms.insert(0, previous.key_name)
82-
previous = previous.lhs
83-
lhs_mql = previous.as_mql(compiler, connection)
84-
return ".".join([lhs_mql, *key_transforms])
85-
86-
87-
def data_contains(self, compiler, connection): # noqa: ARG001
88-
raise NotSupportedError("contains lookup is not supported on this database backend.")
89-
90-
91-
def contained_by(self, compiler, connection): # noqa: ARG001
92-
raise NotSupportedError("contained_by lookup is not supported on this database backend.")
93-
94-
95-
def json_exact(self, compiler, connection):
96-
lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True)
97-
rhs_mql = json_process_rhs(self, compiler, connection)
98-
if rhs_mql == "null":
99-
return {"$or": [{lhs_mql: {"$eq": None}}, {lhs_mql: {"$exists": False}}]}
100-
return {lhs_mql: {"$eq": rhs_mql, "$exists": True}}
101-
102-
103-
def key_transform_isnull(self, compiler, connection):
104-
lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True)
105-
rhs_mql = json_process_rhs(self, compiler, connection)
106-
107-
# https://code.djangoproject.com/ticket/32252
108-
return {lhs_mql: {"$exists": not rhs_mql}}
109-
110-
111-
def key_transform_in(self, compiler, connection):
112-
lhs_mql = key_transform_agg(self.lhs, compiler, connection)
113-
bare_lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True)
114-
value = json_process_rhs(self, compiler, connection)
115-
expr = connection.mongo_aggregations[self.lookup_name](lhs_mql, value)
116-
return {"$expr": expr, bare_lhs_mql: {"$exists": True}}
117-
118-
119-
def has_key_lookup(self, compiler, connection):
120-
lhs = process_lhs(self, compiler, connection, bare_column_ref=True)
121-
rhs = self.rhs
122-
if not isinstance(rhs, (list | tuple)):
123-
rhs = [rhs]
124-
paths = []
125-
for key in rhs:
126-
if isinstance(key, KeyTransform):
127-
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
128-
rhs_key_transforms = ".".join(rhs_key_transforms)
129-
else:
130-
rhs_key_transforms = str(key)
131-
rhs_json_path = f"{lhs}.{rhs_key_transforms}"
132-
paths.append(rhs_json_path)
133-
134-
keys = []
135-
for path in paths:
136-
keys.append({path: {"$exists": True}})
137-
if self.mongo_operator is None:
138-
assert len(keys) == 1
139-
return keys[0]
140-
return {self.mongo_operator: keys}
141-
142-
143-
def key_transform_agg(self, compiler, connection):
144-
key_transforms = [self.key_name]
145-
previous = self.lhs
146-
while isinstance(previous, KeyTransform):
147-
key_transforms.insert(0, previous.key_name)
148-
previous = previous.lhs
149-
lhs_mql = previous.as_mql(compiler, connection)
150-
result = f"{lhs_mql}"
151-
for key in key_transforms:
152-
get_field = {"$getField": {"input": result, "field": key}}
153-
if key.isdigit():
154-
result = {
155-
"$cond": {
156-
"if": {"$isArray": result},
157-
"then": {"$arrayElemAt": [result, int(key)]},
158-
"else": get_field,
159-
}
160-
}
161-
else:
162-
result = get_field
163-
return result
164-
165-
166-
def load_fields():
167-
JSONField.from_db_value = from_db_value
168-
DataContains.as_mql = data_contains
169-
KeyTransform.as_mql = key_transform
170-
KeyTransform.as_mql_agg = key_transform_agg
171-
JSONExact.as_mql = json_exact
172-
ContainedBy.as_mql = contained_by
173-
HasKeyLookup.as_mql = has_key_lookup
174-
HasAnyKeys.mongo_operator = "$or"
175-
HasKey.mongo_operator = None
176-
HasKeys.mongo_operator = "$and"
177-
KeyTransformIsNull.as_mql = key_transform_isnull
178-
KeyTransformIn.as_mql = key_transform_in

django_mongodb/fields/json_field.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from django.db import NotSupportedError
2+
from django.db.models import JSONField
3+
from django.db.models.fields.json import (
4+
ContainedBy,
5+
DataContains,
6+
HasAnyKeys,
7+
HasKey,
8+
HasKeyLookup,
9+
HasKeys,
10+
JSONExact,
11+
KeyTransform,
12+
KeyTransformIn,
13+
KeyTransformIsNull,
14+
KeyTransformNumericLookupMixin,
15+
)
16+
17+
from django_mongodb.lookups import builtin_lookup
18+
from django_mongodb.query_utils import process_lhs, process_rhs
19+
20+
21+
def _key_transform_root(self, compiler, connection):
22+
previous = self.lhs
23+
while isinstance(previous, KeyTransform):
24+
previous = previous.lhs
25+
return previous.as_mql(compiler, connection)
26+
27+
28+
def contained_by(self, compiler, connection): # noqa: ARG001
29+
raise NotSupportedError("contained_by lookup is not supported on this database backend.")
30+
31+
32+
def data_contains(self, compiler, connection): # noqa: ARG001
33+
raise NotSupportedError("contains lookup is not supported on this database backend.")
34+
35+
36+
_from_db_value = JSONField.from_db_value
37+
38+
39+
def from_db_value(self, value, expression, connection):
40+
"""
41+
Mongodb does not need to change the json value. It is store as it is.
42+
"""
43+
return (
44+
value
45+
if connection.vendor == "mongodb"
46+
else _from_db_value(self, value, expression, connection)
47+
)
48+
49+
50+
def has_key_lookup(self, compiler, connection):
51+
rhs = self.rhs
52+
lhs = process_lhs(self, compiler, connection)
53+
if not isinstance(rhs, (list | tuple)):
54+
rhs = [rhs]
55+
paths = []
56+
for key in rhs:
57+
if isinstance(key, KeyTransform):
58+
rhs_json_path = key.as_mql(compiler, connection)
59+
else:
60+
rhs_json_path = KeyTransform(key, self.lhs).as_mql(compiler, connection)
61+
paths.append(rhs_json_path)
62+
63+
keys = []
64+
for path in paths:
65+
keys.append({"$and": [{"$ne": [{"$type": path}, "missing"]}, {"$ne": [lhs, None]}]})
66+
if self.mongo_operator is None:
67+
assert len(keys) == 1
68+
return keys[0]
69+
return {self.mongo_operator: keys}
70+
71+
72+
__process_rhs = JSONExact.process_rhs
73+
74+
75+
def json_process_rhs(self, compiler, connection):
76+
"""
77+
Django does some transformation over the parameters if it is null.
78+
to avoid ambiguity, we keep using the parent process_rhs.
79+
"""
80+
return (
81+
super(JSONExact, self).process_rhs(compiler, connection)
82+
if connection.vendor == "mongodb"
83+
else __process_rhs(self, compiler, connection)
84+
)
85+
86+
87+
def key_transform(self, compiler, connection):
88+
key_transforms = [self.key_name]
89+
previous = self.lhs
90+
while isinstance(previous, KeyTransform):
91+
key_transforms.insert(0, previous.key_name)
92+
previous = previous.lhs
93+
lhs_mql = previous.as_mql(compiler, connection)
94+
result = lhs_mql
95+
for key in key_transforms:
96+
get_field = {"$getField": {"input": result, "field": key}}
97+
if key.isdigit() and str(int(key)) == key:
98+
result = {
99+
"$cond": {
100+
"if": {"$isArray": result},
101+
"then": {"$arrayElemAt": [result, int(key)]},
102+
"else": get_field,
103+
}
104+
}
105+
else:
106+
result = get_field
107+
return result
108+
109+
110+
def key_transform_in(self, compiler, connection):
111+
lhs_mql = process_lhs(self, compiler, connection)
112+
value = process_rhs(self, compiler, connection)
113+
expr = connection.mongo_operators[self.lookup_name](lhs_mql, value)
114+
return {"$and": [expr, {"$not": {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}}]}
115+
116+
117+
def key_transform_isnull(self, compiler, connection):
118+
"""
119+
The KeyTransformIsNull lookup borrows the logic from HasKey for isnull=False.
120+
If isnull=True, the query should only match objects that do not have the key.
121+
# https://code.djangoproject.com/ticket/32252
122+
"""
123+
lhs_mql = process_lhs(self, compiler, connection)
124+
rhs_mql = process_rhs(self, compiler, connection)
125+
root_column = _key_transform_root(self, compiler, connection)
126+
127+
result = {"$or": [{"$in": [{"$type": lhs_mql}, ["missing"]]}, {"$eq": [root_column, None]}]}
128+
if not rhs_mql:
129+
result = {"$not": result}
130+
return result
131+
132+
133+
def key_transform_numeric_lookup_mixin(self, compiler, connection):
134+
expr = builtin_lookup(self, compiler, connection)
135+
lhs = process_lhs(self, compiler, connection)
136+
return {"$and": [expr, {"$not": {"$in": [{"$type": lhs}, ["missing", "null"]]}}]}
137+
138+
139+
def register_fields():
140+
ContainedBy.as_mql = contained_by
141+
DataContains.as_mql = data_contains
142+
HasAnyKeys.mongo_operator = "$or"
143+
HasKey.mongo_operator = None
144+
HasKeyLookup.as_mql = has_key_lookup
145+
HasKeys.mongo_operator = "$and"
146+
JSONExact.process_rhs = json_process_rhs
147+
JSONField.from_db_value = from_db_value
148+
KeyTransform.as_mql = key_transform
149+
KeyTransformIn.as_mql = key_transform_in
150+
KeyTransformIsNull.as_mql = key_transform_isnull
151+
KeyTransformNumericLookupMixin.as_mql = key_transform_numeric_lookup_mixin

django_mongodb/functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def cast(self, compiler, connection):
6464
lhs_mql = process_lhs(self, compiler, connection)[0]
6565
if max_length := self.output_field.max_length:
6666
lhs_mql = {"$substrCP": [lhs_mql, 0, max_length]}
67-
lhs_mql = {"$convert": {"input": lhs_mql, "to": output_type}}
67+
if output_type == "object":
68+
lhs_mql = {"$convert": {"input": lhs_mql, "to": output_type, "onError": lhs_mql}}
69+
else:
70+
lhs_mql = {"$convert": {"input": lhs_mql, "to": output_type}}
6871
if decimal_places := getattr(self.output_field, "decimal_places", None):
6972
lhs_mql = {"$trunc": [lhs_mql, decimal_places]}
7073
return lhs_mql

django_mongodb/lookups.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from django.db import NotSupportedError
22
from django.db.models.fields.related_lookups import In, MultiColSource, RelatedIn
3-
from django.db.models.lookups import BuiltinLookup, IsNull, UUIDTextMixin
3+
from django.db.models.lookups import (
4+
BuiltinLookup,
5+
FieldGetDbPrepValueIterableMixin,
6+
IsNull,
7+
UUIDTextMixin,
8+
)
49

510
from .query_utils import process_lhs, process_rhs
611

0 commit comments

Comments
 (0)