diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index 457ecad4..d36d0597 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -1,5 +1,8 @@ +import re + +from bson.objectid import ObjectId from django.conf import settings -from django.core.exceptions import EmptyResultSet +from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import ( DatabaseError, IntegrityError, @@ -9,13 +12,31 @@ from django.db.models import NOT_PROVIDED, Count, Expression from django.db.models.aggregates import Aggregate from django.db.models.constants import LOOKUP_SEP +from django.db.models.expressions import Col +from django.db.models.fields.related_lookups import In, MultiColSource, RelatedIn +from django.db.models.functions.datetime import Extract +from django.db.models.lookups import ( + Contains, + Exact, # , IExact + IsNull, + StartsWith, +) from django.db.models.sql import compiler from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI +from django.db.models.sql.where import AND, XOR, WhereNode from .base import Cursor from .query import MongoQuery, wrap_database_errors +def safe_regex(regex, *re_args, **re_kwargs): + def wrapper(value): + return re.compile(regex % re.escape(value), *re_args, **re_kwargs) + + wrapper.__name__ = "safe_regex (%r)" % regex + return wrapper + + class SQLCompiler(compiler.SQLCompiler): """Base class for all Mongo compilers.""" @@ -123,12 +144,93 @@ def get_count(self, check_exists=False): except EmptyResultSet: return 0 + def _compile_where_node(self, node): + if node.connector == AND: + operator = "$and" + elif node.connector == XOR: + operator = "$ne" + else: + operator = "$or" + + children_mql = [] + for child in node.children: + try: + mql = self._compile(child) + except (EmptyResultSet, FullResultSet): + pass + else: + if mql: + children_mql.append(mql) + + mql = {operator: children_mql} if children_mql else {} + + if node.negated and mql: + mql = {"$not": mql} + + return mql + + def _compile_leaf_node(self, node): + return node + + def _compile_exact(self, node): + lhs_mql = self._compile(node.lhs) + rhs_mql = self._compile(node.rhs) + if isinstance(node.lhs, Extract): + return {"$expr": {"$eq": [lhs_mql, rhs_mql]}} + return {lhs_mql: rhs_mql} + + def _compile_col(self, node): + return node.target.column + + def _compile(self, node): + result = None + if isinstance(node, WhereNode): + result = self._compile_where_node(node) + elif isinstance(node, str | bool | list | int | ObjectId): + result = self._compile_leaf_node(node) + elif isinstance(node, Exact): + result = self._compile_exact(node) + elif isinstance(node, Col): + result = node.target.column + elif isinstance(node, RelatedIn | In): + if isinstance(node.lhs, MultiColSource): + raise NotImplementedError("It will be there, I promise! :D") + lhs_mql = self._compile(node.lhs) + rhs_mql = self._compile(node.rhs) + result = {lhs_mql: {"$in": rhs_mql}} + elif isinstance(node, Extract): + lhs_mql = self._compile(node.lhs) + if node.lookup_name == "week": + operator = "$week" + elif node.lookup_name == "month": + operator = "$month" + elif node.lookup_name == "year": + operator = "$year" + else: + raise NotSupportedError(f"Node of type {type(node)} is not supported") + # check if it is an expression or a column, now I take as it's column + result = {operator: f"${lhs_mql}"} + elif isinstance(node, StartsWith): + lhs_mql = self._compile(node.lhs) + rhs_mql = self._compile(node.rhs) + result = {lhs_mql: safe_regex("^%s")(rhs_mql)} + elif isinstance(node, Contains): + lhs_mql = self._compile(node.lhs) + rhs_mql = self._compile(node.rhs) + result = {lhs_mql: safe_regex("%s")(rhs_mql)} + elif isinstance(node, IsNull): + lhs_mql = self._compile(node.lhs) + result = {lhs_mql: None} if node.rhs is True else {lhs_mql: {"$ne": None}} + else: + raise NotSupportedError(f"Node of type {type(node)} is not supported") + return result + def build_query(self, columns=None): """Check if the query is supported and prepare a MongoQuery.""" self.check_query() self.setup_query() query = self.query_class(self, columns) - query.add_filters(self.query.where) + query.mongo_query = self._compile(self.query.where) query.order_by(self._get_ordering()) # This at least satisfies the most basic unit tests.