diff --git a/README.md b/README.md index d0826df9..0be92b04 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,17 @@ Migrations for 'admin': - The `tzinfo` parameter of the `Trunc` database functions doesn't work properly because MongoDB converts the result back to UTC. +- When querying `JSONField`: + - There is no way to distinguish between a JSON "null" (represented by + `Value(None, JSONField())`) and a SQL null (queried using the `isnull` + lookup). Both of these queries return both of these nulls. + - Some queries with `Q` objects, e.g. `Q(value__foo="bar")`, don't work + properly, particularly with `QuerySet.exclude()`. + - Filtering for a `None` key, e.g. `QuerySet.filter(value__j=None)` + incorrectly returns objects where the key doesn't exist. + - You can study the skipped tests in `DatabaseFeatures.django_test_skips` for + more details on known issues. + ## Troubleshooting TODO diff --git a/django_mongodb/__init__.py b/django_mongodb/__init__.py index 5dea9ab2..a2dba98e 100644 --- a/django_mongodb/__init__.py +++ b/django_mongodb/__init__.py @@ -7,11 +7,13 @@ check_django_compatability() from .expressions import register_expressions # noqa: E402 +from .fields import register_fields # noqa: E402 from .functions import register_functions # noqa: E402 from .lookups import register_lookups # noqa: E402 from .query import register_nodes # noqa: E402 register_expressions() +register_fields() register_functions() register_lookups() register_nodes() diff --git a/django_mongodb/base.py b/django_mongodb/base.py index 9d3d4124..38c76c21 100644 --- a/django_mongodb/base.py +++ b/django_mongodb/base.py @@ -42,6 +42,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): "IntegerField": "int", "BigIntegerField": "long", "GenericIPAddressField": "string", + "JSONField": "object", "OneToOneField": "int", "PositiveBigIntegerField": "int", "PositiveIntegerField": "long", diff --git a/django_mongodb/features.py b/django_mongodb/features.py index 4a60d169..bc083130 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -5,11 +5,11 @@ class DatabaseFeatures(BaseDatabaseFeatures): greatest_least_ignores_nulls = True has_json_object_function = False + has_native_json_field = True supports_date_lookup_using_string = False supports_foreign_keys = False supports_ignore_conflicts = False - # Not implemented: https://github.com/mongodb-labs/django-mongodb/issues/8 - supports_json_field = False + supports_json_field_contains = False # BSON Date type doesn't support microsecond precision. supports_microsecond_precision = False # MongoDB stores datetimes in UTC. @@ -41,6 +41,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): # tuple index out of range in process_rhs() "lookup.tests.LookupTests.test_exact_sliced_queryset_limit_one", "lookup.tests.LookupTests.test_exact_sliced_queryset_limit_one_offset", + # Pattern lookups that use regexMatch don't work on JSONField: + # Unsupported conversion from array to string in $convert + "model_fields.test_jsonfield.TestQuerying.test_icontains", # MongoDB gives the wrong result of log(number, base) when base is a # fractional Decimal: https://jira.mongodb.org/browse/SERVER-91223 "db_functions.math.test_log.LogTests.test_decimal", @@ -53,8 +56,14 @@ class DatabaseFeatures(BaseDatabaseFeatures): # pk__in=queryset doesn't work because subqueries aren't a thing in # MongoDB. "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_and_alias_filter_in_subquery", + "model_fields.test_jsonfield.TestQuerying.test_usage_in_subquery", # Length of null considered zero rather than null. "db_functions.text.test_length.LengthTests.test_basic", + # Key transforms are incorrectly treated as joins: + # Ordering can't span tables on MongoDB (value_custom__a). + "model_fields.test_jsonfield.TestQuerying.test_order_grouping_custom_decoder", + "model_fields.test_jsonfield.TestQuerying.test_ordering_by_transform", + "model_fields.test_jsonfield.TestQuerying.test_ordering_grouping_by_key_transform", } # $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3. _django_test_expected_failures_bitwise = { @@ -221,6 +230,8 @@ def django_test_expected_failures(self): "db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_subquery_with_parameters", "expressions_case.tests.CaseExpressionTests.test_in_subquery", "lookup.tests.LookupQueryingTests.test_filter_subquery_lhs", + "model_fields.test_jsonfield.TestQuerying.test_nested_key_transform_on_subquery", + "model_fields.test_jsonfield.TestQuerying.test_obj_subquery_lookup", # Invalid $project :: caused by :: Unknown expression $count, "annotations.tests.NonAggregateAnnotationTestCase.test_combined_expression_annotation_with_aggregation", "annotations.tests.NonAggregateAnnotationTestCase.test_combined_f_expression_annotation_with_aggregation", @@ -240,6 +251,7 @@ def django_test_expected_failures(self): "expressions.tests.NegatedExpressionTests.test_filter", "expressions_case.tests.CaseExpressionTests.test_annotate_values_not_in_order_by", "expressions_case.tests.CaseExpressionTests.test_order_by_conditional_implicit", + "model_fields.test_jsonfield.TestQuerying.test_ordering_grouping_by_count", # annotate().filter().count() gives incorrect results. "db_functions.datetime.test_extract_trunc.DateFunctionTests.test_extract_year_exact_lookup", }, @@ -324,6 +336,7 @@ def django_test_expected_failures(self): "lookup.tests.LookupTests.test_lookup_collision", "lookup.tests.LookupTests.test_lookup_rhs", "lookup.tests.LookupTests.test_isnull_non_boolean_value", + "model_fields.test_jsonfield.TestQuerying.test_join_key_transform_annotation_expression", "model_fields.test_manytomanyfield.ManyToManyFieldDBTests.test_value_from_object_instance_with_pk", "model_fields.test_uuid.TestAsPrimaryKey.test_two_level_foreign_keys", "timezones.tests.LegacyDatabaseTests.test_query_annotation", @@ -343,6 +356,9 @@ def django_test_expected_failures(self): }, "Test executes raw SQL.": { "annotations.tests.NonAggregateAnnotationTestCase.test_raw_sql_with_inherited_field", + "model_fields.test_jsonfield.TestQuerying.test_key_sql_injection_escape", + "model_fields.test_jsonfield.TestQuerying.test_key_transform_raw_expression", + "model_fields.test_jsonfield.TestQuerying.test_nested_key_transform_raw_expression", "timezones.tests.LegacyDatabaseTests.test_cursor_execute_accepts_naive_datetime", "timezones.tests.LegacyDatabaseTests.test_cursor_execute_returns_naive_datetime", "timezones.tests.LegacyDatabaseTests.test_raw_sql", @@ -401,6 +417,24 @@ def django_test_expected_failures(self): "db_functions.comparison.test_cast.CastTests.test_cast_from_python_to_datetime", "db_functions.comparison.test_cast.CastTests.test_cast_to_duration", }, + "Known issue querying JSONField.": { + # An ExpressionWrapper annotation with KeyTransform followed by + # .filter(expr__isnull=False) doesn't use KeyTransformIsNull as it + # needs to. + "model_fields.test_jsonfield.TestQuerying.test_expression_wrapper_key_transform", + # There is no way to distinguish between a JSON "null" (represented + # by Value(None, JSONField())) and a SQL null (queried using the + # isnull lookup). Both of these queries return both nulls. + "model_fields.test_jsonfield.TestSaveLoad.test_json_null_different_from_sql_null", + # Some queries with Q objects, e.g. Q(value__foo="bar"), don't work + # properly, particularly with QuerySet.exclude(). + "model_fields.test_jsonfield.TestQuerying.test_lookup_exclude", + "model_fields.test_jsonfield.TestQuerying.test_lookup_exclude_nonexistent_key", + # Queries like like QuerySet.filter(value__j=None) incorrectly + # returns objects where the key doesn't exist. + "model_fields.test_jsonfield.TestQuerying.test_none_key", + "model_fields.test_jsonfield.TestQuerying.test_none_key_exclude", + }, } @cached_property diff --git a/django_mongodb/fields/__init__.py b/django_mongodb/fields/__init__.py index 6d288855..eaaee954 100644 --- a/django_mongodb/fields/__init__.py +++ b/django_mongodb/fields/__init__.py @@ -1,3 +1,8 @@ from .auto import MongoAutoField +from .json import register_json_field -__all__ = ["MongoAutoField"] +__all__ = ["register_fields", "MongoAutoField"] + + +def register_fields(): + register_json_field() diff --git a/django_mongodb/fields/json.py b/django_mongodb/fields/json.py new file mode 100644 index 00000000..218ae649 --- /dev/null +++ b/django_mongodb/fields/json.py @@ -0,0 +1,171 @@ +from django.db import NotSupportedError +from django.db.models.fields.json import ( + ContainedBy, + DataContains, + HasAnyKeys, + HasKey, + HasKeyLookup, + HasKeys, + JSONExact, + KeyTransform, + KeyTransformIn, + KeyTransformIsNull, + KeyTransformNumericLookupMixin, +) + +from ..lookups import builtin_lookup +from ..query_utils import process_lhs, process_rhs + + +def contained_by(self, compiler, connection): # noqa: ARG001 + raise NotSupportedError("contained_by lookup is not supported on this database backend.") + + +def data_contains(self, compiler, connection): # noqa: ARG001 + raise NotSupportedError("contains lookup is not supported on this database backend.") + + +def _has_key_predicate(path, root_column, negated=False): + """Return MQL to check for the existence of `path`.""" + result = { + "$and": [ + # The path must exist (i.e. not be "missing"). + {"$ne": [{"$type": path}, "missing"]}, + # If the JSONField value is None, an additional check for not null + # is needed since $type returns null instead of "missing". + {"$ne": [root_column, None]}, + ] + } + if negated: + result = {"$not": result} + return result + + +def has_key_lookup(self, compiler, connection): + """Return MQL to check for the existence of a key.""" + rhs = self.rhs + lhs = process_lhs(self, compiler, connection) + if not isinstance(rhs, list | tuple): + rhs = [rhs] + paths = [] + # Transform any "raw" keys into KeyTransforms to allow consistent handling + # in the code that follows. + for key in rhs: + rhs_json_path = key if isinstance(key, KeyTransform) else KeyTransform(key, self.lhs) + paths.append(rhs_json_path.as_mql(compiler, connection)) + keys = [] + for path in paths: + keys.append(_has_key_predicate(path, lhs)) + if self.mongo_operator is None: + return keys[0] + return {self.mongo_operator: keys} + + +_process_rhs = JSONExact.process_rhs + + +def json_exact_process_rhs(self, compiler, connection): + """Skip JSONExact.process_rhs()'s conversion of None to "null".""" + return ( + super(JSONExact, self).process_rhs(compiler, connection) + if connection.vendor == "mongodb" + else _process_rhs(self, compiler, connection) + ) + + +def key_transform(self, compiler, connection): + """ + Return MQL for this KeyTransform (JSON path). + + JSON paths cannot always be represented simply as $var.key1.key2.key3 due + to possible array types. Therefore, indexing arrays requires the use of + `arrayElemAt`. Additionally, $cond is necessary to verify the type before + performing the operation. + """ + key_transforms = [self.key_name] + previous = self.lhs + # Collect all key transforms in order. + while isinstance(previous, KeyTransform): + key_transforms.insert(0, previous.key_name) + previous = previous.lhs + lhs_mql = previous.as_mql(compiler, connection) + result = lhs_mql + # Build the MQL path using the collected key transforms. + for key in key_transforms: + get_field = {"$getField": {"input": result, "field": key}} + # Handle array indexing if the key is a digit. If key is something + # like '001', it's not an array index despite isdigit() returning True. + if key.isdigit() and str(int(key)) == key: + result = { + "$cond": { + "if": {"$isArray": result}, + "then": {"$arrayElemAt": [result, int(key)]}, + "else": get_field, + } + } + else: + result = get_field + return result + + +def key_transform_in(self, compiler, connection): + """ + Return MQL to check if a JSON path exists and that its values are in the + set of specified values (rhs). + """ + lhs_mql = process_lhs(self, compiler, connection) + # Traverse to the root column. + previous = self.lhs + while isinstance(previous, KeyTransform): + previous = previous.lhs + root_column = previous.as_mql(compiler, connection) + value = process_rhs(self, compiler, connection) + # Construct the expression to check if lhs_mql values are in rhs values. + expr = connection.mongo_operators[self.lookup_name](lhs_mql, value) + return {"$and": [_has_key_predicate(lhs_mql, root_column), expr]} + + +def key_transform_is_null(self, compiler, connection): + """ + Return MQL to check the nullability of a key. + + If `isnull=True`, the query matches objects where the key is missing or the + root column is null. If `isnull=False`, the query negates the result to + match objects where the key exists. + + Reference: https://code.djangoproject.com/ticket/32252 + """ + lhs_mql = process_lhs(self, compiler, connection) + rhs_mql = process_rhs(self, compiler, connection) + # Get the root column. + previous = self.lhs + while isinstance(previous, KeyTransform): + previous = previous.lhs + root_column = previous.as_mql(compiler, connection) + return _has_key_predicate(lhs_mql, root_column, negated=rhs_mql) + + +def key_transform_numeric_lookup_mixin(self, compiler, connection): + """ + Return MQL to check if the field exists (i.e., is not "missing" or "null") + and that the field matches the given numeric lookup expression. + """ + expr = builtin_lookup(self, compiler, connection) + lhs = process_lhs(self, compiler, connection) + # Check if the type of lhs is not "missing" or "null". + not_missing_or_null = {"$not": {"$in": [{"$type": lhs}, ["missing", "null"]]}} + return {"$and": [expr, not_missing_or_null]} + + +def register_json_field(): + ContainedBy.as_mql = contained_by + DataContains.as_mql = data_contains + HasAnyKeys.mongo_operator = "$or" + HasKey.mongo_operator = None + HasKeyLookup.as_mql = has_key_lookup + HasKeys.mongo_operator = "$and" + JSONExact.process_rhs = json_exact_process_rhs + KeyTransform.as_mql = key_transform + KeyTransformIn.as_mql = key_transform_in + KeyTransformIsNull.as_mql = key_transform_is_null + KeyTransformNumericLookupMixin.as_mql = key_transform_numeric_lookup_mixin diff --git a/django_mongodb/functions.py b/django_mongodb/functions.py index 372ca60e..0753577f 100644 --- a/django_mongodb/functions.py +++ b/django_mongodb/functions.py @@ -62,7 +62,11 @@ def cast(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection)[0] if max_length := self.output_field.max_length: lhs_mql = {"$substrCP": [lhs_mql, 0, max_length]} - lhs_mql = {"$convert": {"input": lhs_mql, "to": output_type}} + # Skip the conversion for "object" as it doesn't need to be transformed for + # interpretation by JSONField, which can handle types including int, + # object, or array. + if output_type != "object": + lhs_mql = {"$convert": {"input": lhs_mql, "to": output_type}} if decimal_places := getattr(self.output_field, "decimal_places", None): lhs_mql = {"$trunc": [lhs_mql, decimal_places]} return lhs_mql diff --git a/django_mongodb/lookups.py b/django_mongodb/lookups.py index 59f58435..35d67a5a 100644 --- a/django_mongodb/lookups.py +++ b/django_mongodb/lookups.py @@ -1,6 +1,12 @@ from django.db import NotSupportedError from django.db.models.fields.related_lookups import In, MultiColSource, RelatedIn -from django.db.models.lookups import BuiltinLookup, IsNull, PatternLookup, UUIDTextMixin +from django.db.models.lookups import ( + BuiltinLookup, + FieldGetDbPrepValueIterableMixin, + IsNull, + PatternLookup, + UUIDTextMixin, +) from .query_utils import process_lhs, process_rhs @@ -11,6 +17,22 @@ def builtin_lookup(self, compiler, connection): return connection.mongo_operators[self.lookup_name](lhs_mql, value) +_field_resolve_expression_parameter = FieldGetDbPrepValueIterableMixin.resolve_expression_parameter + + +def field_resolve_expression_parameter(self, compiler, connection, sql, param): + """For MongoDB, this method must call as_mql() instead of as_sql().""" + sql, sql_params = _field_resolve_expression_parameter(self, compiler, connection, sql, param) + if connection.vendor == "mongodb": + params = [param] + if hasattr(param, "resolve_expression"): + param = param.resolve_expression(compiler.query) + if hasattr(param, "as_mql"): + params = [param.as_mql(compiler, connection)] + return sql, params + return sql, sql_params + + def in_(self, compiler, connection): if isinstance(self.lhs, MultiColSource): raise NotImplementedError("MultiColSource is not supported.") @@ -48,6 +70,9 @@ def uuid_text_mixin(self, compiler, connection): # noqa: ARG001 def register_lookups(): BuiltinLookup.as_mql = builtin_lookup + FieldGetDbPrepValueIterableMixin.resolve_expression_parameter = ( + field_resolve_expression_parameter + ) In.as_mql = RelatedIn.as_mql = in_ IsNull.as_mql = is_null PatternLookup.prep_lookup_value_mongo = pattern_lookup_prep_lookup_value diff --git a/django_mongodb/operations.py b/django_mongodb/operations.py index a7cc4eb0..87dd9281 100644 --- a/django_mongodb/operations.py +++ b/django_mongodb/operations.py @@ -1,9 +1,11 @@ import datetime +import json import re import uuid from bson.decimal128 import Decimal128 from django.conf import settings +from django.db import DataError from django.db.backends.base.operations import BaseDatabaseOperations from django.db.models.expressions import Combinable from django.utils import timezone @@ -49,6 +51,14 @@ def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None): return None return Decimal128(value) + def adapt_json_value(self, value, encoder): + if encoder is None: + return value + try: + return json.loads(json.dumps(value, cls=encoder)) + except json.decoder.JSONDecodeError as e: + raise DataError from e + def adapt_timefield_value(self, value): """Store TimeField as datetime.""" if value is None: @@ -67,6 +77,8 @@ def get_db_converters(self, expression): converters.append(self.convert_datetimefield_value) elif internal_type == "DecimalField": converters.append(self.convert_decimalfield_value) + elif internal_type == "JSONField": + converters.append(self.convert_jsonfield_value) elif internal_type == "TimeField": converters.append(self.convert_timefield_value) elif internal_type == "UUIDField": @@ -89,6 +101,13 @@ def convert_decimalfield_value(self, value, expression, connection): value = value.to_decimal() return value + def convert_jsonfield_value(self, value, expression, connection): + """ + Convert dict data to a string so that JSONField.from_db_value() can + decode it using json.loads(). + """ + return json.dumps(value) + def convert_timefield_value(self, value, expression, connection): if value is not None: value = value.time()