diff --git a/examples/estimator/classifier/DecisionTreeClassifier/c/basics.pct.multilabel.py b/examples/estimator/classifier/DecisionTreeClassifier/c/basics.pct.multilabel.py new file mode 100644 index 00000000..d3ef575c --- /dev/null +++ b/examples/estimator/classifier/DecisionTreeClassifier/c/basics.pct.multilabel.py @@ -0,0 +1,75 @@ +# %% [markdown] +# # sklearn-porter +# +# Repository: [https://github.com/nok/sklearn-porter](https://github.com/nok/sklearn-porter) +# +# ## DecisionTreeClassifier +# +# Documentation: [sklearn.tree.DecisionTreeClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html) + +# %% +import sys +sys.path.append('../../../../..') + +# %% [markdown] +# ### Load data + +# %% +from sklearn.datasets import load_iris + +iris_data = load_iris() + +X = iris_data.data +y = iris_data.target + +print(X.shape, y.shape) + +# %% [markdown] +# ### Train classifier + +# %% +from sklearn.tree import tree +import numpy as np + +# transfer single-output into multi-labels +y_multi_label = [] +for x in iris_data.target: + if x == 0: + y_multi_label.append([1,1,0]) + elif x == 1: + y_multi_label.append([0,1,1]) + else: + y_multi_label.append([1,0,1]) +y = np.array(y_multi_label) + +from sklearn.model_selection import train_test_split +X_train, X_test, y_train, y_test = train_test_split( + iris_data.data, y, test_size=0.33, random_state=42) + +clf = tree.DecisionTreeClassifier(random_state=0) +clf.fit(X_train, y_train) + +# %% [markdown] +# ### Transpile classifier + +# %% +from sklearn_porter import Porter + +porter = Porter(clf, language='c') +output = porter.export() + +print(output) + +# %% [markdown] +# ### Run classification in C + +# %% +# Save model: +# with open('tree.c', 'w') as f: +# f.write(output) + +# Compile model: +# $ gcc tree.c -std=c99 -lm -o tree + +# Run classification: +# $ ./tree 1 2 3 4 diff --git a/sklearn_porter/estimator/classifier/DecisionTreeClassifier/__init__.py b/sklearn_porter/estimator/classifier/DecisionTreeClassifier/__init__.py index c0ab1191..b2e8a854 100644 --- a/sklearn_porter/estimator/classifier/DecisionTreeClassifier/__init__.py +++ b/sklearn_porter/estimator/classifier/DecisionTreeClassifier/__init__.py @@ -188,6 +188,19 @@ def export(self, class_name, method_name, export_data=False, classes = ', '.join([temp_arr_scope.format(v) for v in classes]) classes = temp_arr__.format(type='int', name='classes', values=classes, n=n, m=m) + + # transfer dt to c language and is multilabel task + if self.target_language in ['c'] and self.estimator.tree_.value.ndim == 3: + import numpy as np + classes = np.argmax(self.estimator.tree_.value, axis=2).tolist() + n = len(classes) + m = len(classes[0]) + classes = [', '.join([str(int(x)) for x in e]) for e in classes] + classes = ', '.join([temp_arr_scope.format(v) for v in classes]) + classes = temp_arr__.format(type='int', name='classes', values=classes, + n=n, m=m) + self.n_outputs_ = self.estimator.n_outputs_ + self.classes = classes if self.target_method == 'predict': @@ -254,6 +267,9 @@ def predict(self, temp_type='separated'): if temp_type == 'separated': separated_temp = self.temp('separated.class') + # transfer dt to c language and is multilabel task + if self.target_language in ['c'] and self.estimator.tree_.value.ndim == 3: + separated_temp = self.temp('separated.multilabels.class') return separated_temp.format(**self.__dict__) if temp_type == 'embedded': diff --git a/sklearn_porter/estimator/classifier/DecisionTreeClassifier/templates/c/separated.multilabels.class.txt b/sklearn_porter/estimator/classifier/DecisionTreeClassifier/templates/c/separated.multilabels.class.txt new file mode 100644 index 00000000..1284b876 --- /dev/null +++ b/sklearn_porter/estimator/classifier/DecisionTreeClassifier/templates/c/separated.multilabels.class.txt @@ -0,0 +1,41 @@ +#include +#include +#include + +#define N_FEATURES {n_features} +#define N_OUTPUTS {n_outputs_} + +{left_childs} +{right_childs} +{thresholds} +{indices} +{classes} + + +int* {method_name}(double features[N_FEATURES]) {{ + int node = 0; //root node id is 0 + while (thresholds[node] != -2) {{ + if (features[indices[node]] <= thresholds[node]) {{ + node = lChilds[node]; + }} else {{ + node = rChilds[node]; + }} + }} + return classes[node]; +}} + +int main(int argc, const char * argv[]) {{ + + /* Features: */ + double features[argc-1]; + int i; + for (i = 1; i < argc; i++) {{ + features[i-1] = atof(argv[i]); + }} + + /* Prediction: */ + int* output = {method_name}(features); + for(int i=0;i