|
35 | 35 | from huggingface_hub import snapshot_download # noqa
|
36 | 36 | from datasets import load_dataset # noqa
|
37 | 37 | from app import __version__ # noqa
|
| 38 | +from app.config import Settings # noqa |
38 | 39 | from app.domain import ModelType, TrainingType, BuildBackend, Device, ArchiveFormat, LlmEngine # noqa
|
39 | 40 | from app.registry import model_service_registry # noqa
|
40 | 41 | from app.api.api import (
|
|
44 | 45 | get_vllm_server,
|
45 | 46 | get_app_for_api_docs,
|
46 | 47 | ) # noqa
|
47 |
| -from app.utils import get_settings, send_gelf_message, download_model_package # noqa |
| 48 | +from app.utils import get_settings, send_gelf_message, download_model_package, get_model_data_package_base_name # noqa |
48 | 49 | from app.management.model_manager import ModelManager # noqa
|
49 | 50 | from app.api.dependencies import ModelServiceDep, ModelManagerDep # noqa
|
50 | 51 | from app.management.tracker_client import TrackerClient # noqa
|
@@ -113,10 +114,7 @@ def serve_model(
|
113 | 114 | model_service_dep = ModelServiceDep(model_type, config, model_name)
|
114 | 115 | cms_globals.model_service_dep = model_service_dep
|
115 | 116 |
|
116 |
| - dst_model_path = os.path.join(parent_dir, "model", "model.zip" if model_path.endswith(".zip") else "model.tar.gz") |
117 |
| - config.BASE_MODEL_FILE = "model.zip" if model_path.endswith(".zip") else "model.tar.gz" |
118 |
| - if dst_model_path and os.path.exists(os.path.splitext(dst_model_path)[0]): |
119 |
| - shutil.rmtree(os.path.splitext(dst_model_path)[0]) |
| 117 | + dst_model_path = _ensure_dst_model_path(model_path, parent_dir, config) |
120 | 118 |
|
121 | 119 | if model_path:
|
122 | 120 | if model_path.startswith("http://") or model_path.startswith("https://"):
|
@@ -221,15 +219,13 @@ def train_model(
|
221 | 219 | model_service_dep = ModelServiceDep(model_type, config)
|
222 | 220 | cms_globals.model_service_dep = model_service_dep
|
223 | 221 |
|
224 |
| - dst_model_path = os.path.join(parent_dir, "model", "model.zip" if base_model_path.endswith(".zip") else "model.tar.gz") |
225 |
| - config.BASE_MODEL_FILE = "model.zip" if base_model_path.endswith(".zip") else "model.tar.gz" |
226 |
| - if dst_model_path and os.path.exists(os.path.splitext(dst_model_path)[0]): |
227 |
| - shutil.rmtree(os.path.splitext(dst_model_path)[0]) |
| 222 | + dst_model_path = _ensure_dst_model_path(base_model_path, parent_dir, config) |
228 | 223 |
|
229 | 224 | if base_model_path:
|
230 | 225 | try:
|
231 | 226 | shutil.copy2(base_model_path, dst_model_path)
|
232 | 227 | except shutil.SameFileError:
|
| 228 | + logger.warning("Source and destination are the same model package file.") |
233 | 229 | pass
|
234 | 230 | model_service = model_service_dep()
|
235 | 231 | model_service.model_name = model_name if model_name is not None else "CMS model"
|
@@ -708,6 +704,23 @@ def show_banner() -> None:
|
708 | 704 | typer.echo(banner)
|
709 | 705 |
|
710 | 706 |
|
| 707 | +def _ensure_dst_model_path(model_path: str, parent_dir: str, config: Settings) -> str: |
| 708 | + if model_path.endswith(".zip"): |
| 709 | + dst_model_path = os.path.join(parent_dir, "model", "model.zip") |
| 710 | + config.BASE_MODEL_FILE = "model.zip" |
| 711 | + else: |
| 712 | + dst_model_path = os.path.join(parent_dir, "model", "model.tar.gz") |
| 713 | + config.BASE_MODEL_FILE = "model.tar.gz" |
| 714 | + model_dir = os.path.join(parent_dir, "model", "model") |
| 715 | + if os.path.exists(model_dir): |
| 716 | + shutil.rmtree(model_dir) |
| 717 | + if dst_model_path.endswith(".zip") and os.path.exists(dst_model_path.replace(".zip", ".tar.gz")): |
| 718 | + os.remove(dst_model_path.replace(".zip", ".tar.gz")) |
| 719 | + if dst_model_path.endswith(".tar.gz") and os.path.exists(dst_model_path.replace(".tar.gz", ".zip")): |
| 720 | + os.remove(dst_model_path.replace(".tar.gz", ".zip")) |
| 721 | + return dst_model_path |
| 722 | + |
| 723 | + |
711 | 724 | def _get_logger(
|
712 | 725 | debug: Optional[bool] = None,
|
713 | 726 | model_type: Optional[ModelType] = None,
|
|
0 commit comments