forked from paris-saclay-cds/ramp-workflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathregressor.py
23 lines (19 loc) · 816 Bytes
/
regressor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import imp
class Regressor(object):
def __init__(self, workflow_element_names=['regressor']):
self.element_names = workflow_element_names
# self.name = 'regressor_workflow' # temporary
def train_submission(self, module_path, X_array, y_array, train_is=None):
if train_is is None:
train_is = slice(None, None, None)
submitted_regressor_file = '{}/{}.py'.format(
module_path, self.element_names[0])
regressor = imp.load_source(
self.element_names[0], submitted_regressor_file)
reg = regressor.Regressor()
reg.fit(X_array[train_is], y_array[train_is])
return reg
def test_submission(self, trained_model, X_array):
reg = trained_model
y_pred = reg.predict(X_array)
return y_pred