Skip to content

Commit 3b90847

Browse files
committed
First approach.
1 parent 31a729c commit 3b90847

File tree

5 files changed

+118
-27
lines changed

5 files changed

+118
-27
lines changed

django_mongodb_backend/compiler.py

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from django.utils.functional import cached_property
1818
from pymongo import ASCENDING, DESCENDING
1919

20+
from .functions import SearchScore
2021
from .query import MongoQuery, wrap_database_errors
2122

2223

@@ -34,6 +35,8 @@ def __init__(self, *args, **kwargs):
3435
# A list of OrderBy objects for this query.
3536
self.order_by_objs = None
3637
self.subqueries = []
38+
# Atlas search calls
39+
self.search_pipeline = []
3740

3841
def _get_group_alias_column(self, expr, annotation_group_idx):
3942
"""Generate a dummy field for use in the ids fields in $group."""
@@ -57,6 +60,29 @@ def _get_column_from_expression(self, expr, alias):
5760
column_target.set_attributes_from_name(alias)
5861
return Col(self.collection_name, column_target)
5962

63+
def _get_replace_expr(self, sub_expr, group, alias):
64+
column_target = sub_expr.output_field.clone()
65+
column_target.db_column = alias
66+
column_target.set_attributes_from_name(alias)
67+
inner_column = Col(self.collection_name, column_target)
68+
if getattr(sub_expr, "distinct", False):
69+
# If the expression should return distinct values, use
70+
# $addToSet to deduplicate.
71+
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
72+
group[alias] = {"$addToSet": rhs}
73+
replacing_expr = sub_expr.copy()
74+
replacing_expr.set_source_expressions([inner_column, None])
75+
else:
76+
group[alias] = sub_expr.as_mql(self, self.connection)
77+
replacing_expr = inner_column
78+
# Count must return 0 rather than null.
79+
if isinstance(sub_expr, Count):
80+
replacing_expr = Coalesce(replacing_expr, 0)
81+
# Variance = StdDev^2
82+
if isinstance(sub_expr, Variance):
83+
replacing_expr = Power(replacing_expr, 2)
84+
return replacing_expr
85+
6086
def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx):
6187
"""
6288
Prepare expressions for the aggregation pipeline.
@@ -80,29 +106,42 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
80106
alias = (
81107
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target
82108
)
83-
column_target = sub_expr.output_field.clone()
84-
column_target.db_column = alias
85-
column_target.set_attributes_from_name(alias)
86-
inner_column = Col(self.collection_name, column_target)
87-
if sub_expr.distinct:
88-
# If the expression should return distinct values, use
89-
# $addToSet to deduplicate.
90-
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
91-
group[alias] = {"$addToSet": rhs}
92-
replacing_expr = sub_expr.copy()
93-
replacing_expr.set_source_expressions([inner_column, None])
94-
else:
95-
group[alias] = sub_expr.as_mql(self, self.connection)
96-
replacing_expr = inner_column
97-
# Count must return 0 rather than null.
98-
if isinstance(sub_expr, Count):
99-
replacing_expr = Coalesce(replacing_expr, 0)
100-
# Variance = StdDev^2
101-
if isinstance(sub_expr, Variance):
102-
replacing_expr = Power(replacing_expr, 2)
103-
replacements[sub_expr] = replacing_expr
109+
replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias)
104110
return replacements, group
105111

112+
def _prepare_search_expressions_for_pipeline(self, expression, target, search_idx):
113+
searches = {}
114+
replacements = {}
115+
for sub_expr in self._get_search_expressions(expression):
116+
alias = f"__search_expr.search{next(search_idx)}"
117+
replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias)
118+
return replacements, searches
119+
120+
def _prepare_search_query_for_aggregation_pipeline(self, order_by):
121+
replacements = {}
122+
searches = {}
123+
search_idx = itertools.count(start=1)
124+
for target, expr in self.query.annotation_select.items():
125+
new_replacements, expr_searches = self._prepare_search_expressions_for_pipeline(
126+
expr, target, search_idx
127+
)
128+
replacements.update(new_replacements)
129+
searches.update(expr_searches)
130+
131+
for expr, _ in order_by:
132+
new_replacements, expr_searches = self._prepare_search_expressions_for_pipeline(
133+
expr, None, search_idx
134+
)
135+
replacements.update(new_replacements)
136+
searches.update(expr_searches)
137+
138+
having_replacements, having_group = self._prepare_search_expressions_for_pipeline(
139+
self.having, None, search_idx
140+
)
141+
replacements.update(having_replacements)
142+
searches.update(having_group)
143+
return searches, replacements
144+
106145
def _prepare_annotations_for_aggregation_pipeline(self, order_by):
107146
"""Prepare annotations for the aggregation pipeline."""
108147
replacements = {}
@@ -179,6 +218,9 @@ def _get_group_id_expressions(self, order_by):
179218
ids = self.get_project_fields(tuple(columns), force_expression=True)
180219
return ids, replacements
181220

221+
def _build_search_pipeline(self, search_queries):
222+
pass
223+
182224
def _build_aggregation_pipeline(self, ids, group):
183225
"""Build the aggregation pipeline for grouping."""
184226
pipeline = []
@@ -209,7 +251,12 @@ def _build_aggregation_pipeline(self, ids, group):
209251

210252
def pre_sql_setup(self, with_col_aliases=False):
211253
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
212-
group, all_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
254+
searches, search_replacements = self._prepare_search_query_for_aggregation_pipeline(
255+
order_by
256+
)
257+
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
258+
all_replacements = {**search_replacements, **group_replacements}
259+
self.search_pipeline = searches
213260
# query.group_by is either:
214261
# - None: no GROUP BY
215262
# - True: group by select fields
@@ -557,10 +604,16 @@ def get_lookup_pipeline(self):
557604
return result
558605

559606
def _get_aggregate_expressions(self, expr):
607+
return self._get_all_expressions_of_type(expr, Aggregate)
608+
609+
def _get_search_expressions(self, expr):
610+
return self._get_all_expressions_of_type(expr, SearchScore)
611+
612+
def _get_all_expressions_of_type(self, expr, target_type):
560613
stack = [expr]
561614
while stack:
562615
expr = stack.pop()
563-
if isinstance(expr, Aggregate):
616+
if isinstance(expr, target_type):
564617
yield expr
565618
elif hasattr(expr, "get_source_expressions"):
566619
stack.extend(expr.get_source_expressions())

django_mongodb_backend/functions.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from django.conf import settings
44
from django.db import NotSupportedError
5-
from django.db.models import DateField, DateTimeField, TimeField
6-
from django.db.models.expressions import Func
5+
from django.db.models import DateField, DateTimeField, Expression, FloatField, TimeField
6+
from django.db.models.expressions import F, Func, Value
77
from django.db.models.functions import JSONArray
88
from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf
99
from django.db.models.functions.datetime import (
@@ -38,8 +38,9 @@
3838
Trim,
3939
Upper,
4040
)
41+
from django.utils.deconstruct import deconstructible
4142

42-
from .query_utils import process_lhs
43+
from .query_utils import process_lhs, process_rhs
4344

4445
MONGO_OPERATORS = {
4546
Ceil: "ceil",
@@ -268,6 +269,30 @@ def trunc_time(self, compiler, connection):
268269
}
269270

270271

272+
@deconstructible(path="django_mongodb_backend.functions.SearchScore")
273+
class SearchScore(Expression):
274+
def __init__(self, path, value, operation="equals", **kwargs):
275+
self.extra_params = kwargs
276+
self.lhs = path if hasattr(path, "resolve_expression") else F(path)
277+
if not isinstance(value, str):
278+
# TODO HANDLE VALUES LIKE Value("some string")
279+
raise ValueError("STRING NEEDED")
280+
self.rhs = Value(value)
281+
self.operation = operation
282+
super().__init__(output_field=FloatField())
283+
284+
def __repr__(self):
285+
return f"search {self.field} = {self.value} | {self.extra_params}"
286+
287+
def as_mql(self, compiler, connection):
288+
lhs = process_lhs(self, compiler, connection)
289+
rhs = process_rhs(self, compiler, connection)
290+
return {"$search": {self.operation: {"path": lhs[:1], "query": rhs, **self.extra_params}}}
291+
292+
def as_sql(self, compiler, connection):
293+
return "", []
294+
295+
271296
def register_functions():
272297
Cast.as_mql = cast
273298
Concat.as_mql = concat

django_mongodb_backend/query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, compiler):
4949
self.lookup_pipeline = None
5050
self.project_fields = None
5151
self.aggregation_pipeline = compiler.aggregation_pipeline
52+
self.search_pipeline = compiler.search_pipeline
5253
self.extra_fields = None
5354
self.combinator_pipeline = None
5455
# $lookup stage that encapsulates the pipeline for performing a nested
@@ -75,6 +76,8 @@ def get_cursor(self):
7576

7677
def get_pipeline(self):
7778
pipeline = []
79+
if self.search_pipeline:
80+
pipeline.extend(self.search_pipeline)
7881
if self.lookup_pipeline:
7982
pipeline.extend(self.lookup_pipeline)
8083
for query in self.subqueries or ():

django_mongodb_backend/query_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
def is_direct_value(node):
7-
return not hasattr(node, "as_sql")
7+
return not hasattr(node, "as_sql") and not hasattr(node, "as_mql")
88

99

1010
def process_lhs(node, compiler, connection):

tests/queries_/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from django.db import models
22

33
from django_mongodb_backend.fields import ObjectIdAutoField, ObjectIdField
4+
from django_mongodb_backend.indexes import SearchIndex
45

56

67
class Author(models.Model):
@@ -53,3 +54,12 @@ class Meta:
5354

5455
def __str__(self):
5556
return str(self.pk)
57+
58+
59+
class Article(models.Model):
60+
headline = models.CharField(max_length=100)
61+
number = models.IntegerField()
62+
body = models.TextField()
63+
64+
class Meta:
65+
indexes = [SearchIndex(fields=["headline"])]

0 commit comments

Comments
 (0)