From 71a4e1e9db61541aa2e4e6b7078b00823cb4b87f Mon Sep 17 00:00:00 2001 From: Adrian Zuur Date: Mon, 5 Feb 2024 08:50:27 -0500 Subject: [PATCH 1/2] add code to serve --- ml_pipelines/deployment/aws/serve.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/ml_pipelines/deployment/aws/serve.py b/ml_pipelines/deployment/aws/serve.py index e0f54aa..c9d5ae1 100644 --- a/ml_pipelines/deployment/aws/serve.py +++ b/ml_pipelines/deployment/aws/serve.py @@ -17,15 +17,14 @@ def main( - train_artifacts_root_path: Union[str, None] = None, # noqa: UP007 + train_artifacts_bucket: Union[str, None] = None, # noqa: UP007 train_version: Union[str, None] = None, # noqa: UP007 ): """ Serves a model in the /predict/ endpoint of a FastAPI app. - If `train_artifacts_root_path` is null, the command searches for the - TRAIN_ARTIFACTS_ROOT_PATH environment variable, and if not present, - assumes this to be "/". + If `train_artifacts_bucket` is null, the command searches for the + TRAIN_ARTIFACTS_BUCKET environment variable. If `train_version` is null, the command loads the model tagged as 'best_version' in the `train_artifacts_root_path`, and if not found, loads the latest model. @@ -33,21 +32,21 @@ def main( logger = Logger(__file__) logger.addHandler(logging.StreamHandler(sys.stdout)) - if train_artifacts_root_path is None: - train_artifacts_root_path = os.environ.get("TRAIN_ARTIFACTS_ROOT_PATH", "/") + if train_artifacts_bucket is None: + train_artifacts_bucket = os.environ["TRAIN_ARTIFACTS_BUCKET"] if train_version is None: - train_version = get_best_version(train_artifacts_root_path) # type: ignore + train_version = get_best_version(train_artifacts_bucket) # type: ignore if train_version is None: train_version = get_latest_version( - train_artifacts_root_path, # type: ignore + train_artifacts_bucket, # type: ignore "model.pickle", ) get_train_artifacts_func = partial( get_train_artifacts, - train_artifacts_root_path, # type: ignore + train_artifacts_bucket, # type: ignore load_data=False, ) From acbebd43a9c43504df372c8a6fa4235b03d18739 Mon Sep 17 00:00:00 2001 From: Adrian Zuur Date: Mon, 5 Feb 2024 08:52:36 -0500 Subject: [PATCH 2/2] update import --- ml_pipelines/deployment/aws/serve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml_pipelines/deployment/aws/serve.py b/ml_pipelines/deployment/aws/serve.py index c9d5ae1..49e32b4 100644 --- a/ml_pipelines/deployment/aws/serve.py +++ b/ml_pipelines/deployment/aws/serve.py @@ -7,13 +7,13 @@ import typer -from ml_pipelines.deployment.common.serve import run_serve -from ml_pipelines.deployment.local.io import ( +from ml_pipelines.deployment.aws.io import ( get_best_version, get_latest_version, get_train_artifacts, prediction_logging_func, ) +from ml_pipelines.deployment.common.serve import run_serve def main(