17
17
from django .utils .functional import cached_property
18
18
from pymongo import ASCENDING , DESCENDING
19
19
20
+ from .functions import SearchScore
20
21
from .query import MongoQuery , wrap_database_errors
21
22
22
23
@@ -34,6 +35,8 @@ def __init__(self, *args, **kwargs):
34
35
# A list of OrderBy objects for this query.
35
36
self .order_by_objs = None
36
37
self .subqueries = []
38
+ # Atlas search calls
39
+ self .search_pipeline = []
37
40
38
41
def _get_group_alias_column (self , expr , annotation_group_idx ):
39
42
"""Generate a dummy field for use in the ids fields in $group."""
@@ -57,6 +60,29 @@ def _get_column_from_expression(self, expr, alias):
57
60
column_target .set_attributes_from_name (alias )
58
61
return Col (self .collection_name , column_target )
59
62
63
+ def _get_replace_expr (self , sub_expr , group , alias ):
64
+ column_target = sub_expr .output_field .clone ()
65
+ column_target .db_column = alias
66
+ column_target .set_attributes_from_name (alias )
67
+ inner_column = Col (self .collection_name , column_target )
68
+ if getattr (sub_expr , "distinct" , False ):
69
+ # If the expression should return distinct values, use
70
+ # $addToSet to deduplicate.
71
+ rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
72
+ group [alias ] = {"$addToSet" : rhs }
73
+ replacing_expr = sub_expr .copy ()
74
+ replacing_expr .set_source_expressions ([inner_column , None ])
75
+ else :
76
+ group [alias ] = sub_expr .as_mql (self , self .connection )
77
+ replacing_expr = inner_column
78
+ # Count must return 0 rather than null.
79
+ if isinstance (sub_expr , Count ):
80
+ replacing_expr = Coalesce (replacing_expr , 0 )
81
+ # Variance = StdDev^2
82
+ if isinstance (sub_expr , Variance ):
83
+ replacing_expr = Power (replacing_expr , 2 )
84
+ return replacing_expr
85
+
60
86
def _prepare_expressions_for_pipeline (self , expression , target , annotation_group_idx ):
61
87
"""
62
88
Prepare expressions for the aggregation pipeline.
@@ -80,29 +106,42 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
80
106
alias = (
81
107
f"__aggregation{ next (annotation_group_idx )} " if sub_expr != expression else target
82
108
)
83
- column_target = sub_expr .output_field .clone ()
84
- column_target .db_column = alias
85
- column_target .set_attributes_from_name (alias )
86
- inner_column = Col (self .collection_name , column_target )
87
- if sub_expr .distinct :
88
- # If the expression should return distinct values, use
89
- # $addToSet to deduplicate.
90
- rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
91
- group [alias ] = {"$addToSet" : rhs }
92
- replacing_expr = sub_expr .copy ()
93
- replacing_expr .set_source_expressions ([inner_column , None ])
94
- else :
95
- group [alias ] = sub_expr .as_mql (self , self .connection )
96
- replacing_expr = inner_column
97
- # Count must return 0 rather than null.
98
- if isinstance (sub_expr , Count ):
99
- replacing_expr = Coalesce (replacing_expr , 0 )
100
- # Variance = StdDev^2
101
- if isinstance (sub_expr , Variance ):
102
- replacing_expr = Power (replacing_expr , 2 )
103
- replacements [sub_expr ] = replacing_expr
109
+ replacements [sub_expr ] = self ._get_replace_expr (sub_expr , group , alias )
104
110
return replacements , group
105
111
112
+ def _prepare_search_expressions_for_pipeline (self , expression , target , search_idx ):
113
+ searches = {}
114
+ replacements = {}
115
+ for sub_expr in self ._get_search_expressions (expression ):
116
+ alias = f"__search_expr.search{ next (search_idx )} "
117
+ replacements [sub_expr ] = self ._get_replace_expr (sub_expr , searches , alias )
118
+ return replacements , searches
119
+
120
+ def _prepare_search_query_for_aggregation_pipeline (self , order_by ):
121
+ replacements = {}
122
+ searches = {}
123
+ search_idx = itertools .count (start = 1 )
124
+ for target , expr in self .query .annotation_select .items ():
125
+ new_replacements , expr_searches = self ._prepare_search_expressions_for_pipeline (
126
+ expr , target , search_idx
127
+ )
128
+ replacements .update (new_replacements )
129
+ searches .update (expr_searches )
130
+
131
+ for expr , _ in order_by :
132
+ new_replacements , expr_searches = self ._prepare_search_expressions_for_pipeline (
133
+ expr , None , search_idx
134
+ )
135
+ replacements .update (new_replacements )
136
+ searches .update (expr_searches )
137
+
138
+ having_replacements , having_group = self ._prepare_search_expressions_for_pipeline (
139
+ self .having , None , search_idx
140
+ )
141
+ replacements .update (having_replacements )
142
+ searches .update (having_group )
143
+ return searches , replacements
144
+
106
145
def _prepare_annotations_for_aggregation_pipeline (self , order_by ):
107
146
"""Prepare annotations for the aggregation pipeline."""
108
147
replacements = {}
@@ -179,6 +218,9 @@ def _get_group_id_expressions(self, order_by):
179
218
ids = self .get_project_fields (tuple (columns ), force_expression = True )
180
219
return ids , replacements
181
220
221
+ def _build_search_pipeline (self , search_queries ):
222
+ pass
223
+
182
224
def _build_aggregation_pipeline (self , ids , group ):
183
225
"""Build the aggregation pipeline for grouping."""
184
226
pipeline = []
@@ -209,7 +251,12 @@ def _build_aggregation_pipeline(self, ids, group):
209
251
210
252
def pre_sql_setup (self , with_col_aliases = False ):
211
253
extra_select , order_by , group_by = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
212
- group , all_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
254
+ searches , search_replacements = self ._prepare_search_query_for_aggregation_pipeline (
255
+ order_by
256
+ )
257
+ group , group_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
258
+ all_replacements = {** search_replacements , ** group_replacements }
259
+ self .search_pipeline = searches
213
260
# query.group_by is either:
214
261
# - None: no GROUP BY
215
262
# - True: group by select fields
@@ -557,10 +604,16 @@ def get_lookup_pipeline(self):
557
604
return result
558
605
559
606
def _get_aggregate_expressions (self , expr ):
607
+ return self ._get_all_expressions_of_type (expr , Aggregate )
608
+
609
+ def _get_search_expressions (self , expr ):
610
+ return self ._get_all_expressions_of_type (expr , SearchScore )
611
+
612
+ def _get_all_expressions_of_type (self , expr , target_type ):
560
613
stack = [expr ]
561
614
while stack :
562
615
expr = stack .pop ()
563
- if isinstance (expr , Aggregate ):
616
+ if isinstance (expr , target_type ):
564
617
yield expr
565
618
elif hasattr (expr , "get_source_expressions" ):
566
619
stack .extend (expr .get_source_expressions ())
0 commit comments