diff --git a/engines/python/setup/djl_python/encode_decode.py b/engines/python/setup/djl_python/encode_decode.py index cfe46c21b..be6851e69 100644 --- a/engines/python/setup/djl_python/encode_decode.py +++ b/engines/python/setup/djl_python/encode_decode.py @@ -21,21 +21,30 @@ import numpy as np -def decode_csv(inputs: Input): # type: (str) -> np.array +def decode_csv(inputs: Input, require_headers=True): # type: (str) -> np.array csv_content = inputs.get_as_string() - stream = StringIO(csv_content) - # detects if the incoming csv has headers - if not any(header in csv_content.splitlines()[0].lower() - for header in ["question", "context", "inputs"]): - raise ValueError( - "You need to provide the correct CSV with Header columns to use it with the inference toolkit default handler.", - ) - # reads csv as io - request_list = list(csv.DictReader(stream)) - if "inputs" in request_list[0].keys(): - return {"inputs": [entry["inputs"] for entry in request_list]} + + if require_headers: + if not any(header in csv_content.splitlines()[0].lower() + for header in ["question", "context", "inputs"]): + raise ValueError( + "You need to provide the correct CSV with Header columns to use it with the inference toolkit default handler.", + ) + stream = StringIO(csv_content) + request_list = list(csv.DictReader(stream)) + if "inputs" in request_list[0].keys(): + return {"inputs": [entry["inputs"] for entry in request_list]} + else: + return {"inputs": request_list} else: - return {"inputs": request_list} + # for preditive ML inputs + result = np.genfromtxt(StringIO(csv_content), delimiter=",") + # Check for NaN values which indicate non-numeric data + if np.isnan(result).any(): + raise ValueError( + "CSV contains non-numeric data. Please provide numeric data only." + ) + return result def encode_csv(content): # type: (str) -> np.array @@ -51,7 +60,10 @@ def encode_csv(content): # type: (str) -> np.array return stream.getvalue() -def decode(inputs: Input, content_type: str, key=None): +def decode(inputs: Input, + content_type: str, + key=None, + require_csv_headers=True): if not content_type: ret = inputs.get_as_bytes(key=key) if not ret: @@ -60,7 +72,7 @@ def decode(inputs: Input, content_type: str, key=None): elif "application/json" in content_type: return inputs.get_as_json(key=key) elif "text/csv" in content_type: - return decode_csv(inputs) + return decode_csv(inputs, require_headers=require_csv_headers) elif "text/plain" in content_type: return {"inputs": [inputs.get_as_string(key=key)]} if content_type.startswith("image/"): diff --git a/engines/python/setup/djl_python/import_utils.py b/engines/python/setup/djl_python/import_utils.py new file mode 100644 index 000000000..6713022e8 --- /dev/null +++ b/engines/python/setup/djl_python/import_utils.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +# +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +import importlib.util +import importlib.metadata + + +def _is_package_available(pkg_name: str) -> bool: + """Check if a package is available""" + package_exists = importlib.util.find_spec(pkg_name) is not None + if package_exists: + try: + importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: + package_exists = False + return package_exists + + +# SKLearn model persistance libraries +_joblib_available = _is_package_available("joblib") +_cloudpickle_available = _is_package_available("cloudpickle") +_skops_available = _is_package_available("skops") + +# XGBoost +_xgboost_available = _is_package_available("xgboost") + + +def is_joblib_available() -> bool: + return _joblib_available + + +def is_cloudpickle_available() -> bool: + return _cloudpickle_available + + +def is_skops_available() -> bool: + return _skops_available + + +def is_xgboost_available() -> bool: + return _xgboost_available + + +joblib = None +if _joblib_available: + import joblib + +cloudpickle = None +if _cloudpickle_available: + import cloudpickle + +skops_io = None +if _skops_available: + import skops.io as skops_io + +xgboost = None +if _xgboost_available: + import xgboost diff --git a/engines/python/setup/djl_python/sklearn_handler.py b/engines/python/setup/djl_python/sklearn_handler.py new file mode 100644 index 000000000..fa9d4b03b --- /dev/null +++ b/engines/python/setup/djl_python/sklearn_handler.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python +# +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +import pickle +import numpy as np +import os +from io import StringIO +from typing import Optional +from djl_python import Input, Output +from djl_python.encode_decode import decode +from djl_python.utils import find_model_file +from djl_python.service_loader import get_annotated_function +from djl_python.import_utils import joblib, cloudpickle, skops_io as sio + + +class SklearnHandler: + + def __init__(self): + self.model = None + self.initialized = False + self.custom_input_formatter = None + self.custom_output_formatter = None + self.custom_predict_formatter = None + + def _get_trusted_types(self, properties: dict): + trusted_types_str = properties.get("skops_trusted_types", "") + if not trusted_types_str: + raise ValueError( + "option.skops_trusted_types must be set to load skops models. " + "Example: option.skops_trusted_types='sklearn.ensemble._forest.RandomForestClassifier,numpy.ndarray'" + ) + trusted_types = [ + t.strip() for t in trusted_types_str.split(",") if t.strip() + ] + print(f"Using trusted types for skops model loading: {trusted_types}") + return trusted_types + + def initialize(self, properties: dict): + model_dir = properties.get("model_dir") + model_format = properties.get("model_format", "skops") + + format_extensions = { + "skops": ["skops"], + "joblib": ["joblib", "jl"], + "pickle": ["pkl", "pickle"], + "cloudpickle": ["pkl", "pickle", "cloudpkl"] + } + + extensions = format_extensions.get(model_format) + if not extensions: + raise ValueError( + f"Unsupported model format: {model_format}. Supported formats: skops, joblib, pickle, cloudpickle" + ) + + model_file = find_model_file(model_dir, extensions) + if not model_file: + raise FileNotFoundError( + f"No model file found with format '{model_format}' in {model_dir}" + ) + + if model_format == "skops": + trusted_types = self._get_trusted_types(properties) + self.model = sio.load(model_file, trusted=trusted_types) + else: + if properties.get("trust_insecure_model_files", + "false").lower() != "true": + raise ValueError( + f"option.trust_insecure_model_files must be set to 'true' to use {model_format} format (only skops is secure by default)" + ) + + if model_format == "joblib": + self.model = joblib.load(model_file) + elif model_format == "pickle": + with open(model_file, 'rb') as f: + self.model = pickle.load(f) + elif model_format == "cloudpickle": + with open(model_file, 'rb') as f: + self.model = cloudpickle.load(f) + + self.custom_input_formatter = get_annotated_function( + model_dir, "is_input_formatter") + self.custom_output_formatter = get_annotated_function( + model_dir, "is_output_formatter") + self.custom_predict_formatter = get_annotated_function( + model_dir, "is_predict_formatter") + + self.initialized = True + + def inference(self, inputs: Input) -> Output: + content_type = inputs.get_property("Content-Type") + accept = inputs.get_property("Accept") or "application/json" + + # Validate accept type (skip validation if custom output formatter is provided) + if not self.custom_output_formatter: + supported_accept_types = ["application/json", "text/csv"] + if not any(supported_type in accept + for supported_type in supported_accept_types): + raise ValueError( + f"Unsupported Accept type: {accept}. Supported types: {supported_accept_types}" + ) + + # Input processing + X = None + if self.custom_input_formatter: + X = self.custom_input_formatter(inputs) + elif "text/csv" in content_type: + X = decode(inputs, content_type, require_csv_headers=False) + else: + input_map = decode(inputs, content_type) + data = input_map.get("inputs") if isinstance(input_map, + dict) else input_map + X = np.array(data) + + if X is None or not hasattr(X, 'ndim'): + raise ValueError( + f"Input processing failed for content type {content_type}") + + if X.ndim == 1: + X = X.reshape(1, -1) + + if self.custom_predict_formatter: + predictions = self.custom_predict_formatter(self.model, X) + else: + predictions = self.model.predict(X) + + # Output processing + if self.custom_output_formatter: + return self.custom_output_formatter(predictions) + + # Supports CSV/JSON outputs by default + outputs = Output() + if "text/csv" in accept: + csv_buffer = StringIO() + np.savetxt(csv_buffer, predictions, fmt='%s', delimiter=',') + outputs.add(csv_buffer.getvalue().rstrip()) + outputs.add_property("Content-Type", "text/csv") + else: + outputs.add_as_json({"predictions": predictions.tolist()}) + return outputs + + +service = SklearnHandler() + + +def handle(inputs: Input) -> Optional[Output]: + if not service.initialized: + service.initialize(inputs.get_properties()) + + if inputs.is_empty(): + return None + + return service.inference(inputs) diff --git a/engines/python/setup/djl_python/utils.py b/engines/python/setup/djl_python/utils.py index a5c8cc48a..4b2a4c3f7 100644 --- a/engines/python/setup/djl_python/utils.py +++ b/engines/python/setup/djl_python/utils.py @@ -10,7 +10,10 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. +import glob import logging +import os +from typing import Optional, List from djl_python import Output from djl_python.inputs import Input @@ -161,3 +164,27 @@ def get_input_details(requests, errors, batch): idx += 1 adapters = adapters if adapters else None return input_data, input_size, parameters, adapters + + +def find_model_file(model_dir: str, extensions: List[str]) -> Optional[str]: + """Find model file with given extensions in model directory + + Args: + model_dir: Directory to search for model files + extensions: List of file extensions to search for (without dots) + + Returns: + Path to matching model file, or None if not found + """ + all_matches = [] + for ext in extensions: + pattern = os.path.join(model_dir, f"*.{ext}") + matches = glob.glob(pattern) + all_matches.extend(matches) + + if len(all_matches) > 1: + raise ValueError( + f"Multiple model files found in {model_dir}: {all_matches}. Only one model file is supported per directory." + ) + + return all_matches[0] if all_matches else None diff --git a/engines/python/setup/djl_python/xgboost_handler.py b/engines/python/setup/djl_python/xgboost_handler.py new file mode 100644 index 000000000..0f08508aa --- /dev/null +++ b/engines/python/setup/djl_python/xgboost_handler.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python +# +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +import numpy as np +import os +import pickle as pkl +from io import StringIO +from typing import Optional +from djl_python import Input, Output +from djl_python.encode_decode import decode +from djl_python.utils import find_model_file +from djl_python.service_loader import get_annotated_function +from djl_python.import_utils import xgboost as xgb + + +class XGBoostHandler: + + def __init__(self): + self.model = None + self.initialized = False + + def initialize(self, properties: dict): + model_dir = properties.get("model_dir") + model_format = (properties.get("model_format") + or os.environ.get("MODEL_FORMAT") or "json") + + format_extensions = { + "json": ["json"], + "ubj": ["ubj"], + "pickle": ["pkl", "pickle"], + "xgb": ["xgb", "model", "bst"] + } + + extensions = format_extensions.get(model_format) + if not extensions: + raise ValueError( + f"Unsupported model format: {model_format}. Supported formats: json, ubj, pickle, xgb" + ) + + model_file = find_model_file(model_dir, extensions) + if not model_file: + raise FileNotFoundError( + f"No model file found with format '{model_format}' in {model_dir}" + ) + + if model_format in ["json", "ubj"]: + self.model = xgb.Booster() + self.model.load_model(model_file) + else: # unsafe formats: pickle, xgb + trust_insecure = (properties.get("trust_insecure_model_files") + or os.environ.get("TRUST_INSECURE_MODEL_FILES") + or "false") + if trust_insecure.lower() != "true": + raise ValueError( + "option.trust_insecure_model_files must be set to 'true' to use unsafe formats (only json/ubj are secure by default)" + ) + if model_format == "pickle": + with open(model_file, 'rb') as f: + self.model = pkl.load(f) + else: # xgb format + self.model = xgb.Booster() + self.model.load_model(model_file) + + self.custom_input_formatter = get_annotated_function( + model_dir, "is_input_formatter") + self.custom_output_formatter = get_annotated_function( + model_dir, "is_output_formatter") + self.custom_predict_formatter = get_annotated_function( + model_dir, "is_predict_formatter") + + self.initialized = True + + def inference(self, inputs: Input) -> Output: + content_type = inputs.get_property("Content-Type") + accept = inputs.get_property("Accept") or "application/json" + + # Validate accept type (skip validation if custom output formatter is provided) + if not self.custom_output_formatter: + supported_accept_types = ["application/json", "text/csv"] + if not any(supported_type in accept + for supported_type in supported_accept_types): + raise ValueError( + f"Unsupported Accept type: {accept}. Supported types: {supported_accept_types}" + ) + + # Input processing + X = None + if self.custom_input_formatter: + X = self.custom_input_formatter(inputs) + elif "text/csv" in content_type: + X = decode(inputs, content_type, require_csv_headers=False) + else: + input_map = decode(inputs, content_type) + data = input_map.get("inputs") if isinstance(input_map, + dict) else input_map + X = np.array(data) + + if X is None or not hasattr(X, 'ndim'): + raise ValueError( + f"Input processing failed for content type {content_type}") + + if X.ndim == 1: + X = X.reshape(1, -1) + if self.custom_predict_formatter: + predictions = self.custom_predict_formatter(self.model, X) + else: + dmatrix = xgb.DMatrix(X) + predictions = self.model.predict(dmatrix, validate_features=False) + + # Output processing + if self.custom_output_formatter: + return self.custom_output_formatter(predictions) + + # Supports CSV/JSON outputs by default + outputs = Output() + if "text/csv" in accept: + csv_buffer = StringIO() + np.savetxt(csv_buffer, predictions, fmt='%s', delimiter=',') + outputs.add(csv_buffer.getvalue().rstrip()) + outputs.add_property("Content-Type", "text/csv") + else: + outputs.add_as_json({"predictions": predictions.tolist()}) + return outputs + + +service = XGBoostHandler() + + +def handle(inputs: Input) -> Optional[Output]: + if not service.initialized: + service.initialize(inputs.get_properties()) + + if inputs.is_empty(): + return None + + return service.inference(inputs) diff --git a/serving/docker/Dockerfile b/serving/docker/Dockerfile index ddaf6e533..0851df618 100644 --- a/serving/docker/Dockerfile +++ b/serving/docker/Dockerfile @@ -64,21 +64,27 @@ LABEL djl-serving-version=$djl_serving_version FROM base AS cpu-full -ARG python_version=3.10 +ARG python_version=3.11 ARG torch_version=2.7.1 ARG onnx_version=1.20.0 +ARG sklearn_version=1.7.2 +ARG xgboost_version=3.0.5 +ARG pydantic_version=2.12.2 ENV PYTORCH_LIBRARY_PATH=/usr/local/lib/python${python_version}/dist-packages/torch/lib ENV PYTORCH_VERSION=${torch_version} ENV PYTORCH_FLAVOR=cpu COPY scripts scripts/ -RUN scripts/install_python.sh && \ +RUN scripts/install_python.sh ${python_version} && \ scripts/install_djl_serving.sh $djl_version $djl_serving_version $torch_version && \ djl-serving -i ai.djl.onnxruntime:onnxruntime-engine:$djl_version && \ djl-serving -i com.microsoft.onnxruntime:onnxruntime:$onnx_version && \ + djl-serving -i ai.djl:basicdataset:$djl_version && \ scripts/patch_oss_dlc.sh python && \ pip3 install torch=="${torch_version}" torchvision --extra-index-url https://download.pytorch.org/whl/cpu && \ + pip3 install scikit-learn=="${sklearn_version}" skops cloudpickle xgboost=="${xgboost_version}" pydantic=="${pydantic_version}" && \ + pip3 install --upgrade numpy && \ echo "${djl_serving_version} cpufull" > /opt/djl/bin/telemetry && \ rm -rf /opt/djl/logs && \ chown -R djl:djl /opt/djl && \ diff --git a/serving/docker/scripts/install_python.sh b/serving/docker/scripts/install_python.sh index 3ab0b2d33..f8f93d095 100755 --- a/serving/docker/scripts/install_python.sh +++ b/serving/docker/scripts/install_python.sh @@ -9,7 +9,7 @@ apt-get update if [ -z "$PYTHON_VERSION" ]; then DEBIAN_FRONTEND=noninteractive apt-get install -yq python3-dev python3-pip python3-venv git else - DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends curl software-properties-common git + DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends curl software-properties-common git gnupg add-apt-repository -y ppa:deadsnakes/ppa apt-get autoremove -y python3 python_minor_version=$(echo "$PYTHON_VERSION" | awk -F"." '{print $2}') diff --git a/tests/integration/.gitignore b/tests/integration/.gitignore index 65e4bde6c..a5ecb0643 100644 --- a/tests/integration/.gitignore +++ b/tests/integration/.gitignore @@ -2,3 +2,4 @@ /logs /all_logs /models +/client_logs diff --git a/tests/integration/download_models.sh b/tests/integration/download_models.sh index 468e9e1ce..03afd0932 100755 --- a/tests/integration/download_models.sh +++ b/tests/integration/download_models.sh @@ -26,6 +26,25 @@ inf2_models_urls=( "https://resources.djl.ai/test-models/pytorch/resnet18_no_reqs_inf2_2_4.tar.gz" ) +python_skl_models_urls=( + "https://resources.djl.ai/test-models/python/sklearn/sklearn_model_v2.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_joblib_model_v2.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_cloudpickle_model_v2.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_skops_model_v2.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_multi_model_v2.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_unsafe_model_v2.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_custom_model_v2.zip" + "https://resources.djl.ai/test-models/python/sklearn/sklearn_skops_model_env_v2.zip" +) + +python_xgb_models_urls=( + "https://resources.djl.ai/test-models/python/xgboost/xgboost_model_v2.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_ubj_model_v2.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_deprecated_model_v2.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_unsafe_model_v2.zip" + "https://resources.djl.ai/test-models/python/xgboost/xgboost_custom_model_v2.zip" +) + download() { urls=("$@") for url in "${urls[@]}"; do @@ -38,8 +57,13 @@ download() { } case $platform in -cpu | cpu-full | pytorch-gpu) +cpu | pytorch-gpu) + download "${general_platform_models_urls[@]}" + ;; +cpu-full) download "${general_platform_models_urls[@]}" + download "${python_skl_models_urls[@]}" + download "${python_xgb_models_urls[@]}" ;; pytorch-inf2) download "${inf2_models_urls[@]}" diff --git a/tests/integration/test_xgb_skl.py b/tests/integration/test_xgb_skl.py new file mode 100644 index 000000000..f9f675539 --- /dev/null +++ b/tests/integration/test_xgb_skl.py @@ -0,0 +1,705 @@ +#!/usr/bin/env python3 + +import os +import pytest +import requests +import json +from tests import Runner + + +@pytest.mark.cpu +class TestXgbSkl: + + # Basic model tests + def test_sklearn_model(self): + with Runner('cpu-full', 'sklearn_model', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "predictions" in result + assert len(result["predictions"]) == 1 + + def test_xgboost_model(self): + with Runner('cpu-full', 'xgboost_model', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_test::Python=file:/opt/ml/model/xgboost_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_test", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "predictions" in result + assert len(result["predictions"]) == 1 + + # CSV input/output tests + def test_sklearn_csv_input(self): + with Runner('cpu-full', 'sklearn_csv', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + csv_data = "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0" + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + data=csv_data, + headers={ + "Content-Type": "text/csv", + "Accept": "text/csv" + }) + assert response.status_code == 200 + + def test_xgboost_csv_input(self): + with Runner('cpu-full', 'xgboost_csv', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_test::Python=file:/opt/ml/model/xgboost_model_v2.zip" + ) + csv_data = "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0" + response = requests.post( + "http://localhost:8080/predictions/xgboost_test", + data=csv_data, + headers={ + "Content-Type": "text/csv", + "Accept": "text/csv" + }) + assert response.status_code == 200 + + def test_sklearn_json_input_csv_output(self): + with Runner('cpu-full', 'sklearn_json_csv', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "text/csv" + }) + assert response.status_code == 200 + assert "text/csv" in response.headers.get("Content-Type", "") + + def test_sklearn_csv_input_json_output(self): + with Runner('cpu-full', 'sklearn_csv_json', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + csv_data = "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0" + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + data=csv_data, + headers={ + "Content-Type": "text/csv", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "predictions" in result + + def test_xgboost_csv_input_json_output(self): + with Runner('cpu-full', 'xgboost_csv_json', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_test::Python=file:/opt/ml/model/xgboost_model_v2.zip" + ) + csv_data = "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0" + response = requests.post( + "http://localhost:8080/predictions/xgboost_test", + data=csv_data, + headers={ + "Content-Type": "text/csv", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "predictions" in result + + def test_xgboost_json_input_csv_output(self): + with Runner('cpu-full', 'xgboost_json_csv', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_test::Python=file:/opt/ml/model/xgboost_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_test", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "text/csv" + }) + assert response.status_code == 200 + assert "text/csv" in response.headers.get("Content-Type", "") + + # Model format tests + def test_sklearn_joblib_format(self): + with Runner('cpu-full', 'sklearn_joblib', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_joblib::Python=file:/opt/ml/model/sklearn_joblib_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_joblib", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "predictions" in result + + def test_sklearn_cloudpickle_format(self): + with Runner('cpu-full', 'sklearn_cloudpickle', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_cloudpickle::Python=file:/opt/ml/model/sklearn_cloudpickle_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_cloudpickle", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "predictions" in result + + def test_xgboost_ubj_format(self): + with Runner('cpu-full', 'xgboost_ubj', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_ubj::Python=file:/opt/ml/model/xgboost_ubj_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_ubj", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "predictions" in result + + def test_xgboost_deprecated_format(self): + with Runner('cpu-full', 'xgboost_deprecated', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_deprecated::Python=file:/opt/ml/model/xgboost_deprecated_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_deprecated", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "predictions" in result + + # Custom formatter tests + def test_sklearn_custom_formatters(self): + with Runner('cpu-full', 'sklearn_custom', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_custom::Python=file:/opt/ml/model/sklearn_custom_model_v2.zip" + ) + test_data = { + "features": + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_custom", + data=json.dumps(test_data), + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "result" in result + assert "confidence" in result + assert "model_type" in result + assert result["model_type"] == "sklearn_custom" + assert result["confidence"] == 0.95 + + def test_xgboost_custom_formatters(self): + with Runner('cpu-full', 'xgboost_custom', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_custom::Python=file:/opt/ml/model/xgboost_custom_model_v2.zip" + ) + test_data = { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_custom", + data=json.dumps(test_data), + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "probability" in result + assert "prediction" in result + assert "model_version" in result + assert "processed_by" in result + assert result["model_version"] == "1.0" + assert result["processed_by"] == "xgboost_custom" + assert isinstance(result["probability"], float) + assert result["prediction"] in [0, 1] + + # Error handling tests - CSV format errors + def test_sklearn_csv_with_headers(self): + with Runner('cpu-full', 'sklearn_csv_headers', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + csv_data = "f1,f2,f3,f4,f5,f6,f7,f8,f9,f10\n1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0" + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + data=csv_data, + headers={ + "Content-Type": "text/csv", + "Accept": "text/csv" + }) + assert response.status_code == 424 + + def test_sklearn_ragged_csv(self): + with Runner('cpu-full', 'sklearn_ragged_csv', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + csv_data = "1.0,2.0,3.0,4.0,5.0\n6.0,7.0,8.0,9.0,10.0,11.0,12.0" + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + data=csv_data, + headers={ + "Content-Type": "text/csv", + "Accept": "text/csv" + }) + assert response.status_code == 424 + + def test_sklearn_empty_rows_csv(self): + with Runner('cpu-full', 'sklearn_empty_csv', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + csv_data = "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0\n\n1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0" + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + data=csv_data, + headers={ + "Content-Type": "text/csv", + "Accept": "text/csv" + }) + assert response.status_code == 200 # Skips empty rows + + def test_sklearn_non_numeric_csv(self): + with Runner('cpu-full', 'sklearn_non_numeric_csv', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + csv_data = "1.0,2.0,abc,4.0,5.0,6.0,7.0,8.0,9.0,10.0" + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + data=csv_data, + headers={ + "Content-Type": "text/csv", + "Accept": "text/csv" + }) + assert response.status_code == 424 + + def test_xgboost_csv_with_headers(self): + with Runner('cpu-full', 'xgboost_csv_headers', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_test::Python=file:/opt/ml/model/xgboost_model_v2.zip" + ) + csv_data = "f1,f2,f3,f4,f5,f6,f7,f8,f9,f10\n1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0" + response = requests.post( + "http://localhost:8080/predictions/xgboost_test", + data=csv_data, + headers={ + "Content-Type": "text/csv", + "Accept": "text/csv" + }) + assert response.status_code == 424 + + def test_xgboost_ragged_csv(self): + with Runner('cpu-full', 'xgboost_ragged_csv', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_test::Python=file:/opt/ml/model/xgboost_model_v2.zip" + ) + csv_data = "1.0,2.0,3.0,4.0,5.0\n6.0,7.0,8.0,9.0,10.0,11.0,12.0" + response = requests.post( + "http://localhost:8080/predictions/xgboost_test", + data=csv_data, + headers={ + "Content-Type": "text/csv", + "Accept": "text/csv" + }) + assert response.status_code == 424 + + def test_xgboost_empty_rows_csv(self): + with Runner('cpu-full', 'xgboost_empty_csv', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_test::Python=file:/opt/ml/model/xgboost_model_v2.zip" + ) + csv_data = "1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0\n\n1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0" + response = requests.post( + "http://localhost:8080/predictions/xgboost_test", + data=csv_data, + headers={ + "Content-Type": "text/csv", + "Accept": "text/csv" + }) + assert response.status_code == 200 + + def test_xgboost_non_numeric_csv(self): + with Runner('cpu-full', 'xgboost_non_numeric_csv', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_test::Python=file:/opt/ml/model/xgboost_model_v2.zip" + ) + csv_data = "1.0,2.0,abc,4.0,5.0,6.0,7.0,8.0,9.0,10.0" + response = requests.post( + "http://localhost:8080/predictions/xgboost_test", + data=csv_data, + headers={ + "Content-Type": "text/csv", + "Accept": "text/csv" + }) + assert response.status_code == 424 + + # Error handling tests - Input shape errors + def test_sklearn_wrong_input_shape(self): + with Runner('cpu-full', 'sklearn_wrong_shape', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + test_data = {"inputs": [[1.0, 2.0, 3.0, 4.0, 5.0]]} + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + json=test_data, + headers={"Content-Type": "application/json"}) + assert response.status_code == 424 + + def test_sklearn_ragged_arrays(self): + with Runner('cpu-full', 'sklearn_ragged', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [1.0, 2.0, 3.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 424 + + def test_xgboost_wrong_input_shape(self): + with Runner('cpu-full', 'xgboost_wrong_shape', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_test::Python=file:/opt/ml/model/xgboost_model_v2.zip" + ) + test_data = {"inputs": [[1.0, 2.0, 3.0, 4.0, 5.0]]} + response = requests.post( + "http://localhost:8080/predictions/xgboost_test", + json=test_data, + headers={"Content-Type": "application/json"}) + assert response.status_code == 424 + + def test_xgboost_ragged_arrays(self): + with Runner('cpu-full', 'xgboost_ragged', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_test::Python=file:/opt/ml/model/xgboost_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [1.0, 2.0, 3.0]] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_test", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 424 + + # Error handling tests - Content type errors + def test_sklearn_invalid_accept_type(self): + with Runner('cpu-full', 'sklearn_unsupported_accept', + download=True) as r: + r.launch( + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "text/xml" + }) + assert response.status_code == 424 + + def test_sklearn_invalid_content_type(self): + with Runner('cpu-full', 'sklearn_invalid_content', download=True) as r: + r.launch( + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + json=test_data, + headers={ + "Content-Type": "application/xml", + "Accept": "application/json" + }) + assert response.status_code == 424 + + def test_xgboost_invalid_accept_type(self): + with Runner('cpu-full', 'xgboost_invalid_accept', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_test::Python=file:/opt/ml/model/xgboost_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_test", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/xml" + }) + assert response.status_code == 424 + + def test_xgboost_invalid_content_type(self): + with Runner('cpu-full', 'xgboost_invalid_content', download=True) as r: + r.launch( + cmd= + "serve -m xgboost_test::Python=file:/opt/ml/model/xgboost_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_test", + json=test_data, + headers={ + "Content-Type": "application/xml", + "Accept": "application/json" + }) + assert response.status_code == 424 + + # Error handling tests - Configuration errors + def test_multiple_artifacts(self): + with Runner('cpu-full', 'sklearn_multi_artifacts', download=True) as r: + try: + r.launch( + cmd= + "serve -m sklearn_multi::Python=file:/opt/ml/model/sklearn_multi_model_v2.zip" + ) + test_data = { + "inputs": + [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_multi", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code != 200 + except Exception: + pass + + def test_sklearn_bad_env_variable(self): + with Runner('cpu-full', 'sklearn_bad_env', download=True) as r: + env = ["OPTION_SKOPS_TRUSTED_TYPES=invalid_type"] + try: + r.launch( + env_vars=env, + cmd= + "serve -m sklearn_test::Python=file:/opt/ml/model/sklearn_skops_model_v2.zip" + ) + test_data = { + "inputs": + [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_test", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code != 200 + except Exception: + pass + + def test_sklearn_skops_with_valid_trusted_types(self): + with Runner('cpu-full', 'sklearn_skops_valid', download=True) as r: + env = [ + "OPTION_SKOPS_TRUSTED_TYPES=sklearn.ensemble._forest.RandomForestClassifier" + ] + r.launch( + env_vars=env, + cmd= + "serve -m sklearn_skops::Python=file:/opt/ml/model/sklearn_skops_model_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_skops", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "predictions" in result + + # Security tests + def test_sklearn_unsafe_format_without_trust(self): + with Runner('cpu-full', 'sklearn_unsafe', download=True) as r: + try: + r.launch( + cmd= + "serve -m sklearn_unsafe::Python=file:/opt/ml/model/sklearn_unsafe_model_v2.zip" + ) + test_data = { + "inputs": + [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_unsafe", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code != 200 + except Exception: + pass + + def test_xgboost_unsafe_format_without_trust(self): + with Runner('cpu-full', 'xgboost_unsafe', download=True) as r: + try: + r.launch( + cmd= + "serve -m xgboost_unsafe::Python=file:/opt/ml/model/xgboost_unsafe_model_v2.zip" + ) + test_data = { + "inputs": + [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/xgboost_unsafe", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code != 200 + except Exception: + pass + + def test_sklearn_skops_env_variables_only(self): + with Runner('cpu-full', 'sklearn_skops_env_only', download=True) as r: + env = [ + "OPTION_MODEL_FORMAT=skops", + "OPTION_TRUST_INSECURE_MODEL_FILES=true", + "OPTION_SKOPS_TRUSTED_TYPES=sklearn.ensemble._forest.RandomForestClassifier" + ] + r.launch( + env_vars=env, + cmd= + "serve -m sklearn_skops_env::Python=file:/opt/ml/model/sklearn_skops_model_env_v2.zip" + ) + test_data = { + "inputs": [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]] + } + response = requests.post( + "http://localhost:8080/predictions/sklearn_skops_env", + json=test_data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json" + }) + assert response.status_code == 200 + result = response.json() + assert "predictions" in result diff --git a/tests/integration/tests.py b/tests/integration/tests.py index dfaabe1be..64addde87 100644 --- a/tests/integration/tests.py +++ b/tests/integration/tests.py @@ -4,6 +4,8 @@ import subprocess import logging import pytest +import requests +import json import llm.prepare as prepare import llm.client as client import test_client