Skip to content

Commit

Permalink
add code to serve (#15)
Browse files Browse the repository at this point in the history
* add code to serve

* update import
  • Loading branch information
azuur authored Feb 5, 2024
1 parent f2063cc commit 5f9d5e5
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions ml_pipelines/deployment/aws/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,46 @@

import typer

from ml_pipelines.deployment.common.serve import run_serve
from ml_pipelines.deployment.local.io import (
from ml_pipelines.deployment.aws.io import (
get_best_version,
get_latest_version,
get_train_artifacts,
prediction_logging_func,
)
from ml_pipelines.deployment.common.serve import run_serve


def main(
train_artifacts_root_path: Union[str, None] = None, # noqa: UP007
train_artifacts_bucket: Union[str, None] = None, # noqa: UP007
train_version: Union[str, None] = None, # noqa: UP007
):
"""
Serves a model in the /predict/ endpoint of a FastAPI app.
If `train_artifacts_root_path` is null, the command searches for the
TRAIN_ARTIFACTS_ROOT_PATH environment variable, and if not present,
assumes this to be "/".
If `train_artifacts_bucket` is null, the command searches for the
TRAIN_ARTIFACTS_BUCKET environment variable.
If `train_version` is null, the command loads the model tagged as 'best_version'
in the `train_artifacts_root_path`, and if not found, loads the latest model.
"""
logger = Logger(__file__)
logger.addHandler(logging.StreamHandler(sys.stdout))

if train_artifacts_root_path is None:
train_artifacts_root_path = os.environ.get("TRAIN_ARTIFACTS_ROOT_PATH", "/")
if train_artifacts_bucket is None:
train_artifacts_bucket = os.environ["TRAIN_ARTIFACTS_BUCKET"]

if train_version is None:
train_version = get_best_version(train_artifacts_root_path) # type: ignore
train_version = get_best_version(train_artifacts_bucket) # type: ignore

if train_version is None:
train_version = get_latest_version(
train_artifacts_root_path, # type: ignore
train_artifacts_bucket, # type: ignore
"model.pickle",
)

get_train_artifacts_func = partial(
get_train_artifacts,
train_artifacts_root_path, # type: ignore
train_artifacts_bucket, # type: ignore
load_data=False,
)

Expand Down

0 comments on commit 5f9d5e5

Please sign in to comment.