Skip to content

Commit 64bc965

Browse files
committed
Support deploy .zip file from S3 bucket
1 parent 2a4a59f commit 64bc965

File tree

3 files changed

+63
-28
lines changed

3 files changed

+63
-28
lines changed

engines/python/setup/djl_python/ts_service_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def invoke_handler(self, function_name, inputs):
3939
for k, v in inputs.get_properties().items():
4040
header = dict()
4141
header["name"] = k.encode("utf-8")
42-
header["value"] = v.encode("utf-8")
42+
header["value"] = str(v).encode("utf-8")
4343
request["headers"].append(header)
4444

4545
content = inputs.get_content()

serving/src/main/java/ai/djl/serving/util/ModelStore.java

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import ai.djl.modality.Input;
1616
import ai.djl.modality.Output;
1717
import ai.djl.repository.FilenameUtils;
18-
import ai.djl.serving.ModelServer;
1918
import ai.djl.serving.models.ModelManager;
2019
import ai.djl.serving.wlm.ModelInfo;
2120
import ai.djl.serving.workflow.BadWorkflowException;
@@ -49,7 +48,7 @@
4948
/** A class represent model server's model store. */
5049
public final class ModelStore {
5150

52-
private static final Logger logger = LoggerFactory.getLogger(ModelServer.class);
51+
private static final Logger logger = LoggerFactory.getLogger(ModelStore.class);
5352
private static final Pattern MODEL_STORE_PATTERN = Pattern.compile("(\\[?([^?]+?)]?=)?(.+)");
5453

5554
private static final ModelStore INSTANCE = new ModelStore();
@@ -100,13 +99,7 @@ public void initialize() throws IOException, BadWorkflowException {
10099
// contains only directory or archive files
101100
boolean isMultiModelsDirectory;
102101
try (Stream<Path> stream = Files.list(modelStore)) {
103-
isMultiModelsDirectory =
104-
stream.filter(p -> !p.getFileName().toString().startsWith("."))
105-
.allMatch(
106-
p ->
107-
Files.isDirectory(p)
108-
|| FilenameUtils.isArchiveFile(
109-
p.toString()));
102+
isMultiModelsDirectory = stream.allMatch(ModelStore::isModel);
110103
}
111104

112105
if (isMultiModelsDirectory) {
@@ -208,10 +201,7 @@ public List<Workflow> getWorkflows() {
208201
*/
209202
public static String mapModelUrl(Path path) {
210203
try {
211-
if (!Files.exists(path)
212-
|| Files.isHidden(path)
213-
|| (!Files.isDirectory(path)
214-
&& !FilenameUtils.isArchiveFile(path.toString()))) {
204+
if (!isModel(path)) {
215205
return null;
216206
}
217207
try (Stream<Path> stream = Files.list(path)) {
@@ -233,6 +223,15 @@ public static String mapModelUrl(Path path) {
233223
}
234224
}
235225

226+
private static boolean isModel(Path path) {
227+
String fileName = Objects.requireNonNull(path.getFileName()).toString();
228+
if (fileName.startsWith(".")) {
229+
return false;
230+
}
231+
return Files.exists(path)
232+
&& (Files.isDirectory(path) || FilenameUtils.isArchiveFile(fileName));
233+
}
234+
236235
private String createHuggingFaceModel(String modelId) throws IOException {
237236
if (modelId.startsWith("djl://") || modelId.startsWith("s3://")) {
238237
return modelId;

wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import ai.djl.translate.TranslateException;
4141
import ai.djl.util.NeuronUtils;
4242
import ai.djl.util.Utils;
43+
import ai.djl.util.ZipUtils;
4344
import ai.djl.util.cuda.CudaUtils;
4445

4546
import org.slf4j.Logger;
@@ -65,6 +66,7 @@
6566
import java.util.concurrent.ConcurrentHashMap;
6667
import java.util.regex.Matcher;
6768
import java.util.regex.Pattern;
69+
import java.util.stream.Collectors;
6870
import java.util.stream.Stream;
6971

7072
/** A class represent a loaded model and it's metadata. */
@@ -764,6 +766,8 @@ private String inferEngine() throws ModelException {
764766
String groupId = mrl.getGroupId();
765767
ModelZoo zoo = ModelZoo.getModelZoo(groupId);
766768
return zoo.getSupportedEngines().iterator().next();
769+
} else if (isTorchServeModel()) {
770+
return "Python";
767771
} else if (Files.isRegularFile(modelDir.resolve(prefix + ".pt"))
768772
|| Files.isRegularFile(modelDir.resolve("model.pt"))) {
769773
return "PyTorch";
@@ -781,8 +785,6 @@ private String inferEngine() throws ModelException {
781785
|| Files.isRegularFile(modelDir.resolve("model.bst"))
782786
|| Files.isRegularFile(modelDir.resolve("model.xgb"))) {
783787
return "XGBoost";
784-
} else if (Files.isRegularFile(modelDir.resolve(prefix + ".gguf"))) {
785-
return "Llama";
786788
} else if (isPythonModel(prefix)) {
787789
// TODO: How to differentiate Rust model from Python
788790
return "Python";
@@ -811,26 +813,60 @@ private boolean isPythonModel(String prefix) {
811813
return Files.isRegularFile(modelDir.resolve("model.py"))
812814
|| Files.isRegularFile(modelDir.resolve(prefix + ".py"))
813815
|| prop.getProperty("option.model_id") != null
814-
|| Files.isRegularFile(modelDir.resolve("config.json"))
815-
|| isTorchServeModel();
816+
|| Files.isRegularFile(modelDir.resolve("config.json"));
816817
}
817818

818819
private void downloadModel() throws ModelException, IOException {
819820
if (resolvedModelUrl.startsWith("s3://")) {
820821
modelDir = downloadS3ToDownloadDir(resolvedModelUrl);
821822
resolvedModelUrl = modelDir.toUri().toURL().toString();
822-
return;
823-
}
824-
Repository repository = Repository.newInstance("modelStore", resolvedModelUrl);
825-
List<MRL> mrls = repository.getResources();
826-
if (mrls.isEmpty()) {
827-
throw new ModelNotFoundException("Invalid model url: " + resolvedModelUrl);
823+
} else {
824+
Repository repository = Repository.newInstance("modelStore", resolvedModelUrl);
825+
List<MRL> mrls = repository.getResources();
826+
if (mrls.isEmpty()) {
827+
throw new ModelNotFoundException("Invalid model url: " + resolvedModelUrl);
828+
}
829+
830+
Artifact artifact = mrls.get(0).getDefaultArtifact();
831+
repository.prepare(artifact);
832+
modelDir = Utils.getNestedModelDir(repository.getResourceDirectory(artifact));
833+
artifactName = artifact.getName();
834+
if (Files.isRegularFile(modelDir)) {
835+
modelDir = modelDir.getParent();
836+
return;
837+
}
828838
}
829839

830-
Artifact artifact = mrls.get(0).getDefaultArtifact();
831-
repository.prepare(artifact);
832-
modelDir = Utils.getNestedModelDir(repository.getResourceDirectory(artifact));
833-
artifactName = artifact.getName();
840+
try (Stream<Path> stream = Files.list(modelDir)) {
841+
List<Path> list = stream.collect(Collectors.toList());
842+
if (list.size() == 1) {
843+
Path match = list.get(0);
844+
String name = Objects.requireNonNull(match.getFileName()).toString();
845+
String type = FilenameUtils.getFileType(name);
846+
if ("zip".equals(type)) {
847+
String hash = Utils.hash(match.toAbsolutePath().toString());
848+
String download = Utils.getenv("SERVING_DOWNLOAD_DIR", null);
849+
Path parent = download == null ? Utils.getCacheDir() : Paths.get(download);
850+
parent = parent.resolve("download");
851+
Path extracted = parent.resolve(hash);
852+
if (Files.exists(extracted)) {
853+
logger.info("archive already extracted: {}", extracted);
854+
} else {
855+
Files.createDirectories(parent);
856+
Path tmp = Files.createTempDirectory(parent, "tmp");
857+
try (InputStream is = Files.newInputStream(match)) {
858+
ZipUtils.unzip(is, tmp);
859+
Utils.moveQuietly(tmp, extracted);
860+
logger.info("Archive file extracted to {}", extracted);
861+
} finally {
862+
Utils.deleteQuietly(tmp);
863+
}
864+
}
865+
modelDir = extracted;
866+
resolvedModelUrl = modelDir.toUri().toURL().toString();
867+
}
868+
}
869+
}
834870
}
835871

836872
private void loadServingProperties() {

0 commit comments

Comments
 (0)