Skip to content

[WIP] Atlas search lookups #325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ repos:
rev: "v2.2.6"
hooks:
- id: codespell
args: ["-L", "nin"]
args: ["-L", "nin", "-L", "searchin"]
2 changes: 1 addition & 1 deletion django_mongodb_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .aggregates import register_aggregates # noqa: E402
from .checks import register_checks # noqa: E402
from .expressions import register_expressions # noqa: E402
from .expressions.builtins import register_expressions # noqa: E402
from .fields import register_fields # noqa: E402
from .functions import register_functions # noqa: E402
from .indexes import register_indexes # noqa: E402
Expand Down
106 changes: 83 additions & 23 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from django.utils.functional import cached_property
from pymongo import ASCENDING, DESCENDING

from .functions import SearchExpression
from .query import MongoQuery, wrap_database_errors


Expand All @@ -34,6 +35,8 @@ def __init__(self, *args, **kwargs):
# A list of OrderBy objects for this query.
self.order_by_objs = None
self.subqueries = []
# Atlas search calls
self.search_pipeline = []

def _get_group_alias_column(self, expr, annotation_group_idx):
"""Generate a dummy field for use in the ids fields in $group."""
Expand All @@ -57,6 +60,29 @@ def _get_column_from_expression(self, expr, alias):
column_target.set_attributes_from_name(alias)
return Col(self.collection_name, column_target)

def _get_replace_expr(self, sub_expr, group, alias):
column_target = sub_expr.output_field.clone()
column_target.db_column = alias
column_target.set_attributes_from_name(alias)
inner_column = Col(self.collection_name, column_target)
if getattr(sub_expr, "distinct", False):
# If the expression should return distinct values, use
# $addToSet to deduplicate.
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
group[alias] = {"$addToSet": rhs}
replacing_expr = sub_expr.copy()
replacing_expr.set_source_expressions([inner_column, None])
else:
group[alias] = sub_expr.as_mql(self, self.connection)
replacing_expr = inner_column
# Count must return 0 rather than null.
if isinstance(sub_expr, Count):
replacing_expr = Coalesce(replacing_expr, 0)
# Variance = StdDev^2
if isinstance(sub_expr, Variance):
replacing_expr = Power(replacing_expr, 2)
return replacing_expr

def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx):
"""
Prepare expressions for the aggregation pipeline.
Expand All @@ -80,29 +106,42 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
alias = (
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target
)
column_target = sub_expr.output_field.clone()
column_target.db_column = alias
column_target.set_attributes_from_name(alias)
inner_column = Col(self.collection_name, column_target)
if sub_expr.distinct:
# If the expression should return distinct values, use
# $addToSet to deduplicate.
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
group[alias] = {"$addToSet": rhs}
replacing_expr = sub_expr.copy()
replacing_expr.set_source_expressions([inner_column, None])
else:
group[alias] = sub_expr.as_mql(self, self.connection)
replacing_expr = inner_column
# Count must return 0 rather than null.
if isinstance(sub_expr, Count):
replacing_expr = Coalesce(replacing_expr, 0)
# Variance = StdDev^2
if isinstance(sub_expr, Variance):
replacing_expr = Power(replacing_expr, 2)
replacements[sub_expr] = replacing_expr
replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias)
return replacements, group

def _prepare_search_expressions_for_pipeline(self, expression, target, search_idx):
searches = {}
replacements = {}
for sub_expr in self._get_search_expressions(expression):
alias = f"__search_expr.search{next(search_idx)}"
replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias)
return replacements, list(searches.values())

def _prepare_search_query_for_aggregation_pipeline(self, order_by):
replacements = {}
searches = []
annotation_group_idx = itertools.count(start=1)
for target, expr in self.query.annotation_select.items():
new_replacements, expr_searches = self._prepare_search_expressions_for_pipeline(
expr, target, annotation_group_idx
)
replacements.update(new_replacements)
searches += expr_searches

for expr, _ in order_by:
new_replacements, expr_searches = self._prepare_search_expressions_for_pipeline(
expr, None, annotation_group_idx
)
replacements.update(new_replacements)
searches += expr_searches

having_replacements, having_group = self._prepare_search_expressions_for_pipeline(
self.having, None, annotation_group_idx
)
replacements.update(having_replacements)
searches += having_group
return searches, replacements

def _prepare_annotations_for_aggregation_pipeline(self, order_by):
"""Prepare annotations for the aggregation pipeline."""
replacements = {}
Expand Down Expand Up @@ -179,6 +218,9 @@ def _get_group_id_expressions(self, order_by):
ids = self.get_project_fields(tuple(columns), force_expression=True)
return ids, replacements

def _build_search_pipeline(self, search_queries):
pass

def _build_aggregation_pipeline(self, ids, group):
"""Build the aggregation pipeline for grouping."""
pipeline = []
Expand Down Expand Up @@ -207,9 +249,21 @@ def _build_aggregation_pipeline(self, ids, group):
pipeline.append({"$unset": "_id"})
return pipeline

def _compound_searches_queries(self, searches):
if not searches:
return []
if len(searches) > 1:
raise ValueError("Cannot perform more than one search operation.")
return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": "searchScore"}}}]

def pre_sql_setup(self, with_col_aliases=False):
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
group, all_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
searches, search_replacements = self._prepare_search_query_for_aggregation_pipeline(
order_by
)
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
all_replacements = {**search_replacements, **group_replacements}
self.search_pipeline = self._compound_searches_queries(searches)
# query.group_by is either:
# - None: no GROUP BY
# - True: group by select fields
Expand Down Expand Up @@ -557,10 +611,16 @@ def get_lookup_pipeline(self):
return result

def _get_aggregate_expressions(self, expr):
return self._get_all_expressions_of_type(expr, Aggregate)

def _get_search_expressions(self, expr):
return self._get_all_expressions_of_type(expr, SearchExpression)

def _get_all_expressions_of_type(self, expr, target_type):
stack = [expr]
while stack:
expr = stack.pop()
if isinstance(expr, Aggregate):
if isinstance(expr, target_type):
yield expr
elif hasattr(expr, "get_source_expressions"):
stack.extend(expr.get_source_expressions())
Expand Down
Loading
Loading