Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions engines/python/setup/djl_python/encode_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,30 @@
import numpy as np


def decode_csv(inputs: Input): # type: (str) -> np.array
def decode_csv(inputs: Input, require_headers=True): # type: (str) -> np.array
csv_content = inputs.get_as_string()
stream = StringIO(csv_content)
# detects if the incoming csv has headers
if not any(header in csv_content.splitlines()[0].lower()
for header in ["question", "context", "inputs"]):
raise ValueError(
"You need to provide the correct CSV with Header columns to use it with the inference toolkit default handler.",
)
# reads csv as io
request_list = list(csv.DictReader(stream))
if "inputs" in request_list[0].keys():
return {"inputs": [entry["inputs"] for entry in request_list]}

if require_headers:
if not any(header in csv_content.splitlines()[0].lower()
for header in ["question", "context", "inputs"]):
raise ValueError(
"You need to provide the correct CSV with Header columns to use it with the inference toolkit default handler.",
)
stream = StringIO(csv_content)
request_list = list(csv.DictReader(stream))
if "inputs" in request_list[0].keys():
return {"inputs": [entry["inputs"] for entry in request_list]}
else:
return {"inputs": request_list}
else:
return {"inputs": request_list}
# for preditive ML inputs
result = np.genfromtxt(StringIO(csv_content), delimiter=",")
# Check for NaN values which indicate non-numeric data
if np.isnan(result).any():
raise ValueError(
"CSV contains non-numeric data. Please provide numeric data only."
)
return result


def encode_csv(content): # type: (str) -> np.array
Expand All @@ -51,7 +60,10 @@ def encode_csv(content): # type: (str) -> np.array
return stream.getvalue()


def decode(inputs: Input, content_type: str, key=None):
def decode(inputs: Input,
content_type: str,
key=None,
require_csv_headers=True):
if not content_type:
ret = inputs.get_as_bytes(key=key)
if not ret:
Expand All @@ -60,7 +72,7 @@ def decode(inputs: Input, content_type: str, key=None):
elif "application/json" in content_type:
return inputs.get_as_json(key=key)
elif "text/csv" in content_type:
return decode_csv(inputs)
return decode_csv(inputs, require_headers=require_csv_headers)
elif "text/plain" in content_type:
return {"inputs": [inputs.get_as_string(key=key)]}
if content_type.startswith("image/"):
Expand Down
68 changes: 68 additions & 0 deletions engines/python/setup/djl_python/import_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/env python
#
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import importlib.util
import importlib.metadata


def _is_package_available(pkg_name: str) -> bool:
"""Check if a package is available"""
package_exists = importlib.util.find_spec(pkg_name) is not None
if package_exists:
try:
importlib.metadata.version(pkg_name)
except importlib.metadata.PackageNotFoundError:
package_exists = False
return package_exists


# SKLearn model persistance libraries
_joblib_available = _is_package_available("joblib")
_cloudpickle_available = _is_package_available("cloudpickle")
_skops_available = _is_package_available("skops")

# XGBoost
_xgboost_available = _is_package_available("xgboost")


def is_joblib_available() -> bool:
return _joblib_available


def is_cloudpickle_available() -> bool:
return _cloudpickle_available


def is_skops_available() -> bool:
return _skops_available


def is_xgboost_available() -> bool:
return _xgboost_available


joblib = None
if _joblib_available:
import joblib

cloudpickle = None
if _cloudpickle_available:
import cloudpickle

skops_io = None
if _skops_available:
import skops.io as skops_io

xgboost = None
if _xgboost_available:
import xgboost
162 changes: 162 additions & 0 deletions engines/python/setup/djl_python/sklearn_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#!/usr/bin/env python
#
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import pickle
import numpy as np
import os
from io import StringIO
from typing import Optional
from djl_python import Input, Output
from djl_python.encode_decode import decode
from djl_python.utils import find_model_file
from djl_python.service_loader import get_annotated_function
from djl_python.import_utils import joblib, cloudpickle, skops_io as sio


class SklearnHandler:

def __init__(self):
self.model = None
self.initialized = False
self.custom_input_formatter = None
self.custom_output_formatter = None
self.custom_predict_formatter = None

def _get_trusted_types(self, properties: dict):
trusted_types_str = properties.get("skops_trusted_types", "")
if not trusted_types_str:
raise ValueError(
"option.skops_trusted_types must be set to load skops models. "
"Example: option.skops_trusted_types='sklearn.ensemble._forest.RandomForestClassifier,numpy.ndarray'"
)
trusted_types = [
t.strip() for t in trusted_types_str.split(",") if t.strip()
]
print(f"Using trusted types for skops model loading: {trusted_types}")
return trusted_types

def initialize(self, properties: dict):
model_dir = properties.get("model_dir")
model_format = properties.get("model_format", "skops")

format_extensions = {
"skops": ["skops"],
"joblib": ["joblib", "jl"],
"pickle": ["pkl", "pickle"],
"cloudpickle": ["pkl", "pickle", "cloudpkl"]
}

extensions = format_extensions.get(model_format)
if not extensions:
raise ValueError(
f"Unsupported model format: {model_format}. Supported formats: skops, joblib, pickle, cloudpickle"
)

model_file = find_model_file(model_dir, extensions)
if not model_file:
raise FileNotFoundError(
f"No model file found with format '{model_format}' in {model_dir}"
)

if model_format == "skops":
trusted_types = self._get_trusted_types(properties)
self.model = sio.load(model_file, trusted=trusted_types)
else:
if properties.get("trust_insecure_model_files",
"false").lower() != "true":
raise ValueError(
f"option.trust_insecure_model_files must be set to 'true' to use {model_format} format (only skops is secure by default)"
)

if model_format == "joblib":
self.model = joblib.load(model_file)
elif model_format == "pickle":
with open(model_file, 'rb') as f:
self.model = pickle.load(f)
elif model_format == "cloudpickle":
with open(model_file, 'rb') as f:
self.model = cloudpickle.load(f)

self.custom_input_formatter = get_annotated_function(
model_dir, "is_input_formatter")
self.custom_output_formatter = get_annotated_function(
model_dir, "is_output_formatter")
self.custom_predict_formatter = get_annotated_function(
model_dir, "is_predict_formatter")

self.initialized = True

def inference(self, inputs: Input) -> Output:
content_type = inputs.get_property("Content-Type")
accept = inputs.get_property("Accept") or "application/json"

# Validate accept type (skip validation if custom output formatter is provided)
if not self.custom_output_formatter:
supported_accept_types = ["application/json", "text/csv"]
if not any(supported_type in accept
for supported_type in supported_accept_types):
raise ValueError(
f"Unsupported Accept type: {accept}. Supported types: {supported_accept_types}"
)

# Input processing
X = None
if self.custom_input_formatter:
X = self.custom_input_formatter(inputs)
elif "text/csv" in content_type:
X = decode(inputs, content_type, require_csv_headers=False)
else:
input_map = decode(inputs, content_type)
data = input_map.get("inputs") if isinstance(input_map,
dict) else input_map
X = np.array(data)

if X is None or not hasattr(X, 'ndim'):
raise ValueError(
f"Input processing failed for content type {content_type}")

if X.ndim == 1:
X = X.reshape(1, -1)

if self.custom_predict_formatter:
predictions = self.custom_predict_formatter(self.model, X)
else:
predictions = self.model.predict(X)

# Output processing
if self.custom_output_formatter:
return self.custom_output_formatter(predictions)

# Supports CSV/JSON outputs by default
outputs = Output()
if "text/csv" in accept:
csv_buffer = StringIO()
np.savetxt(csv_buffer, predictions, fmt='%s', delimiter=',')
outputs.add(csv_buffer.getvalue().rstrip())
outputs.add_property("Content-Type", "text/csv")
else:
outputs.add_as_json({"predictions": predictions.tolist()})
return outputs


service = SklearnHandler()


def handle(inputs: Input) -> Optional[Output]:
if not service.initialized:
service.initialize(inputs.get_properties())

if inputs.is_empty():
return None

return service.inference(inputs)
27 changes: 27 additions & 0 deletions engines/python/setup/djl_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import glob
import logging
import os
from typing import Optional, List

from djl_python import Output
from djl_python.inputs import Input
Expand Down Expand Up @@ -161,3 +164,27 @@ def get_input_details(requests, errors, batch):
idx += 1
adapters = adapters if adapters else None
return input_data, input_size, parameters, adapters


def find_model_file(model_dir: str, extensions: List[str]) -> Optional[str]:
"""Find model file with given extensions in model directory

Args:
model_dir: Directory to search for model files
extensions: List of file extensions to search for (without dots)

Returns:
Path to matching model file, or None if not found
"""
all_matches = []
for ext in extensions:
pattern = os.path.join(model_dir, f"*.{ext}")
matches = glob.glob(pattern)
all_matches.extend(matches)

if len(all_matches) > 1:
raise ValueError(
f"Multiple model files found in {model_dir}: {all_matches}. Only one model file is supported per directory."
)

return all_matches[0] if all_matches else None
Loading
Loading