Skip to content

Commit 4133713

Browse files
authored
update CI with the latest version of onnx, onnxruntime (#706)
* update CI Signed-off-by: xadupre <[email protected]> * lint Signed-off-by: xadupre <[email protected]> * fix architecture Signed-off-by: xadupre <[email protected]> * remove architecture Signed-off-by: xadupre <[email protected]> * move one line * fix xgboost Signed-off-by: xadupre <[email protected]> * catch unable to import Signed-off-by: xadupre <[email protected]> * adjust import Signed-off-by: xadupre <[email protected]> * import Signed-off-by: xadupre <[email protected]> * none or Signed-off-by: xadupre <[email protected]> * another fix Signed-off-by: xadupre <[email protected]> * xgb Signed-off-by: xadupre <[email protected]> * fix comparison Signed-off-by: xadupre <[email protected]> * fix ut Signed-off-by: xadupre <[email protected]> * fix unittest Signed-off-by: xadupre <[email protected]> * atol Signed-off-by: xadupre <[email protected]> --------- Signed-off-by: xadupre <[email protected]>
1 parent 3ae696a commit 4133713

10 files changed

+163
-38
lines changed

.github/workflows/ci.yml

+11-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@ jobs:
99
os: [ubuntu-latest, macos-latest, windows-latest]
1010
python_version: ['3.12', '3.11', '3.10', '3.9']
1111
include:
12+
- python_version: '3.12'
13+
documentation: 1
14+
numpy_version: '>=1.21.1'
15+
scipy_version: '>=1.7.0'
16+
onnx_version: 'onnx==1.17.0'
17+
onnxrt_version: 'onnxruntime==1.20.1'
18+
sklearn_version: '==1.6.0'
19+
lgbm_version: ">=4"
20+
xgboost_version: ">=2"
1221
- python_version: '3.12'
1322
documentation: 0
1423
numpy_version: '>=1.21.1'
@@ -19,7 +28,7 @@ jobs:
1928
lgbm_version: ">=4"
2029
xgboost_version: ">=2"
2130
- python_version: '3.11'
22-
documentation: 1
31+
documentation: 0
2332
numpy_version: '>=1.21.1'
2433
scipy_version: '>=1.7.0'
2534
onnx_version: 'onnx<1.16.0'
@@ -82,20 +91,19 @@ jobs:
8291

8392
- name: versions
8493
run: |
85-
python -c "from numpy import __version__;print('numpy', __version__)"
8694
python -c "from pandas import __version__;print('pandas', __version__)"
8795
python -c "from scipy import __version__;print('scipy', __version__)"
8896
python -c "from sklearn import __version__;print('sklearn', __version__)"
8997
python -c "from onnxruntime import __version__;print('onnxruntime', __version__)"
9098
python -c "from onnx import __version__;print('onnx', __version__)"
91-
python -c "from xgboost import __version__;print('xgboost', __version__)"
9299
python -c "from catboost import __version__;print('catboost', __version__)"
93100
python -c "import onnx.defs;print('onnx_opset_version', onnx.defs.onnx_opset_version())"
94101
95102
- name: versions lightgbm
96103
if: matrix.os != 'macos-latest'
97104
run: |
98105
python -c "from lightgbm import __version__;print('lightgbm', __version__)"
106+
python -c "from xgboost import __version__;print('xgboost', __version__)"
99107
100108
- name: Run tests baseline
101109
run: pytest --maxfail=10 --durations=10 tests/baseline

onnxmltools/convert/coreml/shape_calculators/OneHotEncoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def calculate_one_hot_encoder_output_shapes(operator):
2626
operator.outputs[0].type = FloatTensorType(
2727
[N, len(int_categories)], doc_string=operator.outputs[0].type.doc_string
2828
)
29-
elif len(str_categories) > 0 and type(operator.inputs[0].type) == StringTensorType:
29+
elif len(str_categories) > 0 and type(operator.inputs[0].type) is StringTensorType:
3030
operator.outputs[0].type = FloatTensorType(
3131
[N, len(str_categories)], doc_string=operator.outputs[0].type.doc_string
3232
)

onnxmltools/utils/tests_helper.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import unittest
34
import pickle
45
import os
56
import numpy
@@ -87,16 +88,22 @@ def dump_data_and_model(
8788
if not os.path.exists(folder):
8889
os.makedirs(folder)
8990

90-
if hasattr(model, "predict"):
91+
if "LGBM" in model.__class__.__name__:
9192
try:
9293
import lightgbm
9394
except ImportError:
94-
lightgbm = None
95+
raise unittest.SkipTest("lightgbm cannot be imported.")
96+
else:
97+
lightgbm = None
98+
if "XGB" in model.__class__.__name__ or "Booster" in model.__class__.__name__:
9599
try:
96100
import xgboost
97101
except ImportError:
98-
xgboost = None
102+
raise unittest.SkipTest("xgboost cannot be imported.")
103+
else:
104+
xgboost = None
99105

106+
if hasattr(model, "predict"):
100107
if lightgbm is not None and isinstance(model, lightgbm.Booster):
101108
# LightGBM Booster
102109
model_dict = model.dump_model()

onnxmltools/utils/utils_backend_onnxruntime.py

+24
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,30 @@ def _compare_expected(
297297
(len(expected), len(output.ravel()) // len(expected))
298298
)
299299
if len(expected) != len(output):
300+
if (
301+
len(output) == 2
302+
and len(expected) == 1
303+
and output[0].dtype in (numpy.int64, numpy.int32)
304+
):
305+
# a classifier
306+
if len(expected[0].shape) == 1:
307+
expected = [
308+
numpy.hstack(
309+
[
310+
1 - expected[0].reshape((-1, 1)),
311+
expected[0].reshape((-1, 1)),
312+
]
313+
)
314+
]
315+
return _compare_expected(
316+
expected,
317+
output[1:],
318+
sess,
319+
onnx,
320+
decimal=5,
321+
onnx_shape=None,
322+
**kwargs
323+
)
300324
raise OnnxRuntimeAssertionError(
301325
"Unexpected number of outputs '{0}', expected={1}, got={2}".format(
302326
onnx, len(expected), len(output)

0 commit comments

Comments
 (0)