Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix converting xgboost with user given feature_names #1395

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions coremltools/converters/xgboost/_tree_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def recurse_json(
if "leaf" not in xgb_tree_json:
branch_mode = "BranchOnValueLessThan"
split_name = xgb_tree_json["split"]
feature_index = split_name if not feature_map else feature_map[split_name]
if split_name in feature_map:
feature_index = feature_map[split_name]
else:
feature_index = int(split_name)

# xgboost internally uses float32, but the parsing from json pulls it out
# as a 64bit double. To trigger the internal float32 detection in the
Expand Down Expand Up @@ -157,7 +160,6 @@ def convert_tree_ensemble(
import json
import os

feature_map = None
if isinstance(
model, (_xgboost.core.Booster, _xgboost.XGBRegressor, _xgboost.XGBClassifier)
):
Expand Down Expand Up @@ -202,15 +204,13 @@ def convert_tree_ensemble(
raise ValueError(
"The XGBoost model does not have feature names. They must be provided in convert method."
)
feature_names = model.feature_names
# Use user given feature names if they exist
if feature_names is None:
feature_names = model.feature_names

feature_map = {f: i for i, f in enumerate(feature_names)}

xgb_model_str = model.get_dump(with_stats=True, dump_format="json")

if model.feature_names:
feature_map = {f: i for i, f in enumerate(model.feature_names)}

# Path on the file system where the XGboost model exists.
elif isinstance(model, str):
if not os.path.exists(model):
Expand Down
15 changes: 15 additions & 0 deletions coremltools/test/xgboost_tests/test_boosted_trees_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,18 @@ def test_conversion_bad_inputs(self):
with self.assertRaises(TypeError):
model = OneHotEncoder()
spec = xgb_converter.convert(model, "data", "out")

def test_conversion_model_without_feature_names(self):
# Train model without feature names
dtrain = xgboost.DMatrix(
self.scikit_data.data,
label=self.scikit_data.target
)
model = xgboost.train({}, dtrain, 1)

spec = xgb_converter.convert(model, feature_names=self.feature_names)

self.assertEqual(
sorted(self.feature_names),
sorted(map(lambda x: x.name, spec.description.input))
)