Skip to content

Commit 5e7a7c7

Browse files
committed
Add more coverage for scenarios that probably should not happen, unless there is a gap in support
1 parent 4908ae0 commit 5e7a7c7

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

modin/core/storage_formats/pandas/query_compiler_caster.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(self):
5151
self._qc_cls_set = set()
5252
self._result_type = None
5353

54-
def add_query_compiler(self, query_compiler: BaseQueryCompiler):
54+
def add_query_compiler(self, query_compiler):
5555
"""
5656
Add a query compiler to be considered for casting.
5757
@@ -83,7 +83,7 @@ def calculate(self):
8383
if self._result_type is not None:
8484
return self._result_type
8585
if len(self._qc_cls_set) == 1:
86-
return self._qc_cls_set.pop()
86+
return list(self._qc_cls_set)[0]
8787
if len(self._qc_cls_set) == 0:
8888
raise ValueError("No query compilers registered")
8989

@@ -133,7 +133,10 @@ def result_data_cls(self):
133133
DataFrame object associated with the preferred query compiler.
134134
"""
135135
qc_type = self.calculate()
136-
return self._compiler_class_to_data_class[qc_type]
136+
if qc_type in self._compiler_class_to_data_class:
137+
return self._compiler_class_to_data_class[qc_type]
138+
else:
139+
return None
137140

138141

139142
class QueryCompilerCaster:

modin/tests/pandas/native_df_interoperability/test_compiler_caster.py

+17
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import pandas
1515
import pytest
1616

17+
from modin.core.storage_formats.pandas.query_compiler_caster import (
18+
QueryCompilerCasterCalculator,
19+
)
1720
import modin.pandas as pd
1821
from modin.core.storage_formats.base.query_compiler import QCCoercionCost
1922
from modin.core.storage_formats.pandas.native_query_compiler import NativeQueryCompiler
@@ -249,3 +252,17 @@ def test_default_to_caller(default_df, default2_df):
249252
assert type(df3) is type(default_df) # should stay on caller
250253
df3 = default2_df.concat(axis=1, other=default_df)
251254
assert type(df3) is type(default2_df) # should stay on caller
255+
256+
257+
def test_no_qc_data_to_calculate():
258+
calculator = QueryCompilerCasterCalculator()
259+
calculator.add_query_compiler(ClusterQC)
260+
result = calculator.calculate()
261+
assert result is ClusterQC
262+
assert calculator.result_data_cls() is None
263+
264+
265+
def test_no_qc_to_calculate():
266+
calculator = QueryCompilerCasterCalculator()
267+
with pytest.raises(ValueError):
268+
calculator.calculate()

0 commit comments

Comments
 (0)