diff --git a/openfe/FeatureSelector.py b/openfe/FeatureSelector.py index b0e7337..c232c9a 100644 --- a/openfe/FeatureSelector.py +++ b/openfe/FeatureSelector.py @@ -13,7 +13,7 @@ from .FeatureGenerator import * from concurrent.futures import ProcessPoolExecutor from sklearn.feature_selection import mutual_info_regression, mutual_info_classif -from sklearn.metrics import mean_squared_error, log_loss, roc_auc_score +from sklearn.metrics import root_mean_squared_error, log_loss, roc_auc_score import scipy.special from datetime import datetime import warnings @@ -562,7 +562,7 @@ def get_init_metric(self, pred, label): init_metric = log_loss(label, scipy.special.softmax(pred, axis=1), labels=list(range(pred.shape[1]))) elif self.metric == 'rmse': - init_metric = mean_squared_error(label, pred, squared=False) + init_metric = root_mean_squared_error(label, pred) elif self.metric == 'auc': init_metric = roc_auc_score(label, scipy.special.expit(pred)) else: diff --git a/openfe/openfe.py b/openfe/openfe.py index 8d7596c..4f344fc 100644 --- a/openfe/openfe.py +++ b/openfe/openfe.py @@ -11,7 +11,7 @@ from .utils import tree_to_formula, check_xor, formula_to_tree from sklearn.inspection import permutation_importance from sklearn.feature_selection import mutual_info_regression, mutual_info_classif -from sklearn.metrics import mean_squared_error, log_loss, roc_auc_score +from sklearn.metrics import root_mean_squared_error, log_loss, roc_auc_score import scipy.special from copy import deepcopy from tqdm import tqdm @@ -569,7 +569,7 @@ def get_init_metric(self, pred, label): init_metric = log_loss(label, scipy.special.softmax(pred, axis=1), labels=list(range(pred.shape[1]))) elif self.metric == 'rmse': - init_metric = mean_squared_error(label, pred, squared=False) + init_metric = root_mean_squared_error(label, pred) elif self.metric == 'auc': init_metric = roc_auc_score(label, scipy.special.expit(pred)) else: