12
12
from onnxruntime import InferenceSession
13
13
from pandas .core .frame import DataFrame
14
14
15
- from lightgbm import LGBMRegressor
15
+ from lightgbm import LGBMRegressor , Booster , Dataset
16
16
17
17
_N_ROWS = 10_000
18
18
_N_COLS = 10
31
31
32
32
33
33
class ObjectiveTest (unittest .TestCase ):
34
- _objectives : Tuple [str ] = ("regression" , "poisson" , "gamma" , "quantile" )
34
+ _regressor_objectives : Tuple [str ] = (
35
+ "regression" ,
36
+ "poisson" ,
37
+ "gamma" ,
38
+ "quantile" ,
39
+ "huber" ,
40
+ )
35
41
36
42
@staticmethod
37
43
def _calc_initial_types (X : DataFrame ) -> List [Tuple [str , TensorType ]]:
@@ -83,7 +89,7 @@ def _assert_almost_equal(
83
89
tuple (int (ver ) for ver in onnxruntime .__version__ .split ("." )[:2 ]) < (1 , 3 ),
84
90
"not supported in this library version" ,
85
91
)
86
- def test_objective (self ):
92
+ def test_objective_LGBMRegressor (self ):
87
93
"""
88
94
Test if a LGBMRegressor a with certain objective (e.g. 'poisson')
89
95
can be converted to ONNX
@@ -95,7 +101,7 @@ def test_objective(self):
95
101
and therefore sometimes fails randomly. In these cases,
96
102
a retry should resolve the issue.
97
103
"""
98
- for objective in self ._objectives :
104
+ for objective in self ._regressor_objectives :
99
105
with self .subTest (X = _X , objective = objective ):
100
106
regressor = LGBMRegressor (objective = objective , num_thread = 1 )
101
107
regressor .fit (_X , _Y )
@@ -113,6 +119,40 @@ def test_objective(self):
113
119
frac = _FRAC ,
114
120
)
115
121
122
+ def test_objective_Booster (self ):
123
+ """
124
+ Test if a Booster a with certain objective (e.g. 'poisson')
125
+ can be converted to ONNX
126
+ and whether the ONNX graph and the original model produce
127
+ almost equal predictions.
128
+
129
+ Note that this tests is a bit flaky because of precision
130
+ differences with ONNX and LightGBM
131
+ and therefore sometimes fails randomly. In these cases,
132
+ a retry should resolve the issue.
133
+ """
134
+ for objective in self ._regressor_objectives :
135
+ with self .subTest (X = _X , objective = objective ):
136
+ ds = Dataset (_X , feature_name = "auto" ).construct ()
137
+ ds .set_label (_Y )
138
+ regressor = Booster (params = {"objective" : objective }, train_set = ds )
139
+ for k in range (10 ):
140
+ regressor .update ()
141
+
142
+ regressor_onnx : ModelProto = convert_lightgbm (
143
+ regressor ,
144
+ initial_types = self ._calc_initial_types (_X ),
145
+ target_opset = TARGET_OPSET ,
146
+ )
147
+ y_pred = regressor .predict (_X )
148
+ y_pred_onnx = self ._predict_with_onnx (regressor_onnx , _X )
149
+ self ._assert_almost_equal (
150
+ y_pred ,
151
+ y_pred_onnx ,
152
+ decimal = _N_DECIMALS ,
153
+ frac = _FRAC ,
154
+ )
155
+
116
156
117
157
if __name__ == "__main__" :
118
158
unittest .main ()
0 commit comments