diff --git a/engines/python/setup/djl_python/ts_service_loader.py b/engines/python/setup/djl_python/ts_service_loader.py index b84f96c74..d7b57aa17 100644 --- a/engines/python/setup/djl_python/ts_service_loader.py +++ b/engines/python/setup/djl_python/ts_service_loader.py @@ -39,7 +39,7 @@ def invoke_handler(self, function_name, inputs): for k, v in inputs.get_properties().items(): header = dict() header["name"] = k.encode("utf-8") - header["value"] = v.encode("utf-8") + header["value"] = str(v).encode("utf-8") request["headers"].append(header) content = inputs.get_content() diff --git a/serving/src/main/java/ai/djl/serving/util/ModelStore.java b/serving/src/main/java/ai/djl/serving/util/ModelStore.java index 8f2304372..d22155969 100644 --- a/serving/src/main/java/ai/djl/serving/util/ModelStore.java +++ b/serving/src/main/java/ai/djl/serving/util/ModelStore.java @@ -15,7 +15,6 @@ import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.repository.FilenameUtils; -import ai.djl.serving.ModelServer; import ai.djl.serving.models.ModelManager; import ai.djl.serving.wlm.ModelInfo; import ai.djl.serving.workflow.BadWorkflowException; @@ -49,7 +48,7 @@ /** A class represent model server's model store. */ public final class ModelStore { - private static final Logger logger = LoggerFactory.getLogger(ModelServer.class); + private static final Logger logger = LoggerFactory.getLogger(ModelStore.class); private static final Pattern MODEL_STORE_PATTERN = Pattern.compile("(\\[?([^?]+?)]?=)?(.+)"); private static final ModelStore INSTANCE = new ModelStore(); @@ -100,13 +99,7 @@ public void initialize() throws IOException, BadWorkflowException { // contains only directory or archive files boolean isMultiModelsDirectory; try (Stream stream = Files.list(modelStore)) { - isMultiModelsDirectory = - stream.filter(p -> !p.getFileName().toString().startsWith(".")) - .allMatch( - p -> - Files.isDirectory(p) - || FilenameUtils.isArchiveFile( - p.toString())); + isMultiModelsDirectory = stream.allMatch(ModelStore::isModel); } if (isMultiModelsDirectory) { @@ -208,10 +201,7 @@ public List getWorkflows() { */ public static String mapModelUrl(Path path) { try { - if (!Files.exists(path) - || Files.isHidden(path) - || (!Files.isDirectory(path) - && !FilenameUtils.isArchiveFile(path.toString()))) { + if (!isModel(path)) { return null; } try (Stream stream = Files.list(path)) { @@ -233,6 +223,15 @@ public static String mapModelUrl(Path path) { } } + private static boolean isModel(Path path) { + String fileName = Objects.requireNonNull(path.getFileName()).toString(); + if (fileName.startsWith(".")) { + return false; + } + return Files.exists(path) + && (Files.isDirectory(path) || FilenameUtils.isArchiveFile(fileName)); + } + private String createHuggingFaceModel(String modelId) throws IOException { if (modelId.startsWith("djl://") || modelId.startsWith("s3://")) { return modelId; diff --git a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java index e102faa12..5f7b064cc 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -40,6 +40,7 @@ import ai.djl.translate.TranslateException; import ai.djl.util.NeuronUtils; import ai.djl.util.Utils; +import ai.djl.util.ZipUtils; import ai.djl.util.cuda.CudaUtils; import org.slf4j.Logger; @@ -65,6 +66,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; import java.util.stream.Stream; /** A class represent a loaded model and it's metadata. */ @@ -764,6 +766,8 @@ private String inferEngine() throws ModelException { String groupId = mrl.getGroupId(); ModelZoo zoo = ModelZoo.getModelZoo(groupId); return zoo.getSupportedEngines().iterator().next(); + } else if (isTorchServeModel()) { + return "Python"; } else if (Files.isRegularFile(modelDir.resolve(prefix + ".pt")) || Files.isRegularFile(modelDir.resolve("model.pt"))) { return "PyTorch"; @@ -781,8 +785,6 @@ private String inferEngine() throws ModelException { || Files.isRegularFile(modelDir.resolve("model.bst")) || Files.isRegularFile(modelDir.resolve("model.xgb"))) { return "XGBoost"; - } else if (Files.isRegularFile(modelDir.resolve(prefix + ".gguf"))) { - return "Llama"; } else if (isPythonModel(prefix)) { // TODO: How to differentiate Rust model from Python return "Python"; @@ -811,26 +813,60 @@ private boolean isPythonModel(String prefix) { return Files.isRegularFile(modelDir.resolve("model.py")) || Files.isRegularFile(modelDir.resolve(prefix + ".py")) || prop.getProperty("option.model_id") != null - || Files.isRegularFile(modelDir.resolve("config.json")) - || isTorchServeModel(); + || Files.isRegularFile(modelDir.resolve("config.json")); } private void downloadModel() throws ModelException, IOException { if (resolvedModelUrl.startsWith("s3://")) { modelDir = downloadS3ToDownloadDir(resolvedModelUrl); resolvedModelUrl = modelDir.toUri().toURL().toString(); - return; - } - Repository repository = Repository.newInstance("modelStore", resolvedModelUrl); - List mrls = repository.getResources(); - if (mrls.isEmpty()) { - throw new ModelNotFoundException("Invalid model url: " + resolvedModelUrl); + } else { + Repository repository = Repository.newInstance("modelStore", resolvedModelUrl); + List mrls = repository.getResources(); + if (mrls.isEmpty()) { + throw new ModelNotFoundException("Invalid model url: " + resolvedModelUrl); + } + + Artifact artifact = mrls.get(0).getDefaultArtifact(); + repository.prepare(artifact); + modelDir = Utils.getNestedModelDir(repository.getResourceDirectory(artifact)); + artifactName = artifact.getName(); + if (Files.isRegularFile(modelDir)) { + modelDir = modelDir.getParent(); + return; + } } - Artifact artifact = mrls.get(0).getDefaultArtifact(); - repository.prepare(artifact); - modelDir = Utils.getNestedModelDir(repository.getResourceDirectory(artifact)); - artifactName = artifact.getName(); + try (Stream stream = Files.list(modelDir)) { + List list = stream.collect(Collectors.toList()); + if (list.size() == 1) { + Path match = list.get(0); + String name = Objects.requireNonNull(match.getFileName()).toString(); + String type = FilenameUtils.getFileType(name); + if ("zip".equals(type)) { + String hash = Utils.hash(match.toAbsolutePath().toString()); + String download = Utils.getenv("SERVING_DOWNLOAD_DIR", null); + Path parent = download == null ? Utils.getCacheDir() : Paths.get(download); + parent = parent.resolve("download"); + Path extracted = parent.resolve(hash); + if (Files.exists(extracted)) { + logger.info("archive already extracted: {}", extracted); + } else { + Files.createDirectories(parent); + Path tmp = Files.createTempDirectory(parent, "tmp"); + try (InputStream is = Files.newInputStream(match)) { + ZipUtils.unzip(is, tmp); + Utils.moveQuietly(tmp, extracted); + logger.info("Archive file extracted to {}", extracted); + } finally { + Utils.deleteQuietly(tmp); + } + } + modelDir = extracted; + resolvedModelUrl = modelDir.toUri().toURL().toString(); + } + } + } } private void loadServingProperties() {