99from ray .data ._internal .logical .operators .map_operator import Project
1010from ray .data ._internal .planner .plan_expression .expression_visitors import (
1111 _ColumnReferenceCollector ,
12- _ColumnRewriter ,
12+ _ColumnRefRebindingVisitor ,
13+ _is_col_expr ,
1314)
1415from ray .data .expressions import (
1516 AliasExpr ,
@@ -45,17 +46,20 @@ def _extract_simple_rename(expr: Expr) -> Optional[Tuple[str, str]]:
4546 """
4647 Check if an expression is a simple column rename.
4748
48- Returns (source_name, dest_name ) if the expression is of form:
49+ Returns (source_name, target_name ) if the expression is of form:
4950 col("source").alias("dest")
50- where source != dest.
5151
5252 Returns None for other expression types.
5353 """
54- if isinstance (expr , AliasExpr ) and isinstance (expr .expr , ColumnExpr ):
55- dest_name = expr .name
54+ if (
55+ isinstance (expr , AliasExpr )
56+ and isinstance (expr .expr , ColumnExpr )
57+ and expr ._is_rename
58+ ):
59+ target_name = expr .name
5660 source_name = expr .expr .name
57- if source_name != dest_name :
58- return source_name , dest_name
61+ return source_name , target_name
62+
5963 return None
6064
6165
@@ -67,27 +71,28 @@ def _analyze_upstream_project(
6771
6872 Example: Upstream exprs [col("x").alias("y")] → removed_by_renames = {"x"} if "x" not in output
6973 """
70- output_columns = {
74+ output_column_names = {
7175 expr .name for expr in upstream_project .exprs if not isinstance (expr , StarExpr )
7276 }
73- column_definitions = {
74- expr .name : expr
75- for expr in upstream_project .exprs
76- if not isinstance (expr , StarExpr )
77+
78+ # Compose column definitions in the form of a mapping of
79+ # - Target column name
80+ # - Target expression
81+ output_column_defs = {
82+ expr .name : expr for expr in _filter_out_star (upstream_project .exprs )
7783 }
7884
79- # Identify columns removed by renames (source not in output)
80- removed_by_renames : Set [str ] = set ()
81- for expr in upstream_project .exprs :
82- if isinstance (expr , StarExpr ):
83- continue
84- rename_pair = _extract_simple_rename (expr )
85- if rename_pair is not None :
86- source_name , _ = rename_pair
87- if source_name not in output_columns :
88- removed_by_renames .add (source_name )
85+ # Identify upstream input columns removed by renaming (ie not propagated into
86+ # its output)
87+ upstream_column_renaming_map = _extract_input_columns_renaming_mapping (
88+ upstream_project .exprs
89+ )
8990
90- return output_columns , column_definitions , removed_by_renames
91+ return (
92+ output_column_names ,
93+ output_column_defs ,
94+ set (upstream_column_renaming_map .keys ()),
95+ )
9196
9297
9398def _validate_fusion (
@@ -142,98 +147,124 @@ def _validate_fusion(
142147 return is_valid , missing_columns
143148
144149
145- def _compose_projects (
146- upstream_project : Project ,
147- downstream_project : Project ,
148- upstream_has_star : bool ,
149- ) -> List [Expr ]:
150- """
151- Compose two Projects when the downstream has star().
152-
153- Strategy:
154- - Emit a single star() only if the upstream had star() as well.
155- - Evaluate upstream non-star expressions first, then downstream non-star expressions.
156- With sequential projection evaluation, downstream expressions can reference
157- upstream outputs without explicit rewriting.
158- - Rename-of-computed columns will be dropped from final output by the evaluator
159- when there's no later explicit mention of the source name.
160- """
161- fused_exprs : List [Expr ] = []
162-
163- # Include star only if upstream had star; otherwise, don't reintroduce dropped cols.
164- if upstream_has_star :
165- fused_exprs .append (StarExpr ())
166-
167- # Then upstream non-star expressions in order.
168- for expr in upstream_project .exprs :
169- if not isinstance (expr , StarExpr ):
170- fused_exprs .append (expr )
171-
172- # Then downstream non-star expressions in order.
173- for expr in downstream_project .exprs :
174- if not isinstance (expr , StarExpr ):
175- fused_exprs .append (expr )
176-
177- return fused_exprs
178-
179-
180- def _try_fuse_consecutive_projects (
181- upstream_project : Project , downstream_project : Project
182- ) -> Project :
150+ def _try_fuse (upstream_project : Project , downstream_project : Project ) -> Project :
183151 """
184152 Attempt to merge two consecutive Project operations into one.
185153
186154 Example: Upstream: [star(), col("x").alias("y")], Downstream: [star(), (col("y") + 1).alias("z")] → Fused: [star(), (col("x") + 1).alias("z")]
187155 """
188156 upstream_has_star : bool = upstream_project .has_star_expr ()
189- downstream_has_star : bool = downstream_project .has_star_expr ()
157+
158+ # TODO add validations that
159+ # - exprs only depend on input attrs (ie no dep on output of other exprs)
190160
191161 # Analyze upstream
192162 (
193- upstream_output_columns ,
194- upstream_column_definitions ,
195- removed_by_renames ,
163+ upstream_output_cols ,
164+ upstream_column_defs ,
165+ upstream_input_cols_removed ,
196166 ) = _analyze_upstream_project (upstream_project )
197167
198168 # Validate fusion possibility
199169 is_valid , missing_columns = _validate_fusion (
200170 downstream_project ,
201171 upstream_has_star ,
202- upstream_output_columns ,
203- removed_by_renames ,
172+ upstream_output_cols ,
173+ upstream_input_cols_removed ,
204174 )
205175
206176 if not is_valid :
207177 # Raise KeyError to match expected error type in tests
208178 raise KeyError (
209179 f"Column(s) { sorted (missing_columns )} not found. "
210- f"Available columns: { sorted (upstream_output_columns ) if not upstream_has_star else 'all columns (has star)' } "
180+ f"Available columns: { sorted (upstream_output_cols ) if not upstream_has_star else 'all columns (has star)' } "
211181 )
212182
213- rewritten_exprs : List [Expr ] = []
214- # Intersection case: This is when downstream is a selection (no star), and we need to recursively rewrite the downstream expressions into the upstream column definitions.
215- # Example: Upstream: [col("a").alias("b")], Downstream: [col("b").alias("c")] → Rewritten: [col("a").alias("c")]
216- if not downstream_has_star :
217- for expr in downstream_project .exprs :
218- rewritten = _ColumnRewriter (upstream_column_definitions ).visit (expr )
219- rewritten_exprs .append (rewritten )
220- else :
221- # Composition case: downstream has star(), and we need to merge both upstream and downstream expressions.
183+ # Following invariants are upheld for each ``Project`` logical op:
184+ #
185+ # 1. ``Project``s list of expressions are bound to op's input columns **only**
186+ # (ie there could be no inter-dependency b/w expressions themselves)
187+ #
188+ # 2. `Each of expressions on the `Project``s list constitutes an output
189+ # column definition, where column's name is derived from ``expr.name`` and
190+ # column itself is derived by executing that expression against the op's
191+ # input block.
192+ #
193+ # Therefore to abide by and satisfy aforementioned invariants, when fusing
194+ # 2 ``Project`` operators, following scenarios are considered:
195+ #
196+ # 1. Composition: downstream including (and potentially renaming) upstream
197+ # output columns (this is the case when downstream holds ``StarExpr``).
198+ #
199+ # 2. Projection: downstream projecting upstream output columns (by for ex,
200+ # only selecting & transforming some of the upstream output columns).
201+ #
202+
203+ # Upstream output column refs inside downstream expressions need to be bound
204+ # to upstream output column definitions to satisfy invariant #1 (common for both
205+ # composition/projection cases)
206+ v = _ColumnRefRebindingVisitor (upstream_column_defs )
207+
208+ rebound_downstream_exprs = [
209+ v .visit (e ) for e in _filter_out_star (downstream_project .exprs )
210+ ]
211+
212+ if not downstream_project .has_star_expr ():
213+ # Projection case: this is when downstream is a *selection* (ie, not including
214+ # the upstream columns with ``StarExpr``)
215+ #
222216 # Example:
223- # Upstream: [star(), col("a").alias("b")], Downstream: [star(), col("b").alias("c")] → Rewritten: [star(), col("a").alias("b"), col("b").alias("c")]
224- rewritten_exprs = _compose_projects (
225- upstream_project ,
226- downstream_project ,
227- upstream_has_star ,
217+ # Upstream: Project([col("a").alias("b")])
218+ # Downstream: Project([col("b").alias("c")])
219+ #
220+ # Result: Project([col("a").alias("c")])
221+ new_exprs = rebound_downstream_exprs
222+ else :
223+ # Composition case: downstream has ``StarExpr`` (entailing that downstream
224+ # output will be including all of the upstream output columns)
225+ #
226+ # Example 1:
227+ # Upstream: [star(), col("a").alias("b")],
228+ # Downstream: [star(), col("b").alias("c")]
229+ #
230+ # Result: [star(), col("a").alias("b"), col("a").alias("c")]
231+ #
232+ # Example 2:
233+ # Input (columns): ["a", "b"]
234+ # Upstream: [star({"b": "z"}), col("a").alias("x")],
235+ # Downstream: [star({"x": "y"}), col("z")]
236+ #
237+ # Result: [star(), col("a").alias("y"), col("b").alias("z")]
238+
239+ # Extract downstream's input column rename map (downstream inputs are
240+ # upstream's outputs)
241+ downstream_input_column_rename_map = _extract_input_columns_renaming_mapping (
242+ downstream_project .exprs
228243 )
244+ # Collect upstream output column expression "projected" to become
245+ # downstream expressions
246+ projected_upstream_output_col_exprs = []
247+
248+ # When fusing 2 projections
249+ for e in upstream_project .exprs :
250+ # NOTE: We have to filter out upstream output columns that are
251+ # being *renamed* by downstream expression
252+ if e .name not in downstream_input_column_rename_map :
253+ projected_upstream_output_col_exprs .append (e )
254+
255+ new_exprs = projected_upstream_output_col_exprs + rebound_downstream_exprs
229256
230257 return Project (
231258 upstream_project .input_dependency ,
232- exprs = rewritten_exprs ,
259+ exprs = new_exprs ,
233260 ray_remote_args = downstream_project ._ray_remote_args ,
234261 )
235262
236263
264+ def _filter_out_star (exprs : List [Expr ]) -> List [Expr ]:
265+ return [e for e in exprs if not isinstance (e , StarExpr )]
266+
267+
237268class ProjectionPushdown (Rule ):
238269 """
239270 Optimization rule that pushes projections (column selections) down the query plan.
@@ -270,7 +301,10 @@ def _try_fuse_projects(cls, op: LogicalOperator) -> LogicalOperator:
270301 return op
271302
272303 upstream_project : Project = current_project .input_dependency # type: ignore[assignment]
273- return _try_fuse_consecutive_projects (upstream_project , current_project )
304+
305+ fused = _try_fuse (upstream_project , current_project )
306+
307+ return fused
274308
275309 @classmethod
276310 def _push_projection_into_read_op (cls , op : LogicalOperator ) -> LogicalOperator :
@@ -287,28 +321,27 @@ def _push_projection_into_read_op(cls, op: LogicalOperator) -> LogicalOperator:
287321 and input_op .supports_projection_pushdown ()
288322 ):
289323 if current_project .has_star_expr ():
290- # If project has a star, than no projection is feasible
324+ # If project has a star, then projection is not feasible
291325 required_columns = None
292326 else :
293- # Otherwise, collect required column for projection
327+ # Otherwise, collect required columns to push projection down
328+ # into the reader
294329 required_columns = _collect_referenced_columns (current_project .exprs )
295330
296331 # Check if it's a simple projection that could be pushed into
297332 # read as a whole
298- is_simple_projection = all (
299- _is_col_expr (expr )
300- for expr in current_project .exprs
301- if not isinstance (expr , StarExpr )
333+ is_projection = all (
334+ _is_col_expr (expr ) for expr in _filter_out_star (current_project .exprs )
302335 )
303336
304- if is_simple_projection :
337+ if is_projection :
305338 # NOTE: We only can rename output columns when it's a simple
306339 # projection and Project operator is discarded (otherwise
307340 # it might be holding expression referencing attributes
308341 # by original their names prior to renaming)
309342 #
310343 # TODO fix by instead rewriting exprs
311- output_column_rename_map = _collect_output_column_rename_map (
344+ output_column_rename_map = _extract_input_columns_renaming_mapping (
312345 current_project .exprs
313346 )
314347
@@ -330,18 +363,35 @@ def _push_projection_into_read_op(cls, op: LogicalOperator) -> LogicalOperator:
330363 return current_project
331364
332365
333- def _is_col_expr (expr : Expr ) -> bool :
334- return isinstance (expr , ColumnExpr ) or (
335- isinstance (expr , AliasExpr ) and isinstance (expr .expr , ColumnExpr )
366+ def _extract_input_columns_renaming_mapping (
367+ projection_exprs : List [Expr ],
368+ ) -> Dict [str , str ]:
369+ """Fetches renaming mapping of all input columns names being renamed (replaced).
370+ Format is source column name -> new column name.
371+ """
372+
373+ return dict (
374+ [
375+ _get_renaming_mapping (expr )
376+ for expr in _filter_out_star (projection_exprs )
377+ if _is_renaming_expr (expr )
378+ ]
336379 )
337380
338381
339- def _collect_output_column_rename_map (exprs : List [Expr ]) -> Dict [str , str ]:
340- # First, extract all potential rename pairs
341- rename_map = {
342- expr .expr .name : expr .name
343- for expr in exprs
344- if isinstance (expr , AliasExpr ) and isinstance (expr .expr , ColumnExpr )
345- }
382+ def _get_renaming_mapping (expr : Expr ) -> Tuple [str , str ]:
383+ assert _is_renaming_expr (expr )
384+
385+ alias : AliasExpr = expr
386+
387+ return alias .expr .name , alias .name
388+
389+
390+ def _is_renaming_expr (expr : Expr ) -> bool :
391+ is_renaming = isinstance (expr , AliasExpr ) and expr ._is_rename
392+
393+ assert not is_renaming or isinstance (
394+ expr .expr , ColumnExpr
395+ ), f"Renaming expression expected to be of the shape alias(col('source'), 'target') (got { expr } )"
346396
347- return rename_map
397+ return is_renaming
0 commit comments