Skip to content

Commit 3a45f2e

Browse files
authored
Fix Projection meta (#78)
1 parent c9c10d4 commit 3a45f2e

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

dask_expr/expr.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
_get_meta_map_partitions,
1818
apply_and_enforce,
1919
is_dataframe_like,
20+
is_index_like,
21+
is_series_like,
2022
)
21-
from dask.utils import M, apply, funcname, import_required
23+
from dask.utils import M, apply, funcname, import_required, is_arraylike
2224

2325
replacement_rules = []
2426

@@ -83,6 +85,14 @@ def _tree_repr_lines(self, indent=0, recursive=True):
8385

8486
if isinstance(op, pd.core.base.PandasObject):
8587
op = "<pandas>"
88+
elif is_dataframe_like(op):
89+
op = "<dataframe>"
90+
elif is_index_like(op):
91+
op = "<index>"
92+
elif is_series_like(op):
93+
op = "<series>"
94+
elif is_arraylike(op):
95+
op = "<array>"
8696

8797
elif repr(op) != repr(default):
8898
if param:
@@ -777,6 +787,13 @@ def columns(self):
777787
else:
778788
return self.operand("columns")
779789

790+
@property
791+
def _meta(self):
792+
if is_dataframe_like(self.frame._meta):
793+
return super()._meta
794+
# Avoid column selection for Series/Index
795+
return self.frame._meta
796+
780797
def _node_label_args(self):
781798
return [self.frame, self.operand("columns")]
782799

dask_expr/reductions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,6 @@ def chunk_kwargs(self):
188188
min_count=self.min_count,
189189
)
190190

191-
@property
192-
def _meta(self):
193-
return self.frame._meta.sum(**self.chunk_kwargs)
194-
195191
def _simplify_up(self, parent):
196192
if isinstance(parent, Projection):
197193
return self.frame[parent.operand("columns")].sum(*self.operands[1:])

dask_expr/tests/test_collection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pickle
33

44
import dask
5+
import numpy as np
56
import pandas as pd
67
import pytest
78
from dask.dataframe.utils import assert_eq
@@ -47,7 +48,7 @@ def test_meta_divisions_name():
4748
assert list(df.columns) == list(a.columns)
4849
assert df.npartitions == 2
4950

50-
assert df.x.sum()._meta == 0
51+
assert np.isscalar(df.x.sum()._meta)
5152
assert df.x.sum().npartitions == 1
5253

5354
assert "mul" in df._name

0 commit comments

Comments
 (0)