diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a330132..188c8f3c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -81,4 +81,4 @@ repos: rev: "v2.2.6" hooks: - id: codespell - args: ["-L", "nin"] + args: ["-L", "nin", "-L", "searchin"] diff --git a/django_mongodb_backend/__init__.py b/django_mongodb_backend/__init__.py index 00700421..d21566d9 100644 --- a/django_mongodb_backend/__init__.py +++ b/django_mongodb_backend/__init__.py @@ -8,7 +8,7 @@ from .aggregates import register_aggregates # noqa: E402 from .checks import register_checks # noqa: E402 -from .expressions import register_expressions # noqa: E402 +from .expressions.builtins import register_expressions # noqa: E402 from .fields import register_fields # noqa: E402 from .functions import register_functions # noqa: E402 from .indexes import register_indexes # noqa: E402 diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 12da13a1..00e96cb0 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -17,6 +17,7 @@ from django.utils.functional import cached_property from pymongo import ASCENDING, DESCENDING +from .expressions.builtins import SearchExpression from .query import MongoQuery, wrap_database_errors @@ -34,6 +35,8 @@ def __init__(self, *args, **kwargs): # A list of OrderBy objects for this query. self.order_by_objs = None self.subqueries = [] + # Atlas search calls + self.search_pipeline = [] def _get_group_alias_column(self, expr, annotation_group_idx): """Generate a dummy field for use in the ids fields in $group.""" @@ -57,6 +60,29 @@ def _get_column_from_expression(self, expr, alias): column_target.set_attributes_from_name(alias) return Col(self.collection_name, column_target) + def _get_replace_expr(self, sub_expr, group, alias): + column_target = sub_expr.output_field.clone() + column_target.db_column = alias + column_target.set_attributes_from_name(alias) + inner_column = Col(self.collection_name, column_target) + if getattr(sub_expr, "distinct", False): + # If the expression should return distinct values, use + # $addToSet to deduplicate. + rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True) + group[alias] = {"$addToSet": rhs} + replacing_expr = sub_expr.copy() + replacing_expr.set_source_expressions([inner_column, None]) + else: + group[alias] = sub_expr.as_mql(self, self.connection) + replacing_expr = inner_column + # Count must return 0 rather than null. + if isinstance(sub_expr, Count): + replacing_expr = Coalesce(replacing_expr, 0) + # Variance = StdDev^2 + if isinstance(sub_expr, Variance): + replacing_expr = Power(replacing_expr, 2) + return replacing_expr + def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx): """ Prepare expressions for the aggregation pipeline. @@ -80,29 +106,41 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group alias = ( f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target ) - column_target = sub_expr.output_field.clone() - column_target.db_column = alias - column_target.set_attributes_from_name(alias) - inner_column = Col(self.collection_name, column_target) - if sub_expr.distinct: - # If the expression should return distinct values, use - # $addToSet to deduplicate. - rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True) - group[alias] = {"$addToSet": rhs} - replacing_expr = sub_expr.copy() - replacing_expr.set_source_expressions([inner_column, None]) - else: - group[alias] = sub_expr.as_mql(self, self.connection) - replacing_expr = inner_column - # Count must return 0 rather than null. - if isinstance(sub_expr, Count): - replacing_expr = Coalesce(replacing_expr, 0) - # Variance = StdDev^2 - if isinstance(sub_expr, Variance): - replacing_expr = Power(replacing_expr, 2) - replacements[sub_expr] = replacing_expr + replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias) return replacements, group + def _prepare_search_expressions_for_pipeline( + self, expression, target, search_idx, replacements + ): + searches = {} + for sub_expr in self._get_search_expressions(expression): + if sub_expr not in replacements: + alias = f"__search_expr.search{next(search_idx)}" + replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias) + return list(searches.values()) + + def _prepare_search_query_for_aggregation_pipeline(self, order_by): + replacements = {} + searches = [] + annotation_group_idx = itertools.count(start=1) + for target, expr in self.query.annotation_select.items(): + expr_searches = self._prepare_search_expressions_for_pipeline( + expr, target, annotation_group_idx, replacements + ) + searches += expr_searches + + for expr, _ in order_by: + expr_searches = self._prepare_search_expressions_for_pipeline( + expr, None, annotation_group_idx, replacements + ) + searches += expr_searches + + having_group = self._prepare_search_expressions_for_pipeline( + self.having, None, annotation_group_idx, replacements + ) + searches += having_group + return searches, replacements + def _prepare_annotations_for_aggregation_pipeline(self, order_by): """Prepare annotations for the aggregation pipeline.""" replacements = {} @@ -179,6 +217,9 @@ def _get_group_id_expressions(self, order_by): ids = self.get_project_fields(tuple(columns), force_expression=True) return ids, replacements + def _build_search_pipeline(self, search_queries): + pass + def _build_aggregation_pipeline(self, ids, group): """Build the aggregation pipeline for grouping.""" pipeline = [] @@ -207,9 +248,21 @@ def _build_aggregation_pipeline(self, ids, group): pipeline.append({"$unset": "_id"}) return pipeline + def _compound_searches_queries(self, searches): + if not searches: + return [] + if len(searches) > 1: + raise ValueError("Cannot perform more than one search operation.") + return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": "searchScore"}}}] + def pre_sql_setup(self, with_col_aliases=False): extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases) - group, all_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by) + searches, search_replacements = self._prepare_search_query_for_aggregation_pipeline( + order_by + ) + group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by) + all_replacements = {**search_replacements, **group_replacements} + self.search_pipeline = self._compound_searches_queries(searches) # query.group_by is either: # - None: no GROUP BY # - True: group by select fields @@ -557,10 +610,16 @@ def get_lookup_pipeline(self): return result def _get_aggregate_expressions(self, expr): + return self._get_all_expressions_of_type(expr, Aggregate) + + def _get_search_expressions(self, expr): + return self._get_all_expressions_of_type(expr, SearchExpression) + + def _get_all_expressions_of_type(self, expr, target_type): stack = [expr] while stack: expr = stack.pop() - if isinstance(expr, Aggregate): + if isinstance(expr, target_type): yield expr elif hasattr(expr, "get_source_expressions"): stack.extend(expr.get_source_expressions()) diff --git a/django_mongodb_backend/expressions.py b/django_mongodb_backend/expressions.py deleted file mode 100644 index b8fbebf5..00000000 --- a/django_mongodb_backend/expressions.py +++ /dev/null @@ -1,233 +0,0 @@ -import datetime -from decimal import Decimal -from uuid import UUID - -from bson import Decimal128 -from django.core.exceptions import EmptyResultSet, FullResultSet -from django.db import NotSupportedError -from django.db.models.expressions import ( - Case, - Col, - ColPairs, - CombinedExpression, - Exists, - ExpressionList, - ExpressionWrapper, - F, - NegatedExpression, - OrderBy, - RawSQL, - Ref, - ResolvedOuterRef, - Star, - Subquery, - Value, - When, -) -from django.db.models.sql import Query - -from .query_utils import process_lhs - - -def case(self, compiler, connection): - case_parts = [] - for case in self.cases: - case_mql = {} - try: - case_mql["case"] = case.as_mql(compiler, connection) - except EmptyResultSet: - continue - except FullResultSet: - default_mql = case.result.as_mql(compiler, connection) - break - case_mql["then"] = case.result.as_mql(compiler, connection) - case_parts.append(case_mql) - else: - default_mql = self.default.as_mql(compiler, connection) - if not case_parts: - return default_mql - return { - "$switch": { - "branches": case_parts, - "default": default_mql, - } - } - - -def col(self, compiler, connection): # noqa: ARG001 - # If the column is part of a subquery and belongs to one of the parent - # queries, it will be stored for reference using $let in a $lookup stage. - # If the query is built with `alias_cols=False`, treat the column as - # belonging to the current collection. - if self.alias is not None and ( - self.alias not in compiler.query.alias_refcount - or compiler.query.alias_refcount[self.alias] == 0 - ): - try: - index = compiler.column_indices[self] - except KeyError: - index = len(compiler.column_indices) - compiler.column_indices[self] = index - return f"$${compiler.PARENT_FIELD_TEMPLATE.format(index)}" - # Add the column's collection's alias for columns in joined collections. - has_alias = self.alias and self.alias != compiler.collection_name - prefix = f"{self.alias}." if has_alias else "" - return f"${prefix}{self.target.column}" - - -def col_pairs(self, compiler, connection): - cols = self.get_cols() - if len(cols) > 1: - raise NotSupportedError("ColPairs is not supported.") - return cols[0].as_mql(compiler, connection) - - -def combined_expression(self, compiler, connection): - expressions = [ - self.lhs.as_mql(compiler, connection), - self.rhs.as_mql(compiler, connection), - ] - return connection.ops.combine_expression(self.connector, expressions) - - -def expression_wrapper(self, compiler, connection): - return self.expression.as_mql(compiler, connection) - - -def f(self, compiler, connection): # noqa: ARG001 - return f"${self.name}" - - -def negated_expression(self, compiler, connection): - return {"$not": expression_wrapper(self, compiler, connection)} - - -def order_by(self, compiler, connection): - return self.expression.as_mql(compiler, connection) - - -def query(self, compiler, connection, get_wrapping_pipeline=None): - subquery_compiler = self.get_compiler(connection=connection) - subquery_compiler.pre_sql_setup(with_col_aliases=False) - field_name, expr = subquery_compiler.columns[0] - subquery = subquery_compiler.build_query( - subquery_compiler.columns - if subquery_compiler.query.annotations or not subquery_compiler.query.default_cols - else None - ) - table_output = f"__subquery{len(compiler.subqueries)}" - from_table = next( - e.table_name for alias, e in self.alias_map.items() if self.alias_refcount[alias] - ) - # To perform a subquery, a $lookup stage that escapsulates the entire - # subquery pipeline is added. The "let" clause defines the variables - # needed to bridge the main collection with the subquery. - subquery.subquery_lookup = { - "as": table_output, - "from": from_table, - "let": { - compiler.PARENT_FIELD_TEMPLATE.format(i): col.as_mql(compiler, connection) - for col, i in subquery_compiler.column_indices.items() - }, - } - if get_wrapping_pipeline: - # The results from some lookups must be converted to a list of values. - # The output is compressed with an aggregation pipeline. - wrapping_result_pipeline = get_wrapping_pipeline( - subquery_compiler, connection, field_name, expr - ) - # If the subquery is a combinator, wrap the result at the end of the - # combinator pipeline... - if subquery.query.combinator: - subquery.combinator_pipeline.extend(wrapping_result_pipeline) - # ... otherwise put at the end of subquery's pipeline. - else: - if subquery.aggregation_pipeline is None: - subquery.aggregation_pipeline = [] - subquery.aggregation_pipeline.extend(wrapping_result_pipeline) - # Erase project_fields since the required value is projected above. - subquery.project_fields = None - compiler.subqueries.append(subquery) - return f"${table_output}.{field_name}" - - -def raw_sql(self, compiler, connection): # noqa: ARG001 - raise NotSupportedError("RawSQL is not supported on MongoDB.") - - -def ref(self, compiler, connection): # noqa: ARG001 - prefix = ( - f"{self.source.alias}." - if isinstance(self.source, Col) and self.source.alias != compiler.collection_name - else "" - ) - if hasattr(self, "ordinal"): - refs, _ = compiler.columns[self.ordinal - 1] - else: - refs = self.refs - return f"${prefix}{refs}" - - -def star(self, compiler, connection): # noqa: ARG001 - return {"$literal": True} - - -def subquery(self, compiler, connection, get_wrapping_pipeline=None): - return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) - - -def exists(self, compiler, connection, get_wrapping_pipeline=None): - try: - lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) - except EmptyResultSet: - return Value(False).as_mql(compiler, connection) - return connection.mongo_operators["isnull"](lhs_mql, False) - - -def when(self, compiler, connection): - return self.condition.as_mql(compiler, connection) - - -def value(self, compiler, connection): # noqa: ARG001 - value = self.value - if isinstance(value, int): - # Wrap numbers in $literal to prevent ambiguity when Value appears in - # $project. - return {"$literal": value} - if isinstance(value, Decimal): - return Decimal128(value) - if isinstance(value, datetime.datetime): - return value - if isinstance(value, datetime.date): - # Turn dates into datetimes since BSON doesn't support dates. - return datetime.datetime.combine(value, datetime.datetime.min.time()) - if isinstance(value, datetime.time): - # Turn times into datetimes since BSON doesn't support times. - return datetime.datetime.combine(datetime.datetime.min.date(), value) - if isinstance(value, datetime.timedelta): - # DurationField stores milliseconds rather than microseconds. - return value / datetime.timedelta(milliseconds=1) - if isinstance(value, UUID): - return value.hex - return value - - -def register_expressions(): - Case.as_mql = case - Col.as_mql = col - ColPairs.as_mql = col_pairs - CombinedExpression.as_mql = combined_expression - Exists.as_mql = exists - ExpressionList.as_mql = process_lhs - ExpressionWrapper.as_mql = expression_wrapper - F.as_mql = f - NegatedExpression.as_mql = negated_expression - OrderBy.as_mql = order_by - Query.as_mql = query - RawSQL.as_mql = raw_sql - Ref.as_mql = ref - ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql - Star.as_mql = star - Subquery.as_mql = subquery - When.as_mql = when - Value.as_mql = value diff --git a/django_mongodb_backend/expressions/__init__.py b/django_mongodb_backend/expressions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py new file mode 100644 index 00000000..6d5987cc --- /dev/null +++ b/django_mongodb_backend/expressions/builtins.py @@ -0,0 +1,546 @@ +import datetime +from decimal import Decimal +from uuid import UUID + +from bson import Decimal128 +from django.core.exceptions import EmptyResultSet, FullResultSet +from django.db import NotSupportedError +from django.db.models import Expression, FloatField +from django.db.models.expressions import ( + Case, + Col, + ColPairs, + CombinedExpression, + Exists, + ExpressionList, + ExpressionWrapper, + F, + NegatedExpression, + OrderBy, + RawSQL, + Ref, + ResolvedOuterRef, + Star, + Subquery, + Value, + When, +) +from django.db.models.sql import Query + +from ..query_utils import process_lhs + + +def case(self, compiler, connection): + case_parts = [] + for case in self.cases: + case_mql = {} + try: + case_mql["case"] = case.as_mql(compiler, connection) + except EmptyResultSet: + continue + except FullResultSet: + default_mql = case.result.as_mql(compiler, connection) + break + case_mql["then"] = case.result.as_mql(compiler, connection) + case_parts.append(case_mql) + else: + default_mql = self.default.as_mql(compiler, connection) + if not case_parts: + return default_mql + return { + "$switch": { + "branches": case_parts, + "default": default_mql, + } + } + + +def col(self, compiler, connection): # noqa: ARG001 + # If the column is part of a subquery and belongs to one of the parent + # queries, it will be stored for reference using $let in a $lookup stage. + # If the query is built with `alias_cols=False`, treat the column as + # belonging to the current collection. + if self.alias is not None and ( + self.alias not in compiler.query.alias_refcount + or compiler.query.alias_refcount[self.alias] == 0 + ): + try: + index = compiler.column_indices[self] + except KeyError: + index = len(compiler.column_indices) + compiler.column_indices[self] = index + return f"$${compiler.PARENT_FIELD_TEMPLATE.format(index)}" + # Add the column's collection's alias for columns in joined collections. + has_alias = self.alias and self.alias != compiler.collection_name + prefix = f"{self.alias}." if has_alias else "" + return f"${prefix}{self.target.column}" + + +def col_pairs(self, compiler, connection): + cols = self.get_cols() + if len(cols) > 1: + raise NotSupportedError("ColPairs is not supported.") + return cols[0].as_mql(compiler, connection) + + +def combined_expression(self, compiler, connection): + expressions = [ + self.lhs.as_mql(compiler, connection), + self.rhs.as_mql(compiler, connection), + ] + return connection.ops.combine_expression(self.connector, expressions) + + +def expression_wrapper(self, compiler, connection): + return self.expression.as_mql(compiler, connection) + + +def f(self, compiler, connection): # noqa: ARG001 + return f"${self.name}" + + +def negated_expression(self, compiler, connection): + return {"$not": expression_wrapper(self, compiler, connection)} + + +def order_by(self, compiler, connection): + return self.expression.as_mql(compiler, connection) + + +def query(self, compiler, connection, get_wrapping_pipeline=None): + subquery_compiler = self.get_compiler(connection=connection) + subquery_compiler.pre_sql_setup(with_col_aliases=False) + field_name, expr = subquery_compiler.columns[0] + subquery = subquery_compiler.build_query( + subquery_compiler.columns + if subquery_compiler.query.annotations or not subquery_compiler.query.default_cols + else None + ) + table_output = f"__subquery{len(compiler.subqueries)}" + from_table = next( + e.table_name for alias, e in self.alias_map.items() if self.alias_refcount[alias] + ) + # To perform a subquery, a $lookup stage that escapsulates the entire + # subquery pipeline is added. The "let" clause defines the variables + # needed to bridge the main collection with the subquery. + subquery.subquery_lookup = { + "as": table_output, + "from": from_table, + "let": { + compiler.PARENT_FIELD_TEMPLATE.format(i): col.as_mql(compiler, connection) + for col, i in subquery_compiler.column_indices.items() + }, + } + if get_wrapping_pipeline: + # The results from some lookups must be converted to a list of values. + # The output is compressed with an aggregation pipeline. + wrapping_result_pipeline = get_wrapping_pipeline( + subquery_compiler, connection, field_name, expr + ) + # If the subquery is a combinator, wrap the result at the end of the + # combinator pipeline... + if subquery.query.combinator: + subquery.combinator_pipeline.extend(wrapping_result_pipeline) + # ... otherwise put at the end of subquery's pipeline. + else: + if subquery.aggregation_pipeline is None: + subquery.aggregation_pipeline = [] + subquery.aggregation_pipeline.extend(wrapping_result_pipeline) + # Erase project_fields since the required value is projected above. + subquery.project_fields = None + compiler.subqueries.append(subquery) + return f"${table_output}.{field_name}" + + +def raw_sql(self, compiler, connection): # noqa: ARG001 + raise NotSupportedError("RawSQL is not supported on MongoDB.") + + +def ref(self, compiler, connection): # noqa: ARG001 + prefix = ( + f"{self.source.alias}." + if isinstance(self.source, Col) and self.source.alias != compiler.collection_name + else "" + ) + if hasattr(self, "ordinal"): + refs, _ = compiler.columns[self.ordinal - 1] + else: + refs = self.refs + return f"${prefix}{refs}" + + +def star(self, compiler, connection): # noqa: ARG001 + return {"$literal": True} + + +def subquery(self, compiler, connection, get_wrapping_pipeline=None): + return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) + + +def exists(self, compiler, connection, get_wrapping_pipeline=None): + try: + lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) + except EmptyResultSet: + return Value(False).as_mql(compiler, connection) + return connection.mongo_operators["isnull"](lhs_mql, False) + + +def when(self, compiler, connection): + return self.condition.as_mql(compiler, connection) + + +def value(self, compiler, connection): # noqa: ARG001 + value = self.value + if isinstance(value, int): + # Wrap numbers in $literal to prevent ambiguity when Value appears in + # $project. + return {"$literal": value} + if isinstance(value, Decimal): + return Decimal128(value) + if isinstance(value, datetime.datetime): + return value + if isinstance(value, datetime.date): + # Turn dates into datetimes since BSON doesn't support dates. + return datetime.datetime.combine(value, datetime.datetime.min.time()) + if isinstance(value, datetime.time): + # Turn times into datetimes since BSON doesn't support times. + return datetime.datetime.combine(datetime.datetime.min.date(), value) + if isinstance(value, datetime.timedelta): + # DurationField stores milliseconds rather than microseconds. + return value / datetime.timedelta(milliseconds=1) + if isinstance(value, UUID): + return value.hex + return value + + +class SearchExpression(Expression): + output_field = FloatField() + + def get_source_expressions(self): + return [] + + def __str__(self): + args = ", ".join(map(str, self.get_source_expressions())) + return f"{self.search_type}({args})" + + def __repr__(self): + return str(self) + + def as_sql(self, compiler, connection): + return "", [] + + def _get_query_index(self, fields, compiler): + fields = set(fields) + for search_indexes in compiler.collection.list_search_indexes(): + mappings = search_indexes["latestDefinition"]["mappings"] + if mappings["dynamic"] or fields.issubset(set(mappings["fields"])): + return search_indexes["name"] + return "default" + + +class SearchAutocomplete(SearchExpression): + def __init__(self, path, query, fuzzy=None, score=None): + self.path = path + self.query = query + self.fuzzy = fuzzy + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path, + "query": self.query, + } + if self.score is not None: + params["score"] = self.score + if self.fuzzy is not None: + params["fuzzy"] = self.fuzzy + index = self._get_query_index([self.path], compiler) + return {"$search": {"autocomplete": params, "index": index}} + + +class SearchEquals(SearchExpression): + def __init__(self, path, value, score=None): + self.path = path + self.value = value + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path, + "value": self.value, + } + if self.score is not None: + params["score"] = self.score + index = self._get_query_index([self.path], compiler) + return {"$search": {"equals": params, "index": index}} + + +class SearchExists(SearchExpression): + def __init__(self, path, score=None): + self.path = path + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path, + } + if self.score is not None: + params["score"] = self.score + index = self._get_query_index([self.path], compiler) + return {"$search": {"exists": params, "index": index}} + + +class SearchIn(SearchExpression): + def __init__(self, path, value, score=None): + self.path = path + self.value = value + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path, + "value": self.value, + } + if self.score is not None: + params["score"] = self.score + index = self._get_query_index([self.path], compiler) + return {"$search": {"in": params, "index": index}} + + +class SearchPhrase(SearchExpression): + def __init__(self, path, query, slop=None, synonyms=None, score=None): + self.path = path + self.query = query + self.score = score + self.slop = slop + self.synonyms = synonyms + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path, + "query": self.query, + } + if self.score is not None: + params["score"] = self.score + if self.slop is not None: + params["slop"] = self.slop + if self.synonyms is not None: + params["synonyms"] = self.synonyms + index = self._get_query_index([self.path], compiler) + return {"$search": {"phrase": params, "index": index}} + + +class SearchQueryString(SearchExpression): + def __init__(self, path, query, score=None): + self.path = path + self.query = query + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "defaultPath": self.path, + "query": self.query, + } + if self.score is not None: + params["score"] = self.score + index = self._get_query_index([self.path], compiler) + return {"$search": {"queryString": params, "index": index}} + + +class SearchRange(SearchExpression): + def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None): + self.path = path + self.lt = lt + self.lte = lte + self.gt = gt + self.gte = gte + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path, + } + if self.score is not None: + params["score"] = self.score + if self.lt is not None: + params["lt"] = self.lt + if self.lte is not None: + params["lte"] = self.lte + if self.gt is not None: + params["gt"] = self.gt + if self.gte is not None: + params["gte"] = self.gte + index = self._get_query_index([self.path], compiler) + return {"$search": {"range": params, "index": index}} + + +class SearchRegex(SearchExpression): + def __init__(self, path, query, allow_analyzed_field=None, score=None): + self.path = path + self.query = query + self.allow_analyzed_field = allow_analyzed_field + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path, + "query": self.query, + } + if self.score: + params["score"] = self.score + if self.allow_analyzed_field is not None: + params["allowAnalyzedField"] = self.allow_analyzed_field + index = self._get_query_index([self.path], compiler) + return {"$search": {"regex": params, "index": index}} + + +class SearchText(SearchExpression): + def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None, score=None): + self.path = path + self.query = query + self.fuzzy = fuzzy + self.match_criteria = match_criteria + self.synonyms = synonyms + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path, + "query": self.query, + } + if self.score: + params["score"] = self.score + if self.fuzzy is not None: + params["fuzzy"] = self.fuzzy + if self.match_criteria is not None: + params["matchCriteria"] = self.match_criteria + if self.synonyms is not None: + params["synonyms"] = self.synonyms + index = self._get_query_index([self.path], compiler) + return {"$search": {"text": params, "index": index}} + + +class SearchWildcard(SearchExpression): + def __init__(self, path, query, allow_analyzed_field=None, score=None): + self.path = path + self.query = query + self.allow_analyzed_field = allow_analyzed_field + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path, + "query": self.query, + } + if self.score: + params["score"] = self.score + if self.allow_analyzed_field is not None: + params["allowAnalyzedField"] = self.allow_analyzed_field + index = self._get_query_index([self.path], compiler) + return {"$search": {"wildcard": params, "index": index}} + + +class SearchGeoShape(SearchExpression): + def __init__(self, path, relation, geometry, score=None): + self.path = path + self.relation = relation + self.geometry = geometry + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path, + "relation": self.relation, + "geometry": self.geometry, + } + if self.score: + params["score"] = self.score + index = self._get_query_index([self.path], compiler) + return {"$search": {"geoShape": params, "index": index}} + + +class SearchGeoWithin(SearchExpression): + def __init__(self, path, kind, geo_object, score=None): + self.path = path + self.kind = kind + self.geo_object = geo_object + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path, + self.kind: self.geo_object, + } + if self.score: + params["score"] = self.score + index = self._get_query_index([self.path], compiler) + return {"$search": {"geoWithin": params, "index": index}} + + +class SearchMoreLikeThis(SearchExpression): + search_type = "more_like_this" + + def __init__(self, documents, score=None): + self.documents = documents + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "like": self.documents, + } + if self.score: + params["score"] = self.score + needed_fields = [] + for doc in self.documents: + needed_fields += list(doc.keys()) + index = self._get_query_index(needed_fields, compiler) + return {"$search": {"moreLikeThis": params, "index": index}} + + +class SearchScoreOption: + """Class to mutate scoring on a search operation""" + + def __init__(self, definitions=None): + self.definitions = definitions + + +class CombinedSearchExpression(SearchExpression): + def __init__(self, lhs, connector, rhs, output_field=None): + super().__init__(output_field=output_field) + self.connector = connector + self.lhs = lhs + self.rhs = rhs + + +def register_expressions(): + Case.as_mql = case + Col.as_mql = col + ColPairs.as_mql = col_pairs + CombinedExpression.as_mql = combined_expression + Exists.as_mql = exists + ExpressionList.as_mql = process_lhs + ExpressionWrapper.as_mql = expression_wrapper + F.as_mql = f + NegatedExpression.as_mql = negated_expression + OrderBy.as_mql = order_by + Query.as_mql = query + RawSQL.as_mql = raw_sql + Ref.as_mql = ref + ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql + Star.as_mql = star + Subquery.as_mql = subquery + When.as_mql = when + Value.as_mql = value diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index 04977520..0e2bc9b2 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -49,6 +49,7 @@ def __init__(self, compiler): self.lookup_pipeline = None self.project_fields = None self.aggregation_pipeline = compiler.aggregation_pipeline + self.search_pipeline = compiler.search_pipeline self.extra_fields = None self.combinator_pipeline = None # $lookup stage that encapsulates the pipeline for performing a nested @@ -75,6 +76,8 @@ def get_cursor(self): def get_pipeline(self): pipeline = [] + if self.search_pipeline: + pipeline.extend(self.search_pipeline) if self.lookup_pipeline: pipeline.extend(self.lookup_pipeline) for query in self.subqueries or (): diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index 0bb29299..c03a0f7a 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -4,7 +4,7 @@ def is_direct_value(node): - return not hasattr(node, "as_sql") + return not hasattr(node, "as_sql") and not hasattr(node, "as_mql") def process_lhs(node, compiler, connection): diff --git a/tests/queries_/models.py b/tests/queries_/models.py index 01510224..fd70b395 100644 --- a/tests/queries_/models.py +++ b/tests/queries_/models.py @@ -53,3 +53,10 @@ class Meta: def __str__(self): return str(self.pk) + + +class Article(models.Model): + headline = models.CharField(max_length=100) + number = models.IntegerField() + body = models.TextField() + location = models.JSONField(null=True) diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py new file mode 100644 index 00000000..ddb7d33a --- /dev/null +++ b/tests/queries_/test_search.py @@ -0,0 +1,298 @@ +import time + +from django.db import connection +from django.test import TestCase +from pymongo.operations import SearchIndexModel + +from django_mongodb_backend.expressions.builtins import ( + SearchAutocomplete, + SearchEquals, + SearchExists, + SearchGeoShape, + SearchGeoWithin, + SearchIn, + SearchMoreLikeThis, + SearchPhrase, + SearchRange, + SearchRegex, + SearchText, + SearchWildcard, +) + +from .models import Article + + +class CreateIndexMixin: + def _get_collection(self, model): + return connection.database.get_collection(model._meta.db_table) + + def create_search_index(self, model, index_name, definition): + collection = self._get_collection(model) + idx = SearchIndexModel(definition=definition, name=index_name) + collection.create_search_index(idx) + + +class SearchEqualsTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "equals_headline_index", + {"mappings": {"dynamic": False, "fields": {"headline": {"type": "token"}}}}, + ) + Article.objects.create(headline="cross", number=1, body="body") + time.sleep(1) + + def test_search_equals(self): + qs = Article.objects.annotate(score=SearchEquals(path="headline", value="cross")) + self.assertEqual(qs.first().headline, "cross") + + +class SearchAutocompleteTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "autocomplete_headline_index", + { + "mappings": { + "dynamic": False, + "fields": { + "headline": { + "type": "autocomplete", + "analyzer": "lucene.standard", + "tokenization": "edgeGram", + "minGrams": 3, + "maxGrams": 5, + "foldDiacritics": False, + } + }, + } + }, + ) + Article.objects.create(headline="crossing and something", number=2, body="river") + + def test_search_autocomplete(self): + qs = Article.objects.annotate(score=SearchAutocomplete(path="headline", query="crossing")) + self.assertEqual(qs.first().headline, "crossing and something") + + +class SearchExistsTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "exists_body_index", + {"mappings": {"dynamic": False, "fields": {"body": {"type": "token"}}}}, + ) + Article.objects.create(headline="ignored", number=3, body="something") + + def test_search_exists(self): + qs = Article.objects.annotate(score=SearchExists(path="body")) + self.assertEqual(qs.count(), 1) + self.assertEqual(qs.first().body, "something") + + +class SearchInTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "in_headline_index", + {"mappings": {"dynamic": False, "fields": {"headline": {"type": "token"}}}}, + ) + Article.objects.create(headline="cross", number=1, body="a") + Article.objects.create(headline="road", number=2, body="b") + time.sleep(1) + + def test_search_in(self): + qs = Article.objects.annotate(score=SearchIn(path="headline", value=["cross", "river"])) + self.assertEqual(qs.first().headline, "cross") + + +class SearchPhraseTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "phrase_body_index", + {"mappings": {"dynamic": False, "fields": {"body": {"type": "string"}}}}, + ) + Article.objects.create(headline="irrelevant", number=1, body="the quick brown fox") + time.sleep(1) + + def test_search_phrase(self): + qs = Article.objects.annotate(score=SearchPhrase(path="body", query="quick brown")) + self.assertIn("quick brown", qs.first().body) + + +class SearchRangeTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "range_number_index", + {"mappings": {"dynamic": False, "fields": {"number": {"type": "number"}}}}, + ) + Article.objects.create(headline="x", number=5, body="z") + Article.objects.create(headline="y", number=20, body="z") + time.sleep(1) + + def test_search_range(self): + qs = Article.objects.annotate(score=SearchRange(path="number", gte=10, lt=30)) + self.assertEqual(qs.first().number, 20) + + +class SearchRegexTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "regex_headline_index", + { + "mappings": { + "dynamic": False, + "fields": {"headline": {"type": "string", "analyzer": "lucene.keyword"}}, + } + }, + ) + Article.objects.create(headline="hello world", number=1, body="abc") + time.sleep(1) + + def test_search_regex(self): + qs = Article.objects.annotate( + score=SearchRegex(path="headline", query="hello.*", allow_analyzed_field=False) + ) + self.assertTrue(qs.first().headline.startswith("hello")) + + +class SearchTextTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "text_body_index", + {"mappings": {"dynamic": False, "fields": {"body": {"type": "string"}}}}, + ) + Article.objects.create(headline="ignored", number=1, body="The lazy dog sleeps") + time.sleep(1) + + def test_search_text(self): + qs = Article.objects.annotate(score=SearchText(path="body", query="lazy")) + self.assertIn("lazy", qs.first().body) + + def test_search_text_with_fuzzy_and_criteria(self): + qs = Article.objects.annotate( + score=SearchText( + path="body", query="lazzy", fuzzy={"maxEdits": 1}, match_criteria="all" + ) + ) + self.assertIn("lazy", qs.first().body) + + +class SearchWildcardTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "wildcard_headline_index", + { + "mappings": { + "dynamic": False, + "fields": {"headline": {"type": "string", "analyzer": "lucene.keyword"}}, + } + }, + ) + Article.objects.create(headline="dark-knight", number=1, body="") + time.sleep(1) + + def test_search_wildcard(self): + qs = Article.objects.annotate(score=SearchWildcard(path="headline", query="dark-*")) + self.assertIn("dark", qs.first().headline) + + +class SearchGeoShapeTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "geoshape_location_index", + { + "mappings": { + "dynamic": False, + "fields": {"location": {"type": "geo", "indexShapes": True}}, + } + }, + ) + Article.objects.create( + headline="any", number=1, body="", location={"type": "Point", "coordinates": [40, 5]} + ) + time.sleep(1) + + def test_search_geo_shape(self): + polygon = { + "type": "Polygon", + "coordinates": [[[30, 0], [50, 0], [50, 10], [30, 10], [30, 0]]], + } + qs = Article.objects.annotate( + score=SearchGeoShape(path="location", relation="within", geometry=polygon) + ) + self.assertEqual(qs.first().number, 1) + + +class SearchGeoWithinTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "geowithin_location_index", + {"mappings": {"dynamic": False, "fields": {"location": {"type": "geo"}}}}, + ) + Article.objects.create( + headline="geo", number=2, body="", location={"type": "Point", "coordinates": [40, 5]} + ) + time.sleep(1) + + def test_search_geo_within(self): + polygon = { + "type": "Polygon", + "coordinates": [[[30, 0], [50, 0], [50, 10], [30, 10], [30, 0]]], + } + qs = Article.objects.annotate( + score=SearchGeoWithin( + path="location", + kind="geometry", + geo_object=polygon, + ) + ) + self.assertEqual(qs.first().number, 2) + + +class SearchMoreLikeThisTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "mlt_index", + { + "mappings": { + "dynamic": False, + "fields": {"body": {"type": "string"}, "headline": {"type": "string"}}, + } + }, + ) + self.article1 = Article.objects.create( + headline="Space exploration", number=1, body="Webb telescope" + ) + self.article2 = Article.objects.create( + headline="The commodities fall", + number=2, + body="Commodities dropped sharply due to inflation concerns", + ) + Article.objects.create( + headline="irrelevant", + number=3, + body="This is a completely unrelated article about cooking", + ) + time.sleep(1) + + def test_search_more_like_this(self): + like_docs = [ + {"headline": self.article1.headline, "body": self.article1.body}, + {"headline": self.article2.headline, "body": self.article2.body}, + ] + like_docs = [{"body": "NASA launches new satellite to explore the galaxy"}] + qs = Article.objects.annotate(score=SearchMoreLikeThis(documents=like_docs)).order_by( + "score" + ) + self.assertQuerySetEqual( + qs, ["space exploration", "The commodities fall"], lambda a: a.headline + )