-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add prediction logging func to api * Add local deployment for train pipeline * reset python <3.10 type hints for typer * added serve for local deployment * Update readme
- Loading branch information
Showing
14 changed files
with
636 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import json | ||
import os | ||
import pickle | ||
from pathlib import Path | ||
from uuid import uuid4 | ||
|
||
import pandas as pd | ||
from matplotlib.figure import Figure | ||
from sklearn.linear_model import LogisticRegression | ||
|
||
from ml_pipelines.logic.common.feature_eng import FeatureEngineeringParams | ||
from ml_pipelines.pipeline.train_pipeline import TrainArtifacts | ||
|
||
|
||
def make_version(prefix: str) -> str: | ||
version = str(uuid4()) | ||
return f"{prefix}_{version}" | ||
|
||
|
||
def get_raw_data(version: str, root_path: os.PathLike) -> pd.DataFrame: | ||
return pd.read_csv(Path(root_path) / version / "raw_data.csv") | ||
|
||
|
||
class ModelVersionAlreadyExists(Exception): | ||
pass | ||
|
||
|
||
def save_train_artifacts( | ||
version: str, | ||
root_path: os.PathLike, | ||
artifacts: TrainArtifacts, | ||
): | ||
version_dir = Path(root_path) / version | ||
try: | ||
version_dir.mkdir(parents=True) | ||
except FileExistsError: | ||
raise ModelVersionAlreadyExists() | ||
|
||
with open(version_dir / "model.pickle", "wb") as f: | ||
pickle.dump(artifacts["model"], f) | ||
with open(version_dir / "feature_eng_params.json", "w") as f: # type: ignore | ||
f.write(artifacts["feature_eng_params"].json()) | ||
|
||
data_keys = [ | ||
"raw_train_data", | ||
"raw_test_data", | ||
"train_data", | ||
"test_data", | ||
] | ||
for key in data_keys: | ||
if key in artifacts: | ||
dataset = artifacts[key] # type: ignore | ||
dataset.to_csv(version_dir / f"{key}.csv", index=False) | ||
|
||
|
||
def get_train_artifacts(version: str, root_path: os.PathLike, load_data: bool = True): | ||
version_dir = Path(root_path) / version | ||
with open(version_dir / "model.pickle", "rb") as f: # type: ignore | ||
model: LogisticRegression = pickle.load(f) | ||
|
||
with open(version_dir / "feature_eng_params.json") as f: # type: ignore | ||
feature_eng_params = FeatureEngineeringParams(**json.loads(f.read())) | ||
|
||
if not load_data: | ||
return TrainArtifacts( | ||
model=model, | ||
feature_eng_params=feature_eng_params, | ||
) | ||
|
||
raw_train_data = pd.read_csv(version_dir / "raw_train_data.csv") | ||
raw_test_data = pd.read_csv(version_dir / "raw_test_data.csv") | ||
train_data = pd.read_csv(version_dir / "train_data.csv") | ||
test_data = pd.read_csv(version_dir / "test_data.csv") | ||
|
||
return TrainArtifacts( | ||
model=model, | ||
feature_eng_params=feature_eng_params, | ||
raw_train_data=raw_train_data, | ||
raw_test_data=raw_test_data, | ||
train_data=train_data, | ||
test_data=test_data, | ||
) | ||
|
||
|
||
def save_eval_artifacts( | ||
version: str, root_path: os.PathLike, metrics: float, plots: Figure | ||
): | ||
version_dir = Path(root_path) / version | ||
with open(version_dir / "metrics.txt", "w") as f: # type: ignore | ||
f.write(str(metrics)) # type: ignore | ||
plots.savefig(str(version_dir / "calibration_plot.png")) | ||
|
||
|
||
def get_latest_version(root_path: os.PathLike, filename: str) -> str: | ||
root_dir = Path(root_path) | ||
versions: list[tuple[str, float]] = [] | ||
for version_dir in root_dir.iterdir(): | ||
st_mtime = (version_dir / filename).stat().st_mtime | ||
versions.append((version_dir.stem, st_mtime)) | ||
return max(versions, key=lambda t: t[1])[0] |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import logging | ||
import os | ||
import sys | ||
from logging import Logger | ||
from typing import Union | ||
|
||
import typer | ||
import uvicorn | ||
|
||
from ml_pipelines.deployment.local.common import get_latest_version, get_train_artifacts | ||
from ml_pipelines.logic.serve.serve import Point, create_fastapi_app | ||
|
||
|
||
def prediction_logging_func(predictions: list[tuple[Point, float]]): | ||
print("yay") # noqa: T201 | ||
return True, None | ||
|
||
|
||
def run_serve( | ||
train_version: str, | ||
train_artifacts_root_path: os.PathLike, # type: ignore | ||
logger: Logger, | ||
uvicorn_kwargs: dict, | ||
): | ||
logger.info(f"Serving model {train_version}.") | ||
train_artifacts = get_train_artifacts( | ||
train_version, train_artifacts_root_path, load_data=False | ||
) | ||
model = train_artifacts["model"] | ||
feature_eng_params = train_artifacts["feature_eng_params"] | ||
app = create_fastapi_app(model, feature_eng_params, logger, prediction_logging_func) | ||
logger.info("Loaded model, set up endpoint.") | ||
|
||
uvicorn.run(app=app, **uvicorn_kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
TRAIN_ARTIFACTS_ROOT_DIR = os.environ["TRAIN_ARTIFACTS_ROOT_DIR"] | ||
|
||
def main( | ||
train_version: Union[str, None] = None, # noqa: UP007 | ||
train_artifacts_root_path: str = TRAIN_ARTIFACTS_ROOT_DIR, | ||
): | ||
logger = Logger(__file__) | ||
logger.addHandler(logging.StreamHandler(sys.stdout)) | ||
|
||
if train_version is None: | ||
train_version = get_latest_version( | ||
train_artifacts_root_path, # type: ignore | ||
"model.pickle", | ||
) | ||
|
||
uvicorn_kwargs: dict = {} | ||
run_serve( # noqa: PLR0913 | ||
train_version=train_version, | ||
train_artifacts_root_path=train_artifacts_root_path, # type: ignore | ||
logger=logger, | ||
uvicorn_kwargs=uvicorn_kwargs, | ||
) | ||
|
||
typer.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import logging | ||
import os | ||
import sys | ||
from logging import Logger | ||
from typing import Union | ||
|
||
import typer | ||
|
||
from ml_pipelines.deployment.local.common import ( | ||
get_latest_version, | ||
get_raw_data, | ||
make_version, | ||
save_eval_artifacts, | ||
save_train_artifacts, | ||
) | ||
from ml_pipelines.pipeline.eval_pipeline import eval_pipeline | ||
from ml_pipelines.pipeline.train_pipeline import train_pipeline | ||
|
||
|
||
def run_train_pipeline( # noqa: PLR0913 | ||
raw_data_version: str, | ||
raw_data_root_path: os.PathLike, | ||
train_version: str, | ||
train_artifacts_root_path: os.PathLike, | ||
logger: Logger, | ||
split_random_state: int = 3825, | ||
): | ||
logger.info(f"Running full training pipeline version {train_version}.") | ||
logger.info(f"Raw data version {raw_data_version}.") | ||
raw_data = get_raw_data(raw_data_version, raw_data_root_path) | ||
train_artifacts = train_pipeline( | ||
raw_data, split_random_state=split_random_state, logger=logger | ||
) | ||
save_train_artifacts(train_version, train_artifacts_root_path, train_artifacts) | ||
logger.info("Saved train artifacts.") | ||
metrics, plots = eval_pipeline( | ||
train_artifacts["model"], | ||
train_artifacts["feature_eng_params"], | ||
train_artifacts["raw_test_data"], | ||
logger, | ||
) | ||
save_eval_artifacts(train_version, train_artifacts_root_path, metrics, plots) | ||
logger.info("Saved eval artifacts.") | ||
|
||
|
||
if __name__ == "__main__": | ||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
RAW_DATA_ROOT_DIR = os.environ["RAW_DATA_ROOT_DIR"] | ||
TRAIN_ARTIFACTS_ROOT_DIR = os.environ["TRAIN_ARTIFACTS_ROOT_DIR"] | ||
|
||
def main( | ||
raw_data_version: Union[str, None] = None, # noqa: UP007 | ||
train_version: Union[str, None] = None, # noqa: UP007 | ||
raw_data_root_path: str = RAW_DATA_ROOT_DIR, | ||
train_artifacts_root_path: str = TRAIN_ARTIFACTS_ROOT_DIR, | ||
split_random_state: int = 3825, | ||
): | ||
logger = Logger(__file__) | ||
logger.addHandler(logging.StreamHandler(sys.stdout)) | ||
|
||
if raw_data_version is None: | ||
raw_data_version = get_latest_version( | ||
raw_data_root_path, # type: ignore | ||
"raw_data.csv", | ||
) | ||
|
||
if train_version is None: | ||
train_version = make_version(prefix="model") | ||
|
||
run_train_pipeline( # noqa: PLR0913 | ||
raw_data_version=raw_data_version, | ||
raw_data_root_path=raw_data_root_path, # type: ignore | ||
train_version=train_version, | ||
train_artifacts_root_path=train_artifacts_root_path, # type: ignore | ||
logger=logger, | ||
split_random_state=split_random_state, | ||
) | ||
|
||
typer.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from ml_pipelines.logic.common.dgp import generate_raw_data | ||
|
||
data = generate_raw_data(10_000, 813) | ||
data.to_csv("data.csv", index=False) | ||
data.to_csv("raw_data.csv", index=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.