Skip to content

Commit 8155fbf

Browse files
committed
More tests
1 parent 486f515 commit 8155fbf

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

modin/core/storage_formats/pandas/query_compiler_caster.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def calculate(self):
6464
costs_2 = qc_2.qc_engine_switch_cost(qc_1)
6565
self._add_cost_data(costs_1)
6666
self._add_cost_data(costs_2)
67-
67+
if len(self._caster_costing_map) <= 0 and len(self._qc_cls_list) > 0:
68+
self._result_type = self._qc_cls_list[0]
69+
return self._result_type
6870
min_value = min(self._caster_costing_map.values())
6971
for key, value in self._caster_costing_map.items():
7072
if min_value == value:
@@ -163,16 +165,16 @@ def apply_argument_cast(obj: Fn) -> Fn:
163165
"""
164166
if isinstance(obj, type):
165167
all_attrs = dict(inspect.getmembers(obj))
166-
all_attrs.pop("__abstractmethods__")
167-
all_attrs.pop("__init__")
168-
all_attrs.pop("qc_engine_switch_cost")
169-
all_attrs.pop("from_pandas")
168+
170169

171170
# This is required because inspect converts class methods to member functions
172171
current_class_attrs = vars(obj)
173172
for key in current_class_attrs:
174173
all_attrs[key] = current_class_attrs[key]
175-
174+
all_attrs.pop("__abstractmethods__")
175+
all_attrs.pop("__init__")
176+
all_attrs.pop("qc_engine_switch_cost")
177+
all_attrs.pop("from_pandas")
176178
for attr_name, attr_value in all_attrs.items():
177179
if isinstance(
178180
attr_value, (FunctionType, MethodType, classmethod, staticmethod)
@@ -203,7 +205,6 @@ def cast_args(*args: Tuple, **kwargs: Dict) -> Any:
203205
"""
204206
if len(args) == 0 and len(kwargs) == 0:
205207
return
206-
print(f"Adding wrapper {obj}\n")
207208
current_qc = args[0]
208209
calculator = QueryCompilerCasterCalculator()
209210
calculator.add_query_compiler(current_qc)

modin/tests/pandas/native_df_interoperability/test_compiler_caster.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,18 @@ def test_two_two_qc_types_lhs():
7474
assert(type(df3) == type(df2)) # should move to cluster
7575

7676
def test_three_two_qc_types_rhs():
77-
pass
77+
df = CloudQC(pandas.DataFrame([0, 1, 2]))
78+
df2 = CloudQC(pandas.DataFrame([0, 1, 2]))
79+
df3 = PicoQC(pandas.DataFrame([0, 1, 2]))
80+
df4 = df3.concat(axis=1, other=[df, df2])
81+
assert(type(df) == type(df4)) # should move to cloud
7882

7983
def test_three_two_qc_types_lhs():
80-
pass
84+
df = CloudQC(pandas.DataFrame([0, 1, 2]))
85+
df2 = CloudQC(pandas.DataFrame([0, 1, 2]))
86+
df3 = PicoQC(pandas.DataFrame([0, 1, 2]))
87+
df4 = df.concat(axis=1, other=[df2, df3])
88+
assert(type(df) == type(df4)) # should move to cloud
8189

8290
def test_three_two_qc_types_middle():
8391
pass

0 commit comments

Comments
 (0)