21
21
22
22
import functools
23
23
import inspect
24
+ from itertools import combinations
24
25
from types import FunctionType , MethodType
25
26
from typing import Any , Dict , Tuple , TypeVar
26
27
27
28
from pandas .core .indexes .frozen import FrozenList
28
29
29
- from modin .core .storage_formats .base .query_compiler import BaseQueryCompiler
30
+ from modin .core .storage_formats .base .query_compiler import BaseQueryCompiler , QCCoercionCost
30
31
31
32
Fn = TypeVar ("Fn" , bound = Any )
32
33
33
-
34
+ class QueryCompilerCasterCalculator :
35
+
36
+ def __init__ (self ):
37
+ self ._caster_costing_map = {}
38
+ self ._data_cls_map = {}
39
+ self ._qc_list = []
40
+ self ._qc_cls_list = []
41
+ self ._result_type = None
42
+
43
+ def add_query_compiler (self , query_compiler ):
44
+ if isinstance (query_compiler , type ):
45
+ # class
46
+ qc_type = query_compiler
47
+ else :
48
+ # instance
49
+ qc_type = type (query_compiler )
50
+ self ._qc_list .append (query_compiler )
51
+ self ._data_cls_map [qc_type ] = query_compiler ._modin_frame
52
+ self ._qc_cls_list .append (qc_type )
53
+
54
+ def calculate (self ):
55
+ if self ._result_type is not None :
56
+ return self ._result_type
57
+ if len (self ._qc_cls_list ) == 1 :
58
+ return self ._qc_cls_list [0 ]
59
+ if len (self ._qc_cls_list ) == 0 :
60
+ raise ValueError ("No query compilers registered" )
61
+
62
+ for (qc_1 , qc_2 ) in combinations (self ._qc_list , 2 ):
63
+ costs_1 = qc_1 .qc_engine_switch_cost (qc_2 )
64
+ costs_2 = qc_2 .qc_engine_switch_cost (qc_1 )
65
+ self ._add_cost_data (costs_1 )
66
+ self ._add_cost_data (costs_2 )
67
+
68
+ min_value = min (self ._caster_costing_map .values ())
69
+ for key , value in self ._caster_costing_map .items ():
70
+ if min_value == value :
71
+ self ._result_type = key
72
+ break
73
+ return self ._result_type
74
+
75
+ def _add_cost_data (self , costs :dict ):
76
+ for k , v in costs .items ():
77
+ # filter out any extranious query compilers not in this operation
78
+ if k in self ._qc_cls_list :
79
+ QCCoercionCost .validate_coercsion_cost (v )
80
+ # Adds the costs associated with all coercions to a type, k
81
+ self ._caster_costing_map [k ] = v + self ._caster_costing_map [k ] if k in self ._caster_costing_map else v
82
+
83
+ def result_data_frame (self ):
84
+ qc_type = self .calculate ()
85
+ return self ._data_cls_map [qc_type ]
86
+
87
+
34
88
class QueryCompilerCaster :
35
89
"""Cast all query compiler arguments of the member function to current query compiler."""
36
90
@@ -55,7 +109,9 @@ def __init_subclass__(
55
109
apply_argument_cast (cls )
56
110
57
111
58
- def cast_nested_args_to_current_qc_type (arguments , current_qc ):
112
+ def visit_nested_args (arguments ,
113
+ current_qc :BaseQueryCompiler ,
114
+ fn :callable ):
59
115
"""
60
116
Cast all arguments in nested fashion to current query compiler.
61
117
@@ -70,33 +126,25 @@ def cast_nested_args_to_current_qc_type(arguments, current_qc):
70
126
Returns args and kwargs with all query compilers casted to current_qc.
71
127
"""
72
128
73
- def cast_arg_to_current_qc (arg ):
74
- current_qc_type = type (current_qc )
75
- if isinstance (arg , BaseQueryCompiler ) and not isinstance (arg , current_qc_type ):
76
- data_cls = current_qc ._modin_frame
77
- return current_qc_type .from_pandas (arg .to_pandas (), data_cls )
78
- else :
79
- return arg
80
-
81
129
imutable_types = (FrozenList , tuple )
82
130
if isinstance (arguments , imutable_types ):
83
131
args_type = type (arguments )
84
132
arguments = list (arguments )
85
- arguments = cast_nested_args_to_current_qc_type (arguments , current_qc )
133
+ arguments = visit_nested_args (arguments , current_qc , fn )
86
134
87
135
return args_type (arguments )
88
136
if isinstance (arguments , list ):
89
137
for i in range (len (arguments )):
90
138
if isinstance (arguments [i ], (list , dict )):
91
- cast_nested_args_to_current_qc_type (arguments [i ], current_qc )
139
+ visit_nested_args (arguments [i ], current_qc , fn )
92
140
else :
93
- arguments [i ] = cast_arg_to_current_qc (arguments [i ])
141
+ arguments [i ] = fn (arguments [i ])
94
142
elif isinstance (arguments , dict ):
95
143
for key in arguments :
96
144
if isinstance (arguments [key ], (list , dict )):
97
- cast_nested_args_to_current_qc_type (arguments [key ], current_qc )
145
+ visit_nested_args (arguments [key ], current_qc , fn )
98
146
else :
99
- arguments [key ] = cast_arg_to_current_qc (arguments [key ])
147
+ arguments [key ] = fn (arguments [key ])
100
148
return arguments
101
149
102
150
@@ -116,6 +164,9 @@ def apply_argument_cast(obj: Fn) -> Fn:
116
164
if isinstance (obj , type ):
117
165
all_attrs = dict (inspect .getmembers (obj ))
118
166
all_attrs .pop ("__abstractmethods__" )
167
+ all_attrs .pop ("__init__" )
168
+ all_attrs .pop ("qc_engine_switch_cost" )
169
+ all_attrs .pop ("from_pandas" )
119
170
120
171
# This is required because inspect converts class methods to member functions
121
172
current_class_attrs = vars (obj )
@@ -150,10 +201,57 @@ def cast_args(*args: Tuple, **kwargs: Dict) -> Any:
150
201
-------
151
202
Any
152
203
"""
204
+ if len (args ) == 0 and len (kwargs ) == 0 :
205
+ return
206
+ print (f"Adding wrapper { obj } \n " )
153
207
current_qc = args [0 ]
208
+ calculator = QueryCompilerCasterCalculator ()
209
+ calculator .add_query_compiler (current_qc )
210
+
211
+ def arg_needs_casting (arg ):
212
+ current_qc_type = type (current_qc )
213
+ if not isinstance (arg , BaseQueryCompiler ):
214
+ return False
215
+ if isinstance (arg , current_qc_type ):
216
+ return False
217
+ return True
218
+
219
+ def register_query_compilers (arg ):
220
+ if not arg_needs_casting (arg ):
221
+ return arg
222
+ calculator .add_query_compiler (arg )
223
+ return arg
224
+
225
+ def cast_to_qc (arg ):
226
+ if not arg_needs_casting (arg ):
227
+ return arg
228
+ qc_type = calculator .calculate ()
229
+ if qc_type == None or qc_type == type (arg ):
230
+ return arg
231
+ frame_data = calculator .result_data_frame ()
232
+ result = qc_type .from_pandas (arg .to_pandas (), frame_data )
233
+ return result
234
+
235
+
154
236
if isinstance (current_qc , BaseQueryCompiler ):
155
- kwargs = cast_nested_args_to_current_qc_type (kwargs , current_qc )
156
- args = cast_nested_args_to_current_qc_type (args , current_qc )
237
+ visit_nested_args (kwargs , current_qc , register_query_compilers )
238
+ visit_nested_args (args , current_qc , register_query_compilers )
239
+
240
+ args = visit_nested_args (args , current_qc , cast_to_qc )
241
+ kwargs = visit_nested_args (kwargs , current_qc , cast_to_qc )
242
+
243
+
244
+ qc = calculator .calculate ()
245
+
246
+ if qc == None or qc == type (current_qc ):
247
+ return obj (* args , ** kwargs )
248
+
249
+ #breakpoint()
250
+ # we need to cast current_qc to a new query compiler
251
+ if qc != current_qc :
252
+ data_cls = current_qc ._modin_frame
253
+ return qc .from_pandas (current_qc .to_pandas (), data_cls )
254
+ # need to find the new function for obj
157
255
return obj (* args , ** kwargs )
158
256
159
257
return cast_args
0 commit comments