4040import ai .djl .translate .TranslateException ;
4141import ai .djl .util .NeuronUtils ;
4242import ai .djl .util .Utils ;
43+ import ai .djl .util .ZipUtils ;
4344import ai .djl .util .cuda .CudaUtils ;
4445
4546import org .slf4j .Logger ;
6566import java .util .concurrent .ConcurrentHashMap ;
6667import java .util .regex .Matcher ;
6768import java .util .regex .Pattern ;
69+ import java .util .stream .Collectors ;
6870import java .util .stream .Stream ;
6971
7072/** A class represent a loaded model and it's metadata. */
@@ -764,6 +766,8 @@ private String inferEngine() throws ModelException {
764766 String groupId = mrl .getGroupId ();
765767 ModelZoo zoo = ModelZoo .getModelZoo (groupId );
766768 return zoo .getSupportedEngines ().iterator ().next ();
769+ } else if (isTorchServeModel ()) {
770+ return "Python" ;
767771 } else if (Files .isRegularFile (modelDir .resolve (prefix + ".pt" ))
768772 || Files .isRegularFile (modelDir .resolve ("model.pt" ))) {
769773 return "PyTorch" ;
@@ -781,8 +785,6 @@ private String inferEngine() throws ModelException {
781785 || Files .isRegularFile (modelDir .resolve ("model.bst" ))
782786 || Files .isRegularFile (modelDir .resolve ("model.xgb" ))) {
783787 return "XGBoost" ;
784- } else if (Files .isRegularFile (modelDir .resolve (prefix + ".gguf" ))) {
785- return "Llama" ;
786788 } else if (isPythonModel (prefix )) {
787789 // TODO: How to differentiate Rust model from Python
788790 return "Python" ;
@@ -811,26 +813,60 @@ private boolean isPythonModel(String prefix) {
811813 return Files .isRegularFile (modelDir .resolve ("model.py" ))
812814 || Files .isRegularFile (modelDir .resolve (prefix + ".py" ))
813815 || prop .getProperty ("option.model_id" ) != null
814- || Files .isRegularFile (modelDir .resolve ("config.json" ))
815- || isTorchServeModel ();
816+ || Files .isRegularFile (modelDir .resolve ("config.json" ));
816817 }
817818
818819 private void downloadModel () throws ModelException , IOException {
819820 if (resolvedModelUrl .startsWith ("s3://" )) {
820821 modelDir = downloadS3ToDownloadDir (resolvedModelUrl );
821822 resolvedModelUrl = modelDir .toUri ().toURL ().toString ();
822- return ;
823- }
824- Repository repository = Repository .newInstance ("modelStore" , resolvedModelUrl );
825- List <MRL > mrls = repository .getResources ();
826- if (mrls .isEmpty ()) {
827- throw new ModelNotFoundException ("Invalid model url: " + resolvedModelUrl );
823+ } else {
824+ Repository repository = Repository .newInstance ("modelStore" , resolvedModelUrl );
825+ List <MRL > mrls = repository .getResources ();
826+ if (mrls .isEmpty ()) {
827+ throw new ModelNotFoundException ("Invalid model url: " + resolvedModelUrl );
828+ }
829+
830+ Artifact artifact = mrls .get (0 ).getDefaultArtifact ();
831+ repository .prepare (artifact );
832+ modelDir = Utils .getNestedModelDir (repository .getResourceDirectory (artifact ));
833+ artifactName = artifact .getName ();
834+ if (Files .isRegularFile (modelDir )) {
835+ modelDir = modelDir .getParent ();
836+ return ;
837+ }
828838 }
829839
830- Artifact artifact = mrls .get (0 ).getDefaultArtifact ();
831- repository .prepare (artifact );
832- modelDir = Utils .getNestedModelDir (repository .getResourceDirectory (artifact ));
833- artifactName = artifact .getName ();
840+ try (Stream <Path > stream = Files .list (modelDir )) {
841+ List <Path > list = stream .collect (Collectors .toList ());
842+ if (list .size () == 1 ) {
843+ Path match = list .get (0 );
844+ String name = Objects .requireNonNull (match .getFileName ()).toString ();
845+ String type = FilenameUtils .getFileType (name );
846+ if ("zip" .equals (type )) {
847+ String hash = Utils .hash (match .toAbsolutePath ().toString ());
848+ String download = Utils .getenv ("SERVING_DOWNLOAD_DIR" , null );
849+ Path parent = download == null ? Utils .getCacheDir () : Paths .get (download );
850+ parent = parent .resolve ("download" );
851+ Path extracted = parent .resolve (hash );
852+ if (Files .exists (extracted )) {
853+ logger .info ("archive already extracted: {}" , extracted );
854+ } else {
855+ Files .createDirectories (parent );
856+ Path tmp = Files .createTempDirectory (parent , "tmp" );
857+ try (InputStream is = Files .newInputStream (match )) {
858+ ZipUtils .unzip (is , tmp );
859+ Utils .moveQuietly (tmp , extracted );
860+ logger .info ("Archive file extracted to {}" , extracted );
861+ } finally {
862+ Utils .deleteQuietly (tmp );
863+ }
864+ }
865+ modelDir = extracted ;
866+ resolvedModelUrl = modelDir .toUri ().toURL ().toString ();
867+ }
868+ }
869+ }
834870 }
835871
836872 private void loadServingProperties () {
0 commit comments