Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add code to serve #15

Merged
merged 2 commits into from
Feb 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions ml_pipelines/deployment/aws/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,46 @@

import typer

from ml_pipelines.deployment.common.serve import run_serve
from ml_pipelines.deployment.local.io import (
from ml_pipelines.deployment.aws.io import (
get_best_version,
get_latest_version,
get_train_artifacts,
prediction_logging_func,
)
from ml_pipelines.deployment.common.serve import run_serve


def main(
train_artifacts_root_path: Union[str, None] = None, # noqa: UP007
train_artifacts_bucket: Union[str, None] = None, # noqa: UP007
train_version: Union[str, None] = None, # noqa: UP007
):
"""
Serves a model in the /predict/ endpoint of a FastAPI app.

If `train_artifacts_root_path` is null, the command searches for the
TRAIN_ARTIFACTS_ROOT_PATH environment variable, and if not present,
assumes this to be "/".
If `train_artifacts_bucket` is null, the command searches for the
TRAIN_ARTIFACTS_BUCKET environment variable.

If `train_version` is null, the command loads the model tagged as 'best_version'
in the `train_artifacts_root_path`, and if not found, loads the latest model.
"""
logger = Logger(__file__)
logger.addHandler(logging.StreamHandler(sys.stdout))

if train_artifacts_root_path is None:
train_artifacts_root_path = os.environ.get("TRAIN_ARTIFACTS_ROOT_PATH", "/")
if train_artifacts_bucket is None:
train_artifacts_bucket = os.environ["TRAIN_ARTIFACTS_BUCKET"]

if train_version is None:
train_version = get_best_version(train_artifacts_root_path) # type: ignore
train_version = get_best_version(train_artifacts_bucket) # type: ignore

if train_version is None:
train_version = get_latest_version(
train_artifacts_root_path, # type: ignore
train_artifacts_bucket, # type: ignore
"model.pickle",
)

get_train_artifacts_func = partial(
get_train_artifacts,
train_artifacts_root_path, # type: ignore
train_artifacts_bucket, # type: ignore
load_data=False,
)

Expand Down
Loading