Skip to content

Commit a3e1679

Browse files
committed
Update tests
1 parent 8155fbf commit a3e1679

File tree

1 file changed

+52
-38
lines changed

1 file changed

+52
-38
lines changed

modin/tests/pandas/native_df_interoperability/test_compiler_caster.py

+52-38
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11

22

33
import pandas
4-
from modin.core.storage_formats.base.query_compiler import BaseQueryCompiler, QCCoercionCost
4+
import pytest
5+
from modin.core.storage_formats.base.query_compiler import QCCoercionCost
56
from modin.core.storage_formats.pandas.native_query_compiler import NativeQueryCompiler
6-
from modin.core.storage_formats.pandas.query_compiler import PandasQueryCompiler
7-
from modin.utils import _inherit_docstrings
87

98

109
class CloudQC(NativeQueryCompiler):
@@ -55,46 +54,61 @@ def qc_engine_switch_cost(self, other_qc):
5554
LocalMachineQC: QCCoercionCost.COST_LOW,
5655
PicoQC: QCCoercionCost.COST_ZERO}
5756

58-
def test_two_same_qc_types_noop():
59-
df = PicoQC(pandas.DataFrame([0, 1, 2]))
60-
df2 = PicoQC(pandas.DataFrame([0, 1, 2]))
61-
df3 = df.concat(axis=1, other=df2)
62-
assert(type(df3) == type(df2))
57+
@pytest.fixture()
58+
def cloud_df():
59+
return CloudQC(pandas.DataFrame([0, 1, 2]))
6360

64-
def test_two_two_qc_types_rhs():
65-
df = PicoQC(pandas.DataFrame([0, 1, 2]))
66-
df2 = ClusterQC(pandas.DataFrame([0, 1, 2]))
67-
df3 = df.concat(axis=1, other=df2)
68-
assert(type(df3) == type(df2))
61+
@pytest.fixture()
62+
def cluster_df():
63+
return ClusterQC(pandas.DataFrame([0, 1, 2]))
6964

70-
def test_two_two_qc_types_lhs():
71-
df = PicoQC(pandas.DataFrame([0, 1, 2]))
72-
df2 = ClusterQC(pandas.DataFrame([0, 1, 2]))
73-
df3 = df2.concat(axis=1, other=df)
74-
assert(type(df3) == type(df2)) # should move to cluster
65+
@pytest.fixture()
66+
def local_df():
67+
return LocalMachineQC(pandas.DataFrame([0, 1, 2]))
7568

76-
def test_three_two_qc_types_rhs():
77-
df = CloudQC(pandas.DataFrame([0, 1, 2]))
78-
df2 = CloudQC(pandas.DataFrame([0, 1, 2]))
79-
df3 = PicoQC(pandas.DataFrame([0, 1, 2]))
80-
df4 = df3.concat(axis=1, other=[df, df2])
81-
assert(type(df) == type(df4)) # should move to cloud
69+
@pytest.fixture()
70+
def pico_df():
71+
return PicoQC(pandas.DataFrame([0, 1, 2]))
8272

83-
def test_three_two_qc_types_lhs():
84-
df = CloudQC(pandas.DataFrame([0, 1, 2]))
85-
df2 = CloudQC(pandas.DataFrame([0, 1, 2]))
86-
df3 = PicoQC(pandas.DataFrame([0, 1, 2]))
87-
df4 = df.concat(axis=1, other=[df2, df3])
88-
assert(type(df) == type(df4)) # should move to cloud
73+
def test_two_same_qc_types_noop(pico_df):
74+
df3 = pico_df.concat(axis=1, other=pico_df)
75+
assert(type(df3) == type(pico_df))
8976

90-
def test_three_two_qc_types_middle():
91-
pass
77+
def test_two_two_qc_types_rhs(pico_df, cluster_df):
78+
df3 = pico_df.concat(axis=1, other=cluster_df)
79+
assert(type(df3) == type(cluster_df)) # should move to cluster
9280

93-
def test_three_three_qc_types_rhs():
94-
pass
81+
def test_two_two_qc_types_lhs(pico_df, cluster_df):
82+
df3 = cluster_df.concat(axis=1, other=pico_df)
83+
assert(type(df3) == type(cluster_df)) # should move to cluster
9584

96-
def test_three_three_qc_types_lhs():
97-
pass
85+
@pytest.mark.parametrize(
86+
"df1, df2, df3, df4, result_type",
87+
[
88+
# no-op
89+
("cloud_df", "cloud_df", "cloud_df", "cloud_df", CloudQC),
90+
# moving all dfs to cloud is 1250, moving to cluster is 1000
91+
# regardless of how they are ordered
92+
("pico_df", "local_df", "cluster_df", "cloud_df", ClusterQC),
93+
("cloud_df", "local_df", "cluster_df", "pico_df", ClusterQC),
94+
("cloud_df", "cluster_df", "local_df", "pico_df", ClusterQC),
95+
("cloud_df", "cloud_df", "local_df", "pico_df", CloudQC),
96+
# Still move everything to cloud
97+
("pico_df", "pico_df", "pico_df", "cloud_df", CloudQC),
98+
],
99+
)
100+
def test_mixed_dfs(df1, df2, df3, df4, result_type, request):
101+
df1 = request.getfixturevalue(df1)
102+
df2 = request.getfixturevalue(df2)
103+
df3 = request.getfixturevalue(df3)
104+
df4 = request.getfixturevalue(df4)
105+
result = df1.concat(axis=1, other=[df2, df3, df4])
106+
assert(type(result) == result_type)
98107

99-
def test_three_three_qc_types_middle():
100-
pass
108+
# This currently passes because we have no "max cost" associated
109+
# with a particular QC, so we would move all data to the PicoQC
110+
# As soon as we can represent "max-cost" the result of this operation
111+
# should be to move all dfs to the CloudQC
112+
def test_extreme_pico(pico_df, cloud_df):
113+
result = cloud_df.concat(axis=1, other=[pico_df, pico_df, pico_df, pico_df, pico_df, pico_df, pico_df])
114+
assert(type(result) == PicoQC)

0 commit comments

Comments
 (0)