forked from paris-saclay-cds/ramp-workflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassifier.py
23 lines (19 loc) · 832 Bytes
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import imp
class Classifier(object):
def __init__(self, workflow_element_names=['classifier']):
self.element_names = workflow_element_names
# self.name = 'classifier_workflow' # temporary
def train_submission(self, module_path, X_array, y_array, train_is=None):
if train_is is None:
train_is = slice(None, None, None)
submitted_classifier_file = '{}/{}.py'.format(
module_path, self.element_names[0])
classifier = imp.load_source(
self.element_names[0], submitted_classifier_file)
clf = classifier.Classifier()
clf.fit(X_array[train_is], y_array[train_is])
return clf
def test_submission(self, trained_model, X_array):
clf = trained_model
y_proba = clf.predict_proba(X_array)
return y_proba