Skip to content

Commit c60e6c1

Browse files
[Data] Fix renamed columns to be appropriately dropped from the ouput (#58040)
## Description This change addresses the issues that currently upon column renaming we're not removing original columns. ## Related issues > Link related issues: "Fixes #1234", "Closes #1234", or "Related to #1234". ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Alexey Kudinkin <[email protected]>
1 parent d921f01 commit c60e6c1

File tree

10 files changed

+327
-252
lines changed

10 files changed

+327
-252
lines changed

python/ray/data/_internal/arrow_block.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,9 @@ def take(
361361
"""
362362
return transform_pyarrow.take_table(self._table, indices)
363363

364+
def drop(self, columns: List[str]) -> Block:
365+
return self._table.drop(columns)
366+
364367
def select(self, columns: List[str]) -> "pyarrow.Table":
365368
if not all(isinstance(col, str) for col in columns):
366369
raise ValueError(

python/ray/data/_internal/logical/operators/map_operator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,15 @@ def __init__(
296296
)
297297

298298
def has_star_expr(self) -> bool:
299+
return self.get_star_expr() is not None
300+
301+
def get_star_expr(self) -> Optional[StarExpr]:
299302
"""Check if this projection contains a star() expression."""
300-
return any(isinstance(expr, StarExpr) for expr in self._exprs)
303+
for expr in self._exprs:
304+
if isinstance(expr, StarExpr):
305+
return expr
306+
307+
return None
301308

302309
@property
303310
def exprs(self) -> List["Expr"]:

python/ray/data/_internal/logical/rules/projection_pushdown.py

Lines changed: 153 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from ray.data._internal.logical.operators.map_operator import Project
1010
from ray.data._internal.planner.plan_expression.expression_visitors import (
1111
_ColumnReferenceCollector,
12-
_ColumnRewriter,
12+
_ColumnRefRebindingVisitor,
13+
_is_col_expr,
1314
)
1415
from ray.data.expressions import (
1516
AliasExpr,
@@ -45,17 +46,20 @@ def _extract_simple_rename(expr: Expr) -> Optional[Tuple[str, str]]:
4546
"""
4647
Check if an expression is a simple column rename.
4748
48-
Returns (source_name, dest_name) if the expression is of form:
49+
Returns (source_name, target_name) if the expression is of form:
4950
col("source").alias("dest")
50-
where source != dest.
5151
5252
Returns None for other expression types.
5353
"""
54-
if isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr):
55-
dest_name = expr.name
54+
if (
55+
isinstance(expr, AliasExpr)
56+
and isinstance(expr.expr, ColumnExpr)
57+
and expr._is_rename
58+
):
59+
target_name = expr.name
5660
source_name = expr.expr.name
57-
if source_name != dest_name:
58-
return source_name, dest_name
61+
return source_name, target_name
62+
5963
return None
6064

6165

@@ -67,27 +71,28 @@ def _analyze_upstream_project(
6771
6872
Example: Upstream exprs [col("x").alias("y")] → removed_by_renames = {"x"} if "x" not in output
6973
"""
70-
output_columns = {
74+
output_column_names = {
7175
expr.name for expr in upstream_project.exprs if not isinstance(expr, StarExpr)
7276
}
73-
column_definitions = {
74-
expr.name: expr
75-
for expr in upstream_project.exprs
76-
if not isinstance(expr, StarExpr)
77+
78+
# Compose column definitions in the form of a mapping of
79+
# - Target column name
80+
# - Target expression
81+
output_column_defs = {
82+
expr.name: expr for expr in _filter_out_star(upstream_project.exprs)
7783
}
7884

79-
# Identify columns removed by renames (source not in output)
80-
removed_by_renames: Set[str] = set()
81-
for expr in upstream_project.exprs:
82-
if isinstance(expr, StarExpr):
83-
continue
84-
rename_pair = _extract_simple_rename(expr)
85-
if rename_pair is not None:
86-
source_name, _ = rename_pair
87-
if source_name not in output_columns:
88-
removed_by_renames.add(source_name)
85+
# Identify upstream input columns removed by renaming (ie not propagated into
86+
# its output)
87+
upstream_column_renaming_map = _extract_input_columns_renaming_mapping(
88+
upstream_project.exprs
89+
)
8990

90-
return output_columns, column_definitions, removed_by_renames
91+
return (
92+
output_column_names,
93+
output_column_defs,
94+
set(upstream_column_renaming_map.keys()),
95+
)
9196

9297

9398
def _validate_fusion(
@@ -142,98 +147,124 @@ def _validate_fusion(
142147
return is_valid, missing_columns
143148

144149

145-
def _compose_projects(
146-
upstream_project: Project,
147-
downstream_project: Project,
148-
upstream_has_star: bool,
149-
) -> List[Expr]:
150-
"""
151-
Compose two Projects when the downstream has star().
152-
153-
Strategy:
154-
- Emit a single star() only if the upstream had star() as well.
155-
- Evaluate upstream non-star expressions first, then downstream non-star expressions.
156-
With sequential projection evaluation, downstream expressions can reference
157-
upstream outputs without explicit rewriting.
158-
- Rename-of-computed columns will be dropped from final output by the evaluator
159-
when there's no later explicit mention of the source name.
160-
"""
161-
fused_exprs: List[Expr] = []
162-
163-
# Include star only if upstream had star; otherwise, don't reintroduce dropped cols.
164-
if upstream_has_star:
165-
fused_exprs.append(StarExpr())
166-
167-
# Then upstream non-star expressions in order.
168-
for expr in upstream_project.exprs:
169-
if not isinstance(expr, StarExpr):
170-
fused_exprs.append(expr)
171-
172-
# Then downstream non-star expressions in order.
173-
for expr in downstream_project.exprs:
174-
if not isinstance(expr, StarExpr):
175-
fused_exprs.append(expr)
176-
177-
return fused_exprs
178-
179-
180-
def _try_fuse_consecutive_projects(
181-
upstream_project: Project, downstream_project: Project
182-
) -> Project:
150+
def _try_fuse(upstream_project: Project, downstream_project: Project) -> Project:
183151
"""
184152
Attempt to merge two consecutive Project operations into one.
185153
186154
Example: Upstream: [star(), col("x").alias("y")], Downstream: [star(), (col("y") + 1).alias("z")] → Fused: [star(), (col("x") + 1).alias("z")]
187155
"""
188156
upstream_has_star: bool = upstream_project.has_star_expr()
189-
downstream_has_star: bool = downstream_project.has_star_expr()
157+
158+
# TODO add validations that
159+
# - exprs only depend on input attrs (ie no dep on output of other exprs)
190160

191161
# Analyze upstream
192162
(
193-
upstream_output_columns,
194-
upstream_column_definitions,
195-
removed_by_renames,
163+
upstream_output_cols,
164+
upstream_column_defs,
165+
upstream_input_cols_removed,
196166
) = _analyze_upstream_project(upstream_project)
197167

198168
# Validate fusion possibility
199169
is_valid, missing_columns = _validate_fusion(
200170
downstream_project,
201171
upstream_has_star,
202-
upstream_output_columns,
203-
removed_by_renames,
172+
upstream_output_cols,
173+
upstream_input_cols_removed,
204174
)
205175

206176
if not is_valid:
207177
# Raise KeyError to match expected error type in tests
208178
raise KeyError(
209179
f"Column(s) {sorted(missing_columns)} not found. "
210-
f"Available columns: {sorted(upstream_output_columns) if not upstream_has_star else 'all columns (has star)'}"
180+
f"Available columns: {sorted(upstream_output_cols) if not upstream_has_star else 'all columns (has star)'}"
211181
)
212182

213-
rewritten_exprs: List[Expr] = []
214-
# Intersection case: This is when downstream is a selection (no star), and we need to recursively rewrite the downstream expressions into the upstream column definitions.
215-
# Example: Upstream: [col("a").alias("b")], Downstream: [col("b").alias("c")] → Rewritten: [col("a").alias("c")]
216-
if not downstream_has_star:
217-
for expr in downstream_project.exprs:
218-
rewritten = _ColumnRewriter(upstream_column_definitions).visit(expr)
219-
rewritten_exprs.append(rewritten)
220-
else:
221-
# Composition case: downstream has star(), and we need to merge both upstream and downstream expressions.
183+
# Following invariants are upheld for each ``Project`` logical op:
184+
#
185+
# 1. ``Project``s list of expressions are bound to op's input columns **only**
186+
# (ie there could be no inter-dependency b/w expressions themselves)
187+
#
188+
# 2. `Each of expressions on the `Project``s list constitutes an output
189+
# column definition, where column's name is derived from ``expr.name`` and
190+
# column itself is derived by executing that expression against the op's
191+
# input block.
192+
#
193+
# Therefore to abide by and satisfy aforementioned invariants, when fusing
194+
# 2 ``Project`` operators, following scenarios are considered:
195+
#
196+
# 1. Composition: downstream including (and potentially renaming) upstream
197+
# output columns (this is the case when downstream holds ``StarExpr``).
198+
#
199+
# 2. Projection: downstream projecting upstream output columns (by for ex,
200+
# only selecting & transforming some of the upstream output columns).
201+
#
202+
203+
# Upstream output column refs inside downstream expressions need to be bound
204+
# to upstream output column definitions to satisfy invariant #1 (common for both
205+
# composition/projection cases)
206+
v = _ColumnRefRebindingVisitor(upstream_column_defs)
207+
208+
rebound_downstream_exprs = [
209+
v.visit(e) for e in _filter_out_star(downstream_project.exprs)
210+
]
211+
212+
if not downstream_project.has_star_expr():
213+
# Projection case: this is when downstream is a *selection* (ie, not including
214+
# the upstream columns with ``StarExpr``)
215+
#
222216
# Example:
223-
# Upstream: [star(), col("a").alias("b")], Downstream: [star(), col("b").alias("c")] → Rewritten: [star(), col("a").alias("b"), col("b").alias("c")]
224-
rewritten_exprs = _compose_projects(
225-
upstream_project,
226-
downstream_project,
227-
upstream_has_star,
217+
# Upstream: Project([col("a").alias("b")])
218+
# Downstream: Project([col("b").alias("c")])
219+
#
220+
# Result: Project([col("a").alias("c")])
221+
new_exprs = rebound_downstream_exprs
222+
else:
223+
# Composition case: downstream has ``StarExpr`` (entailing that downstream
224+
# output will be including all of the upstream output columns)
225+
#
226+
# Example 1:
227+
# Upstream: [star(), col("a").alias("b")],
228+
# Downstream: [star(), col("b").alias("c")]
229+
#
230+
# Result: [star(), col("a").alias("b"), col("a").alias("c")]
231+
#
232+
# Example 2:
233+
# Input (columns): ["a", "b"]
234+
# Upstream: [star({"b": "z"}), col("a").alias("x")],
235+
# Downstream: [star({"x": "y"}), col("z")]
236+
#
237+
# Result: [star(), col("a").alias("y"), col("b").alias("z")]
238+
239+
# Extract downstream's input column rename map (downstream inputs are
240+
# upstream's outputs)
241+
downstream_input_column_rename_map = _extract_input_columns_renaming_mapping(
242+
downstream_project.exprs
228243
)
244+
# Collect upstream output column expression "projected" to become
245+
# downstream expressions
246+
projected_upstream_output_col_exprs = []
247+
248+
# When fusing 2 projections
249+
for e in upstream_project.exprs:
250+
# NOTE: We have to filter out upstream output columns that are
251+
# being *renamed* by downstream expression
252+
if e.name not in downstream_input_column_rename_map:
253+
projected_upstream_output_col_exprs.append(e)
254+
255+
new_exprs = projected_upstream_output_col_exprs + rebound_downstream_exprs
229256

230257
return Project(
231258
upstream_project.input_dependency,
232-
exprs=rewritten_exprs,
259+
exprs=new_exprs,
233260
ray_remote_args=downstream_project._ray_remote_args,
234261
)
235262

236263

264+
def _filter_out_star(exprs: List[Expr]) -> List[Expr]:
265+
return [e for e in exprs if not isinstance(e, StarExpr)]
266+
267+
237268
class ProjectionPushdown(Rule):
238269
"""
239270
Optimization rule that pushes projections (column selections) down the query plan.
@@ -270,7 +301,10 @@ def _try_fuse_projects(cls, op: LogicalOperator) -> LogicalOperator:
270301
return op
271302

272303
upstream_project: Project = current_project.input_dependency # type: ignore[assignment]
273-
return _try_fuse_consecutive_projects(upstream_project, current_project)
304+
305+
fused = _try_fuse(upstream_project, current_project)
306+
307+
return fused
274308

275309
@classmethod
276310
def _push_projection_into_read_op(cls, op: LogicalOperator) -> LogicalOperator:
@@ -287,28 +321,27 @@ def _push_projection_into_read_op(cls, op: LogicalOperator) -> LogicalOperator:
287321
and input_op.supports_projection_pushdown()
288322
):
289323
if current_project.has_star_expr():
290-
# If project has a star, than no projection is feasible
324+
# If project has a star, then projection is not feasible
291325
required_columns = None
292326
else:
293-
# Otherwise, collect required column for projection
327+
# Otherwise, collect required columns to push projection down
328+
# into the reader
294329
required_columns = _collect_referenced_columns(current_project.exprs)
295330

296331
# Check if it's a simple projection that could be pushed into
297332
# read as a whole
298-
is_simple_projection = all(
299-
_is_col_expr(expr)
300-
for expr in current_project.exprs
301-
if not isinstance(expr, StarExpr)
333+
is_projection = all(
334+
_is_col_expr(expr) for expr in _filter_out_star(current_project.exprs)
302335
)
303336

304-
if is_simple_projection:
337+
if is_projection:
305338
# NOTE: We only can rename output columns when it's a simple
306339
# projection and Project operator is discarded (otherwise
307340
# it might be holding expression referencing attributes
308341
# by original their names prior to renaming)
309342
#
310343
# TODO fix by instead rewriting exprs
311-
output_column_rename_map = _collect_output_column_rename_map(
344+
output_column_rename_map = _extract_input_columns_renaming_mapping(
312345
current_project.exprs
313346
)
314347

@@ -330,18 +363,35 @@ def _push_projection_into_read_op(cls, op: LogicalOperator) -> LogicalOperator:
330363
return current_project
331364

332365

333-
def _is_col_expr(expr: Expr) -> bool:
334-
return isinstance(expr, ColumnExpr) or (
335-
isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr)
366+
def _extract_input_columns_renaming_mapping(
367+
projection_exprs: List[Expr],
368+
) -> Dict[str, str]:
369+
"""Fetches renaming mapping of all input columns names being renamed (replaced).
370+
Format is source column name -> new column name.
371+
"""
372+
373+
return dict(
374+
[
375+
_get_renaming_mapping(expr)
376+
for expr in _filter_out_star(projection_exprs)
377+
if _is_renaming_expr(expr)
378+
]
336379
)
337380

338381

339-
def _collect_output_column_rename_map(exprs: List[Expr]) -> Dict[str, str]:
340-
# First, extract all potential rename pairs
341-
rename_map = {
342-
expr.expr.name: expr.name
343-
for expr in exprs
344-
if isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr)
345-
}
382+
def _get_renaming_mapping(expr: Expr) -> Tuple[str, str]:
383+
assert _is_renaming_expr(expr)
384+
385+
alias: AliasExpr = expr
386+
387+
return alias.expr.name, alias.name
388+
389+
390+
def _is_renaming_expr(expr: Expr) -> bool:
391+
is_renaming = isinstance(expr, AliasExpr) and expr._is_rename
392+
393+
assert not is_renaming or isinstance(
394+
expr.expr, ColumnExpr
395+
), f"Renaming expression expected to be of the shape alias(col('source'), 'target') (got {expr})"
346396

347-
return rename_map
397+
return is_renaming

python/ray/data/_internal/pandas_block.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,9 @@ def take(self, indices: List[int]) -> "pandas.DataFrame":
363363
table.reset_index(drop=True, inplace=True)
364364
return table
365365

366+
def drop(self, columns: List[str]) -> Block:
367+
return self._table.drop(columns, axis="columns")
368+
366369
def select(self, columns: List[str]) -> "pandas.DataFrame":
367370
if not all(isinstance(col, str) for col in columns):
368371
raise ValueError(

0 commit comments

Comments
 (0)