Skip to content

Commit 0953819

Browse files
support for "huber" objective in the LGBM Booster (#705)
* support for huber objective in the LGBM Booster Signed-off-by: Łukasz Ćmielowski <[email protected]> * black & ruff updates Signed-off-by: Łukasz Ćmielowski <[email protected]> --------- Signed-off-by: Łukasz Ćmielowski <[email protected]>
1 parent 2834b4d commit 0953819

File tree

3 files changed

+48
-6
lines changed

3 files changed

+48
-6
lines changed

onnxmltools/convert/lightgbm/_parse.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ def __init__(self, booster):
3636
elif self.objective_.startswith("multiclass"):
3737
self.operator_name = "LgbmClassifier"
3838
self.classes_ = self._generate_classes(booster)
39-
elif self.objective_.startswith("regression"):
39+
elif self.objective_.startswith(
40+
("regression", "poisson", "gamma", "quantile", "huber")
41+
):
4042
self.operator_name = "LgbmRegressor"
4143
else:
4244
raise NotImplementedError(

onnxmltools/convert/lightgbm/operator_converters/LightGbm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def convert_lightgbm(scope, operator, container):
555555
elif gbm_text["objective"].startswith("multiclass"):
556556
n_classes = gbm_text["num_class"]
557557
attrs["post_transform"] = "SOFTMAX"
558-
elif gbm_text["objective"].startswith(("regression", "quantile")):
558+
elif gbm_text["objective"].startswith(("regression", "quantile", "huber")):
559559
n_classes = 1 # Regressor has only one output variable
560560
attrs["post_transform"] = "NONE"
561561
attrs["n_targets"] = n_classes

tests/lightgbm/test_objective_functions.py

+44-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from onnxruntime import InferenceSession
1313
from pandas.core.frame import DataFrame
1414

15-
from lightgbm import LGBMRegressor
15+
from lightgbm import LGBMRegressor, Booster, Dataset
1616

1717
_N_ROWS = 10_000
1818
_N_COLS = 10
@@ -31,7 +31,13 @@
3131

3232

3333
class ObjectiveTest(unittest.TestCase):
34-
_objectives: Tuple[str] = ("regression", "poisson", "gamma", "quantile")
34+
_regressor_objectives: Tuple[str] = (
35+
"regression",
36+
"poisson",
37+
"gamma",
38+
"quantile",
39+
"huber",
40+
)
3541

3642
@staticmethod
3743
def _calc_initial_types(X: DataFrame) -> List[Tuple[str, TensorType]]:
@@ -83,7 +89,7 @@ def _assert_almost_equal(
8389
tuple(int(ver) for ver in onnxruntime.__version__.split(".")[:2]) < (1, 3),
8490
"not supported in this library version",
8591
)
86-
def test_objective(self):
92+
def test_objective_LGBMRegressor(self):
8793
"""
8894
Test if a LGBMRegressor a with certain objective (e.g. 'poisson')
8995
can be converted to ONNX
@@ -95,7 +101,7 @@ def test_objective(self):
95101
and therefore sometimes fails randomly. In these cases,
96102
a retry should resolve the issue.
97103
"""
98-
for objective in self._objectives:
104+
for objective in self._regressor_objectives:
99105
with self.subTest(X=_X, objective=objective):
100106
regressor = LGBMRegressor(objective=objective, num_thread=1)
101107
regressor.fit(_X, _Y)
@@ -113,6 +119,40 @@ def test_objective(self):
113119
frac=_FRAC,
114120
)
115121

122+
def test_objective_Booster(self):
123+
"""
124+
Test if a Booster a with certain objective (e.g. 'poisson')
125+
can be converted to ONNX
126+
and whether the ONNX graph and the original model produce
127+
almost equal predictions.
128+
129+
Note that this tests is a bit flaky because of precision
130+
differences with ONNX and LightGBM
131+
and therefore sometimes fails randomly. In these cases,
132+
a retry should resolve the issue.
133+
"""
134+
for objective in self._regressor_objectives:
135+
with self.subTest(X=_X, objective=objective):
136+
ds = Dataset(_X, feature_name="auto").construct()
137+
ds.set_label(_Y)
138+
regressor = Booster(params={"objective": objective}, train_set=ds)
139+
for k in range(10):
140+
regressor.update()
141+
142+
regressor_onnx: ModelProto = convert_lightgbm(
143+
regressor,
144+
initial_types=self._calc_initial_types(_X),
145+
target_opset=TARGET_OPSET,
146+
)
147+
y_pred = regressor.predict(_X)
148+
y_pred_onnx = self._predict_with_onnx(regressor_onnx, _X)
149+
self._assert_almost_equal(
150+
y_pred,
151+
y_pred_onnx,
152+
decimal=_N_DECIMALS,
153+
frac=_FRAC,
154+
)
155+
116156

117157
if __name__ == "__main__":
118158
unittest.main()

0 commit comments

Comments
 (0)