-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* rename common to io * refactor deployment for modularity * WIP aws train * fix mistakes with partial + test train * Update all except serve
- Loading branch information
Showing
18 changed files
with
698 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import logging | ||
import os | ||
import sys | ||
from functools import partial | ||
from logging import Logger | ||
from typing import Union | ||
|
||
import typer | ||
|
||
from ml_pipelines.deployment.aws.io import ( | ||
get_all_available_train_versions, | ||
get_latest_version, | ||
get_raw_data, | ||
get_train_artifacts, | ||
tag_best_version, | ||
) | ||
from ml_pipelines.deployment.common.eval import run_eval_comparison_pipeline | ||
|
||
|
||
def main( | ||
raw_data_bucket: Union[str, None] = None, # noqa: UP007 | ||
train_artifacts_bucket: Union[str, None] = None, # noqa: UP007, E501 | ||
raw_data_version: Union[str, None] = None, # noqa: UP007 | ||
train_versions: Union[list[str], None] = None, # noqa: UP007 | ||
tag_best_model: bool = False, | ||
): | ||
""" | ||
Runs the model evaluation and comparison pipeline using AWS S3 buckets | ||
for inputs and outputs. | ||
If `raw_data_bucket` is null, the command searches for the RAW_DATA_BUCKET | ||
environment variable. | ||
If `train_artifacts_bucket` is null, the command searches for the | ||
TRAIN_ARTIFACTS_BUCKET environment variable. | ||
If `raw_data_version` is null, the command searches for the latest version in | ||
`raw_data_bucket`. | ||
If `train_versions` is null or empty, the command automatically evaluates all | ||
models found in `train_artifacts_bucket`. | ||
If `tag_best_model` is set (to true) and more than one model version is evaluated, | ||
the best performing one is tagged as the best version. | ||
""" | ||
logger = Logger(__file__) | ||
logger.addHandler(logging.StreamHandler(sys.stdout)) | ||
|
||
if raw_data_bucket is None: | ||
raw_data_bucket = os.environ["RAW_DATA_BUCKET"] | ||
if train_artifacts_bucket is None: | ||
train_artifacts_bucket = os.environ["TRAIN_ARTIFACTS_BUCKET"] | ||
|
||
if raw_data_version is None: | ||
raw_data_version = get_latest_version( | ||
raw_data_bucket, # type: ignore | ||
"raw_data.csv", | ||
) | ||
|
||
if not train_versions: | ||
train_versions = get_all_available_train_versions( | ||
train_artifacts_bucket # type: ignore | ||
) | ||
|
||
get_raw_data_func = partial( | ||
get_raw_data, | ||
raw_data_bucket, # type: ignore | ||
) | ||
get_train_artifacts_func = partial( | ||
get_train_artifacts, | ||
train_artifacts_bucket, # type: ignore | ||
load_data=False, | ||
) | ||
tag_best_version_func = partial( | ||
tag_best_version, | ||
train_artifacts_bucket, # type: ignore | ||
) | ||
|
||
run_eval_comparison_pipeline( # noqa: PLR0913 | ||
raw_data_version=raw_data_version, | ||
train_versions=train_versions, # type: ignore | ||
get_raw_data_func=get_raw_data_func, | ||
get_train_artifacts_func=get_train_artifacts_func, | ||
tag_best_model_func=tag_best_version_func, | ||
tag_best_model=tag_best_model, | ||
logger=logger, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
typer.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
import json | ||
import pickle | ||
from io import BytesIO, StringIO | ||
from tempfile import NamedTemporaryFile | ||
from uuid import uuid4 | ||
|
||
import boto3 | ||
import pandas as pd | ||
from botocore.exceptions import ClientError | ||
from matplotlib.figure import Figure | ||
from sklearn.linear_model import LogisticRegression | ||
|
||
from ml_pipelines.logic.common.feature_eng import FeatureEngineeringParams | ||
from ml_pipelines.logic.serve.serve import Point | ||
from ml_pipelines.pipeline.train_pipeline import TrainArtifacts | ||
|
||
|
||
def make_version(prefix: str) -> str: | ||
version = str(uuid4()) | ||
return f"{prefix}_{version}" | ||
|
||
|
||
def get_raw_data(bucket_name: str, version: str) -> pd.DataFrame: | ||
s3 = boto3.client("s3") | ||
s3_key = f"{version}/raw_data.csv" | ||
|
||
response = s3.get_object(Bucket=bucket_name, Key=s3_key) | ||
content = response["Body"].read().decode("utf-8") | ||
|
||
return pd.read_csv(StringIO(content)) | ||
|
||
|
||
class ModelVersionAlreadyExists(Exception): | ||
pass | ||
|
||
|
||
def save_train_artifacts( | ||
bucket_name: str, | ||
version: str, | ||
artifacts: TrainArtifacts, | ||
): | ||
s3 = boto3.client("s3") | ||
|
||
# Create a directory structure within the bucket for the version | ||
s3_prefix = f"{version}/" | ||
|
||
# Check if the version already exists in the bucket | ||
head_object = None | ||
try: | ||
head_object = s3.head_object(Bucket=bucket_name, Key=s3_prefix) | ||
except ClientError as ex: | ||
if ex.response["Error"]["Code"] == "404": | ||
pass | ||
else: | ||
raise | ||
if head_object is not None: | ||
raise ModelVersionAlreadyExists() | ||
|
||
model_pickle_key = s3_prefix + "model.pickle" | ||
s3.upload_fileobj( | ||
Fileobj=BytesIO(pickle.dumps(artifacts["model"])), | ||
Bucket=bucket_name, | ||
Key=model_pickle_key, | ||
) | ||
|
||
feature_eng_params_key = s3_prefix + "feature_eng_params.json" | ||
s3.upload_fileobj( | ||
Fileobj=BytesIO(artifacts["feature_eng_params"].json().encode()), | ||
Bucket=bucket_name, | ||
Key=feature_eng_params_key, | ||
) | ||
|
||
data_keys = ["raw_train_data", "raw_test_data", "train_data", "test_data"] | ||
for key in data_keys: | ||
if key in artifacts: | ||
dataset: pd.DataFrame = artifacts[key] # type: ignore | ||
csv_key = s3_prefix + f"{key}.csv" | ||
with NamedTemporaryFile(mode="w") as f: | ||
dataset.to_csv(f.name, index=False) | ||
s3.upload_file(f.name, bucket_name, csv_key) | ||
|
||
|
||
def get_train_artifacts(bucket_name: str, version: str, load_data: bool = True): | ||
s3 = boto3.client("s3") | ||
|
||
s3_prefix = f"{version}/" | ||
|
||
model_key = s3_prefix + "model.pickle" | ||
|
||
model: LogisticRegression = pickle.load( | ||
s3.get_object(Bucket=bucket_name, Key=model_key)["Body"] | ||
) | ||
|
||
feature_eng_params_key = s3_prefix + "feature_eng_params.json" | ||
fep = s3.get_object(Bucket=bucket_name, Key=feature_eng_params_key)["Body"].read() | ||
feature_eng_params = FeatureEngineeringParams(**json.loads(fep.decode())) | ||
|
||
if not load_data: | ||
return TrainArtifacts( | ||
model=model, | ||
feature_eng_params=feature_eng_params, | ||
) | ||
|
||
data_keys = ["raw_train_data", "raw_test_data", "train_data", "test_data"] | ||
data_dict = {} | ||
for key in data_keys: | ||
csv_key = s3_prefix + f"{key}.csv" | ||
data_dict[key] = pd.read_csv( | ||
StringIO( | ||
s3.get_object(Bucket=bucket_name, Key=csv_key)["Body"].read().decode() | ||
) | ||
) | ||
|
||
return TrainArtifacts( | ||
model=model, | ||
feature_eng_params=feature_eng_params, | ||
raw_train_data=data_dict["raw_train_data"], | ||
raw_test_data=data_dict["raw_test_data"], | ||
train_data=data_dict["train_data"], | ||
test_data=data_dict["test_data"], | ||
) | ||
|
||
|
||
def save_eval_artifacts(bucket_name: str, version: str, metrics: float, plots: Figure): | ||
s3 = boto3.client("s3") | ||
|
||
s3_prefix = f"{version}/" | ||
|
||
metrics_key = s3_prefix + "metrics.txt" | ||
s3.upload_fileobj( | ||
Fileobj=BytesIO(str(metrics).encode()), | ||
Bucket=bucket_name, | ||
Key=metrics_key, | ||
) | ||
|
||
plot_key = s3_prefix + "calibration_plot.png" | ||
buf = BytesIO() | ||
plots.savefig(buf) | ||
buf.seek(0) | ||
s3.upload_fileobj( | ||
Fileobj=buf, | ||
Bucket=bucket_name, | ||
Key=plot_key, | ||
) | ||
|
||
|
||
def get_all_available_train_versions(bucket_name: str): | ||
s3 = boto3.client("s3") | ||
|
||
# List objects in the specified prefix | ||
objects = s3.list_objects(Bucket=bucket_name, Prefix="") | ||
|
||
return list( | ||
{ | ||
obj["Key"].split("/")[0] | ||
for obj in objects.get("Contents", []) | ||
if len(obj["Key"].split("/")) == 2 # noqa: PLR2004 | ||
} | ||
) | ||
|
||
|
||
def get_latest_version(bucket_name: str, filename: str) -> str: | ||
s3 = boto3.client("s3") | ||
objects = s3.list_objects(Bucket=bucket_name, Prefix="") | ||
depth = 2 | ||
versions = [ | ||
(obj["Key"].split("/")[0], obj["LastModified"]) | ||
for obj in objects.get("Contents", []) | ||
if obj["Key"].endswith(filename) and len(obj["Key"].split("/")) == depth | ||
] | ||
return max(versions, key=lambda t: t[1])[0] | ||
|
||
|
||
def get_best_version(bucket_name: str): | ||
s3 = boto3.client("s3") | ||
objects = s3.list_objects(Bucket=bucket_name, Prefix="").get("Contents", []) | ||
keys = {o["Key"] for o in objects} | ||
if "best_model" not in keys: | ||
return None | ||
return ( | ||
s3.get_object(Bucket=bucket_name, Key="best_model")["Body"] | ||
.read() | ||
.decode("utf-8") | ||
) | ||
|
||
|
||
def tag_best_version(bucket_name: str, train_version: str): | ||
s3 = boto3.client("s3") | ||
s3.upload_fileobj( | ||
Fileobj=BytesIO(train_version.encode()), | ||
Bucket=bucket_name, | ||
Key="best_model", | ||
) | ||
|
||
|
||
def prediction_logging_func(predictions: list[tuple[Point, float]]): | ||
print("logged preds!") # noqa: T201 | ||
return True, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import logging | ||
import os | ||
import sys | ||
from functools import partial | ||
from logging import Logger | ||
from typing import Union | ||
|
||
import typer | ||
|
||
from ml_pipelines.deployment.common.serve import run_serve | ||
from ml_pipelines.deployment.local.io import ( | ||
get_best_version, | ||
get_latest_version, | ||
get_train_artifacts, | ||
prediction_logging_func, | ||
) | ||
|
||
|
||
def main( | ||
train_artifacts_root_path: Union[str, None] = None, # noqa: UP007 | ||
train_version: Union[str, None] = None, # noqa: UP007 | ||
): | ||
""" | ||
Serves a model in the /predict/ endpoint of a FastAPI app. | ||
If `train_artifacts_root_path` is null, the command searches for the | ||
TRAIN_ARTIFACTS_ROOT_PATH environment variable, and if not present, | ||
assumes this to be "/". | ||
If `train_version` is null, the command loads the model tagged as 'best_version' | ||
in the `train_artifacts_root_path`, and if not found, loads the latest model. | ||
""" | ||
logger = Logger(__file__) | ||
logger.addHandler(logging.StreamHandler(sys.stdout)) | ||
|
||
if train_artifacts_root_path is None: | ||
train_artifacts_root_path = os.environ.get("TRAIN_ARTIFACTS_ROOT_PATH", "/") | ||
|
||
if train_version is None: | ||
train_version = get_best_version(train_artifacts_root_path) # type: ignore | ||
|
||
if train_version is None: | ||
train_version = get_latest_version( | ||
train_artifacts_root_path, # type: ignore | ||
"model.pickle", | ||
) | ||
|
||
get_train_artifacts_func = partial( | ||
get_train_artifacts, | ||
train_artifacts_root_path, # type: ignore | ||
load_data=False, | ||
) | ||
|
||
uvicorn_kwargs: dict = {} | ||
run_serve( # noqa: PLR0913 | ||
train_version=train_version, | ||
get_train_artifacts_func=get_train_artifacts_func, | ||
prediction_logging_func=prediction_logging_func, | ||
logger=logger, | ||
uvicorn_kwargs=uvicorn_kwargs, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
typer.run(main) |
Oops, something went wrong.