diff --git a/ml_pipelines/deployment/local/serve.py b/ml_pipelines/deployment/local/serve.py index 39a19c2..244d60d 100644 --- a/ml_pipelines/deployment/local/serve.py +++ b/ml_pipelines/deployment/local/serve.py @@ -7,7 +7,10 @@ import typer import uvicorn -from ml_pipelines.deployment.local.common import get_latest_version, get_train_artifacts +from ml_pipelines.deployment.local.common import ( + get_best_version, + get_train_artifacts, +) from ml_pipelines.logic.serve.serve import Point, create_fastapi_app @@ -48,10 +51,11 @@ def main( logger.addHandler(logging.StreamHandler(sys.stdout)) if train_version is None: - train_version = get_latest_version( - train_artifacts_root_path, # type: ignore - "model.pickle", - ) + train_version = get_best_version(train_artifacts_root_path) # type: ignore + # train_version = get_latest_version( + # train_artifacts_root_path, # type: ignore + # "model.pickle", + # ) uvicorn_kwargs: dict = {} run_serve( # noqa: PLR0913