Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions openfe/FeatureSelector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .FeatureGenerator import *
from concurrent.futures import ProcessPoolExecutor
from sklearn.feature_selection import mutual_info_regression, mutual_info_classif
from sklearn.metrics import mean_squared_error, log_loss, roc_auc_score
from sklearn.metrics import root_mean_squared_error, log_loss, roc_auc_score
import scipy.special
from datetime import datetime
import warnings
Expand Down Expand Up @@ -562,7 +562,7 @@ def get_init_metric(self, pred, label):
init_metric = log_loss(label, scipy.special.softmax(pred, axis=1),
labels=list(range(pred.shape[1])))
elif self.metric == 'rmse':
init_metric = mean_squared_error(label, pred, squared=False)
init_metric = root_mean_squared_error(label, pred)
elif self.metric == 'auc':
init_metric = roc_auc_score(label, scipy.special.expit(pred))
else:
Expand Down
4 changes: 2 additions & 2 deletions openfe/openfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .utils import tree_to_formula, check_xor, formula_to_tree
from sklearn.inspection import permutation_importance
from sklearn.feature_selection import mutual_info_regression, mutual_info_classif
from sklearn.metrics import mean_squared_error, log_loss, roc_auc_score
from sklearn.metrics import root_mean_squared_error, log_loss, roc_auc_score
import scipy.special
from copy import deepcopy
from tqdm import tqdm
Expand Down Expand Up @@ -569,7 +569,7 @@ def get_init_metric(self, pred, label):
init_metric = log_loss(label, scipy.special.softmax(pred, axis=1),
labels=list(range(pred.shape[1])))
elif self.metric == 'rmse':
init_metric = mean_squared_error(label, pred, squared=False)
init_metric = root_mean_squared_error(label, pred)
elif self.metric == 'auc':
init_metric = roc_auc_score(label, scipy.special.expit(pred))
else:
Expand Down