Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/ts_service_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def invoke_handler(self, function_name, inputs):
for k, v in inputs.get_properties().items():
header = dict()
header["name"] = k.encode("utf-8")
header["value"] = v.encode("utf-8")
header["value"] = str(v).encode("utf-8")
request["headers"].append(header)

content = inputs.get_content()
Expand Down
25 changes: 12 additions & 13 deletions serving/src/main/java/ai/djl/serving/util/ModelStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.repository.FilenameUtils;
import ai.djl.serving.ModelServer;
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.workflow.BadWorkflowException;
Expand Down Expand Up @@ -49,7 +48,7 @@
/** A class represent model server's model store. */
public final class ModelStore {

private static final Logger logger = LoggerFactory.getLogger(ModelServer.class);
private static final Logger logger = LoggerFactory.getLogger(ModelStore.class);
private static final Pattern MODEL_STORE_PATTERN = Pattern.compile("(\\[?([^?]+?)]?=)?(.+)");

private static final ModelStore INSTANCE = new ModelStore();
Expand Down Expand Up @@ -100,13 +99,7 @@ public void initialize() throws IOException, BadWorkflowException {
// contains only directory or archive files
boolean isMultiModelsDirectory;
try (Stream<Path> stream = Files.list(modelStore)) {
isMultiModelsDirectory =
stream.filter(p -> !p.getFileName().toString().startsWith("."))
.allMatch(
p ->
Files.isDirectory(p)
|| FilenameUtils.isArchiveFile(
p.toString()));
isMultiModelsDirectory = stream.allMatch(ModelStore::isModel);
}

if (isMultiModelsDirectory) {
Expand Down Expand Up @@ -208,10 +201,7 @@ public List<Workflow> getWorkflows() {
*/
public static String mapModelUrl(Path path) {
try {
if (!Files.exists(path)
|| Files.isHidden(path)
|| (!Files.isDirectory(path)
&& !FilenameUtils.isArchiveFile(path.toString()))) {
if (!isModel(path)) {
return null;
}
try (Stream<Path> stream = Files.list(path)) {
Expand All @@ -233,6 +223,15 @@ public static String mapModelUrl(Path path) {
}
}

private static boolean isModel(Path path) {
String fileName = Objects.requireNonNull(path.getFileName()).toString();
if (fileName.startsWith(".")) {
return false;
}
return Files.exists(path)
&& (Files.isDirectory(path) || FilenameUtils.isArchiveFile(fileName));
}

private String createHuggingFaceModel(String modelId) throws IOException {
if (modelId.startsWith("djl://") || modelId.startsWith("s3://")) {
return modelId;
Expand Down
64 changes: 50 additions & 14 deletions wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import ai.djl.translate.TranslateException;
import ai.djl.util.NeuronUtils;
import ai.djl.util.Utils;
import ai.djl.util.ZipUtils;
import ai.djl.util.cuda.CudaUtils;

import org.slf4j.Logger;
Expand All @@ -65,6 +66,7 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/** A class represent a loaded model and it's metadata. */
Expand Down Expand Up @@ -764,6 +766,8 @@ private String inferEngine() throws ModelException {
String groupId = mrl.getGroupId();
ModelZoo zoo = ModelZoo.getModelZoo(groupId);
return zoo.getSupportedEngines().iterator().next();
} else if (isTorchServeModel()) {
return "Python";
} else if (Files.isRegularFile(modelDir.resolve(prefix + ".pt"))
|| Files.isRegularFile(modelDir.resolve("model.pt"))) {
return "PyTorch";
Expand All @@ -781,8 +785,6 @@ private String inferEngine() throws ModelException {
|| Files.isRegularFile(modelDir.resolve("model.bst"))
|| Files.isRegularFile(modelDir.resolve("model.xgb"))) {
return "XGBoost";
} else if (Files.isRegularFile(modelDir.resolve(prefix + ".gguf"))) {
return "Llama";
} else if (isPythonModel(prefix)) {
// TODO: How to differentiate Rust model from Python
return "Python";
Expand Down Expand Up @@ -811,26 +813,60 @@ private boolean isPythonModel(String prefix) {
return Files.isRegularFile(modelDir.resolve("model.py"))
|| Files.isRegularFile(modelDir.resolve(prefix + ".py"))
|| prop.getProperty("option.model_id") != null
|| Files.isRegularFile(modelDir.resolve("config.json"))
|| isTorchServeModel();
|| Files.isRegularFile(modelDir.resolve("config.json"));
}

private void downloadModel() throws ModelException, IOException {
if (resolvedModelUrl.startsWith("s3://")) {
modelDir = downloadS3ToDownloadDir(resolvedModelUrl);
resolvedModelUrl = modelDir.toUri().toURL().toString();
return;
}
Repository repository = Repository.newInstance("modelStore", resolvedModelUrl);
List<MRL> mrls = repository.getResources();
if (mrls.isEmpty()) {
throw new ModelNotFoundException("Invalid model url: " + resolvedModelUrl);
} else {
Repository repository = Repository.newInstance("modelStore", resolvedModelUrl);
List<MRL> mrls = repository.getResources();
if (mrls.isEmpty()) {
throw new ModelNotFoundException("Invalid model url: " + resolvedModelUrl);
}

Artifact artifact = mrls.get(0).getDefaultArtifact();
repository.prepare(artifact);
modelDir = Utils.getNestedModelDir(repository.getResourceDirectory(artifact));
artifactName = artifact.getName();
if (Files.isRegularFile(modelDir)) {
modelDir = modelDir.getParent();
return;
}
}

Artifact artifact = mrls.get(0).getDefaultArtifact();
repository.prepare(artifact);
modelDir = Utils.getNestedModelDir(repository.getResourceDirectory(artifact));
artifactName = artifact.getName();
try (Stream<Path> stream = Files.list(modelDir)) {
List<Path> list = stream.collect(Collectors.toList());
if (list.size() == 1) {
Path match = list.get(0);
String name = Objects.requireNonNull(match.getFileName()).toString();
String type = FilenameUtils.getFileType(name);
if ("zip".equals(type)) {
String hash = Utils.hash(match.toAbsolutePath().toString());
String download = Utils.getenv("SERVING_DOWNLOAD_DIR", null);
Path parent = download == null ? Utils.getCacheDir() : Paths.get(download);
parent = parent.resolve("download");
Path extracted = parent.resolve(hash);
if (Files.exists(extracted)) {
logger.info("archive already extracted: {}", extracted);
} else {
Files.createDirectories(parent);
Path tmp = Files.createTempDirectory(parent, "tmp");
try (InputStream is = Files.newInputStream(match)) {
ZipUtils.unzip(is, tmp);
Utils.moveQuietly(tmp, extracted);
logger.info("Archive file extracted to {}", extracted);
} finally {
Utils.deleteQuietly(tmp);
}
}
modelDir = extracted;
resolvedModelUrl = modelDir.toUri().toURL().toString();
}
}
}
}

private void loadServingProperties() {
Expand Down
Loading