diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 9f4e222c..ddfc1b44 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -80,6 +80,7 @@ jobs: lookup model_fields or_lookups + queries.tests.Ticket12807Tests.test_ticket_12807 sessions_tests timezones update diff --git a/django_mongodb/__init__.py b/django_mongodb/__init__.py index 49c15fc1..5dea9ab2 100644 --- a/django_mongodb/__init__.py +++ b/django_mongodb/__init__.py @@ -5,3 +5,13 @@ from .utils import check_django_compatability check_django_compatability() + +from .expressions import register_expressions # noqa: E402 +from .functions import register_functions # noqa: E402 +from .lookups import register_lookups # noqa: E402 +from .query import register_nodes # noqa: E402 + +register_expressions() +register_functions() +register_lookups() +register_nodes() diff --git a/django_mongodb/base.py b/django_mongodb/base.py index 276be248..ffaabd75 100644 --- a/django_mongodb/base.py +++ b/django_mongodb/base.py @@ -1,3 +1,5 @@ +import re + from django.core.exceptions import ImproperlyConfigured from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.signals import connection_created @@ -10,6 +12,7 @@ from .features import DatabaseFeatures from .introspection import DatabaseIntrospection from .operations import DatabaseOperations +from .query_utils import safe_regex from .schema import DatabaseSchemaEditor from .utils import CollectionDebugWrapper @@ -52,11 +55,23 @@ class DatabaseWrapper(BaseDatabaseWrapper): "UUIDField": "string", } operators = { - "exact": "= %s", - "gt": "> %s", - "gte": ">= %s", - "lt": "< %s", - "lte": "<= %s", + "exact": lambda val: val, + "gt": lambda val: {"$gt": val}, + "gte": lambda val: {"$gte": val}, + "lt": lambda val: {"$lt": val}, + "lte": lambda val: {"$lte": val}, + "in": lambda val: {"$in": val}, + "range": lambda val: {"$gte": val[0], "$lte": val[1]}, + "isnull": lambda val: None if val else {"$ne": None}, + "iexact": safe_regex("^%s$", re.IGNORECASE), + "startswith": safe_regex("^%s"), + "istartswith": safe_regex("^%s", re.IGNORECASE), + "endswith": safe_regex("%s$"), + "iendswith": safe_regex("%s$", re.IGNORECASE), + "contains": safe_regex("%s"), + "icontains": safe_regex("%s", re.IGNORECASE), + "regex": lambda val: re.compile(val), + "iregex": lambda val: re.compile(val, re.IGNORECASE), } display_name = "MongoDB" diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index 3dcc8035..97d620c2 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -1,4 +1,4 @@ -from django.core.exceptions import EmptyResultSet +from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import DatabaseError, IntegrityError, NotSupportedError from django.db.models import NOT_PROVIDED, Count, Expression, Value from django.db.models.aggregates import Aggregate @@ -136,7 +136,10 @@ def build_query(self, columns=None): self.check_query() self.setup_query() query = self.query_class(self, columns) - query.add_filters(self.query.where) + try: + query.mongo_query = self.query.where.as_mql(self, self.connection) + except FullResultSet: + query.mongo_query = {} query.order_by(self._get_ordering()) return query diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py new file mode 100644 index 00000000..57139927 --- /dev/null +++ b/django_mongodb/expressions.py @@ -0,0 +1,9 @@ +from django.db.models.expressions import Col + + +def col(self, compiler, connection): # noqa: ARG001 + return self.target.column + + +def register_expressions(): + Col.as_mql = col diff --git a/django_mongodb/features.py b/django_mongodb/features.py index deeb37e3..1c5226bd 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -16,11 +16,10 @@ class DatabaseFeatures(BaseDatabaseFeatures): # cannot encode object: 1: - raise DatabaseError("Nested ORs are not supported.") - - if filters.connector == OR and filters.negated: - raise NotImplementedError("Negated ORs are not supported.") - - self.add_filters(child, query=subquery) - - if filters.connector == OR and subquery: - or_conditions.extend(subquery.pop("$or", [])) - if subquery: - or_conditions.append(subquery) - - continue - - try: - field, lookup_type, value = self._decode_child(child) - except FullResultSet: - continue - - column = field.column - existing = subquery.get(column) - - if self._negated and lookup_type in NEGATED_OPERATORS_MAP: - op_func = NEGATED_OPERATORS_MAP[lookup_type] - already_negated = True - else: - op_func = OPERATORS_MAP[lookup_type] - if self._negated: - already_negated = False +def where_node(self, compiler, connection): + if self.connector == AND: + full_needed, empty_needed = len(self.children), 1 + else: + full_needed, empty_needed = 1, len(self.children) - lookup = op_func(value) + if self.connector == AND: + operator = "$and" + elif self.connector == XOR: + # https://github.com/mongodb-labs/django-mongodb/issues/27 + raise NotImplementedError("XOR is not yet supported.") + else: + operator = "$or" - if existing is None: - if self._negated and not already_negated: - lookup = {"$not": lookup} - subquery[column] = lookup - if filters.connector == OR and subquery: - or_conditions.append(subquery) - continue - - if not isinstance(existing, dict): - if not self._negated: - # {'a': o1} + {'a': o2} --> {'a': {'$all': [o1, o2]}} - assert not isinstance(lookup, dict) - subquery[column] = {"$all": [existing, lookup]} - else: - # {'a': o1} + {'a': {'$not': o2}} --> {'a': {'$all': [o1], '$nin': [o2]}} - if already_negated: - assert list(lookup) == ["$ne"] - lookup = lookup["$ne"] - assert not isinstance(lookup, dict) - subquery[column] = {"$all": [existing], "$nin": [lookup]} + children_mql = [] + for child in self.children: + try: + mql = child.as_mql(compiler, connection) + except EmptyResultSet: + empty_needed -= 1 + except FullResultSet: + full_needed -= 1 + else: + if mql: + children_mql.append(mql) else: - not_ = existing.pop("$not", None) - if not_: - assert not existing - if isinstance(lookup, dict): - assert list(lookup) == ["$ne"] - lookup = next(iter(lookup.values())) - assert not isinstance(lookup, dict), (not_, lookup) - if self._negated: - # {'not': {'a': o1}} + {'a': {'not': o2}} --> {'a': {'nin': [o1, o2]}} - subquery[column] = {"$nin": [not_, lookup]} - else: - # {'not': {'a': o1}} + {'a': o2} --> - # {'a': {'nin': [o1], 'all': [o2]}} - subquery[column] = {"$nin": [not_], "$all": [lookup]} - else: - if isinstance(lookup, dict): - if "$ne" in lookup: - if "$nin" in existing: - # {'$nin': [o1, o2]} + {'$ne': o3} --> {'$nin': [o1, o2, o3]} - assert "$ne" not in existing - existing["$nin"].append(lookup["$ne"]) - elif "$ne" in existing: - # {'$ne': o1} + {'$ne': o2} --> {'$nin': [o1, o2]} - existing["$nin"] = [existing.pop("$ne"), lookup["$ne"]] - else: - existing.update(lookup) - else: - if "$in" in lookup and "$in" in existing: - # {'$in': o1} + {'$in': o2} --> {'$in': o1 union o2} - existing["$in"] = list(set(lookup["$in"] + existing["$in"])) - else: - # {'$gt': o1} + {'$lt': o2} --> {'$gt': o1, '$lt': o2} - assert all(key not in existing for key in lookup), [ - lookup, - existing, - ] - existing.update(lookup) - else: - key = "$nin" if self._negated else "$all" - existing.setdefault(key, []).append(lookup) - - if filters.connector == OR and subquery: - or_conditions.append(subquery) - - if filters.negated: - self._negated = not self._negated - - def _decode_child(self, child): - """ - Produce arguments suitable for add_filter from a WHERE tree leaf - (a tuple). - """ - if isinstance(child, UUIDTextMixin): - raise NotSupportedError("Pattern lookups on UUIDField are not supported.") - - rhs, rhs_params = child.process_rhs(self.compiler, self.connection) - lookup_type = child.lookup_name - value = rhs_params - packed = child.lhs.get_group_by_cols()[0] - alias = packed.alias - column = packed.target.column - field = child.lhs.output_field - opts = self.query.model._meta - if alias and alias != opts.db_table: - raise NotSupportedError("MongoDB doesn't support JOINs and multi-table inheritance.") + full_needed -= 1 - # For parent.child_set queries the field held by the constraint - # is the parent's primary key, while the field the filter - # should consider is the child's foreign key field. - if column != field.column: - if not field.primary_key: - raise NotSupportedError( - "MongoDB doesn't support filtering on non-primary key ForeignKey fields." - ) + if empty_needed == 0: + raise (FullResultSet if self.negated else EmptyResultSet) + if full_needed == 0: + raise (EmptyResultSet if self.negated else FullResultSet) - field = next(f for f in opts.fields if f.column == column) + if len(children_mql) == 1: + mql = children_mql[0] + elif len(children_mql) > 1: + mql = {operator: children_mql} if children_mql else {} + else: + mql = {} - value = self._normalize_lookup_value(lookup_type, value, field) - - return field, lookup_type, value - - def _normalize_lookup_value(self, lookup_type, value, field): - """ - Undo preparations done by lookups not suitable for MongoDB, and pass - the lookup argument through DatabaseOperations.prep_lookup_value(). - """ - # Undo Lookup.get_db_prep_lookup() putting params in a list. - if lookup_type not in ("in", "range"): - if len(value) > 1: - raise DatabaseError( - "Filter lookup type was %s; expected the filter argument " - "not to be a list. Only 'in'-filters can be used with " - "lists." % lookup_type - ) - value = value[0] - - # Remove percent signs added by PatternLookup.process_rhs() for LIKE - # queries. - if lookup_type in ("startswith", "istartswith"): - value = value[:-1] - elif lookup_type in ("endswith", "iendswith"): - value = value[1:] - elif lookup_type in ("contains", "icontains"): - value = value[1:-1] - - return self.ops.prep_lookup_value(value, field, lookup_type) - - def _get_children(self, children): - """ - Filter out nodes of the given constraint tree not needed for - MongoDB queries. Check that the given constraints are supported. - """ - result = [] - for child in children: - if isinstance(child, SubqueryConstraint): - raise NotSupportedError("Subqueries are not supported.") + if not mql: + raise FullResultSet - if isinstance(child, tuple): - constraint, lookup_type, _, value = child + if self.negated and mql: + lhs, rhs = next(iter(mql.items())) + mql = {lhs: {"$not": rhs}} - # When doing a lookup using a QuerySet Django would use - # a subquery, but this won't work for MongoDB. - # TODO: Add a supports_subqueries feature and let Django - # evaluate subqueries instead of passing them as SQL - # strings (QueryWrappers) to filtering. - if isinstance(value, QuerySet): - raise NotSupportedError("Subqueries are not supported.") + return mql - # Remove leafs that were automatically added by - # sql.Query.add_filter() to handle negations of outer joins. - if lookup_type == "isnull" and constraint.field is None: - continue - result.append(child) - return result +def register_nodes(): + WhereNode.as_mql = where_node diff --git a/django_mongodb/query_utils.py b/django_mongodb/query_utils.py new file mode 100644 index 00000000..ba1dcc68 --- /dev/null +++ b/django_mongodb/query_utils.py @@ -0,0 +1,37 @@ +import re + + +def is_direct_value(node): + return not hasattr(node, "as_sql") + + +def process_lhs(node, compiler, connection): + if is_direct_value(node.lhs): + return node + return node.lhs.as_mql(compiler, connection) + + +def process_rhs(node, compiler, connection): + _, value = node.process_rhs(compiler, connection) + lookup_name = node.lookup_name + # Undo Lookup.get_db_prep_lookup() putting params in a list. + if lookup_name not in ("in", "range"): + value = value[0] + # Remove percent signs added by PatternLookup.process_rhs() for LIKE + # queries. + if lookup_name in ("startswith", "istartswith"): + value = value[:-1] + elif lookup_name in ("endswith", "iendswith"): + value = value[1:] + elif lookup_name in ("contains", "icontains"): + value = value[1:-1] + + return connection.ops.prep_lookup_value(value, node.lhs.output_field, node.lookup_name) + + +def safe_regex(regex, *re_args, **re_kwargs): + def wrapper(value): + return re.compile(regex % re.escape(value), *re_args, **re_kwargs) + + wrapper.__name__ = "safe_regex (%r)" % regex + return wrapper