forked from paris-saclay-cds/ramp-workflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfeature_extractor_classifier.py
29 lines (25 loc) · 1.21 KB
/
feature_extractor_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from .feature_extractor import FeatureExtractor
from .classifier import Classifier
class FeatureExtractorClassifier(object):
def __init__(self, workflow_element_names=[
'feature_extractor', 'classifier']):
self.element_names = workflow_element_names
self.feature_extractor_workflow = FeatureExtractor(
[self.element_names[0]])
self.classifier_workflow = Classifier([self.element_names[1]])
def train_submission(self, module_path, X_df, y_array, train_is=None):
if train_is is None:
train_is = slice(None, None, None)
fe = self.feature_extractor_workflow.train_submission(
module_path, X_df, y_array, train_is)
X_train_array = self.feature_extractor_workflow.test_submission(
fe, X_df.iloc[train_is])
clf = self.classifier_workflow.train_submission(
module_path, X_train_array, y_array[train_is])
return fe, clf
def test_submission(self, trained_model, X_df):
fe, clf = trained_model
X_test_array = self.feature_extractor_workflow.test_submission(
fe, X_df)
y_proba = self.classifier_workflow.test_submission(clf, X_test_array)
return y_proba