Skip to content

Commit 6a5fc08

Browse files
authored
Update examples to use model_info (mlflow#14636)
Signed-off-by: serena-ruan <[email protected]>
1 parent bf024b9 commit 6a5fc08

9 files changed

+35
-51
lines changed

examples/catboost/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@
2828
with mlflow.start_run() as run:
2929
signature = infer_signature(eval_data, model.predict(eval_data))
3030
mlflow.log_params(params)
31-
mlflow.catboost.log_model(model, artifact_path="model", signature=signature)
32-
model_uri = mlflow.get_artifact_uri("model")
31+
model_info = mlflow.catboost.log_model(model, artifact_path="model", signature=signature)
3332

3433
# Load model
35-
loaded_model = mlflow.catboost.load_model(model_uri)
34+
loaded_model = mlflow.catboost.load_model(model_info.model_uri)
3635

3736
# Get predictions
3837
preds = loaded_model.predict(eval_data)

examples/diviner/train.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def generate_data(location_data, start_dt) -> pd.DataFrame:
4747
return pd.concat(generated_listing).reset_index().drop("index", axis=1)
4848

4949

50-
def grouped_prophet_example(locations, start_dt, artifact_path):
50+
def grouped_prophet_example(locations, start_dt):
5151
print("Generating data...\n")
5252
data = generate_data(location_data=locations, start_dt=start_dt)
5353
grouping_keys = ["country", "city"]
@@ -73,7 +73,7 @@ def grouped_prophet_example(locations, start_dt, artifact_path):
7373
)
7474
print(f"Cross Validation Metrics: \n{metrics.to_string()}")
7575

76-
mlflow.diviner.log_model(diviner_model=model, artifact_path=artifact_path)
76+
model_info = mlflow.diviner.log_model(diviner_model=model, artifact_path="diviner_model")
7777

7878
# As an Alternative to saving metrics and params directly with a `log_dict()` function call,
7979
# Serializing the DataFrames to local as a .csv can be done as well, without requiring
@@ -99,7 +99,7 @@ def grouped_prophet_example(locations, start_dt, artifact_path):
9999

100100
mlflow.log_dict(metrics.to_dict(), "metrics.json")
101101

102-
return mlflow.get_artifact_uri(artifact_path=artifact_path)
102+
return model_info.model_uri
103103

104104

105105
if __name__ == "__main__":
@@ -112,10 +112,9 @@ def grouped_prophet_example(locations, start_dt, artifact_path):
112112
("MX", "MexicoCity"),
113113
]
114114
start_dt = "2022-02-01 04:11:35"
115-
artifact_path = "diviner_model"
116115

117116
with mlflow.start_run():
118-
uri = grouped_prophet_example(locations, start_dt, artifact_path)
117+
uri = grouped_prophet_example(locations, start_dt)
119118

120119
loaded_model = mlflow.diviner.load_model(model_uri=uri)
121120

examples/evaluation/evaluate_on_binary_classifier.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,11 @@
2424

2525
with mlflow.start_run() as run:
2626
# Log the XGBoost binary classifier model to MLflow
27-
mlflow.sklearn.log_model(model, "model", signature=signature)
28-
model_uri = mlflow.get_artifact_uri("model")
27+
model_info = mlflow.sklearn.log_model(model, "model", signature=signature)
2928

3029
# Evaluate the logged model
3130
result = mlflow.evaluate(
32-
model_uri,
31+
model_info.model_uri,
3332
eval_data,
3433
targets="label",
3534
model_type="classifier",

examples/evaluation/evaluate_on_multiclass_classifier.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,15 @@
44

55
import mlflow
66

7-
mlflow.sklearn.autolog()
8-
97
X, y = make_classification(n_samples=10000, n_classes=10, n_informative=5, random_state=1)
108

119
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
1210

1311
with mlflow.start_run() as run:
1412
model = LogisticRegression(solver="liblinear").fit(X_train, y_train)
15-
model_uri = mlflow.get_artifact_uri("model")
13+
model_info = mlflow.sklearn.log_model(model, "model")
1614
result = mlflow.evaluate(
17-
model_uri,
15+
model_info.model_uri,
1816
X_test,
1917
targets=y_test,
2018
model_type="classifier",

examples/evaluation/evaluate_on_regressor.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
import mlflow
66

7-
mlflow.sklearn.autolog()
8-
97
california_housing_data = fetch_california_housing()
108

119
X_train, X_test, y_train, y_test = train_test_split(
@@ -14,10 +12,10 @@
1412

1513
with mlflow.start_run() as run:
1614
model = LinearRegression().fit(X_train, y_train)
17-
model_uri = mlflow.get_artifact_uri("model")
15+
model_info = mlflow.sklearn.log_model(model, "model")
1816

1917
result = mlflow.evaluate(
20-
model_uri,
18+
model_info.model_uri,
2119
X_test,
2220
targets=y_test,
2321
model_type="regressor",

examples/evaluation/evaluate_with_custom_metrics.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,9 @@ def prediction_target_scatter(eval_df, _builtin_metrics, artifacts_dir):
6161

6262

6363
with mlflow.start_run() as run:
64-
mlflow.sklearn.log_model(lin_reg, "model", signature=signature)
65-
model_uri = mlflow.get_artifact_uri("model")
64+
model_info = mlflow.sklearn.log_model(lin_reg, "model", signature=signature)
6665
result = mlflow.evaluate(
67-
model=model_uri,
66+
model=model_info.model_uri,
6867
data=eval_data,
6968
targets="target",
7069
model_type="regressor",

examples/evaluation/evaluate_with_custom_metrics_comprehensive.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,9 @@ def custom_artifact(eval_df, builtin_metrics, _artifacts_dir):
6262

6363

6464
with mlflow.start_run() as run:
65-
mlflow.sklearn.log_model(lin_reg, "model", signature=signature)
66-
model_uri = mlflow.get_artifact_uri("model")
65+
model_info = mlflow.sklearn.log_model(lin_reg, "model", signature=signature)
6766
result = mlflow.evaluate(
68-
model=model_uri,
67+
model=model_info.model_uri,
6968
data=eval_data,
7069
targets="target",
7170
model_type="regressor",

examples/pip_requirements/pip_requirements.py

+16-22
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@ def read_lines(path):
1919
return f.read().splitlines()
2020

2121

22-
def get_pip_requirements(run_id, artifact_path, return_constraints=False):
23-
req_path = download_artifacts(run_id=run_id, artifact_path=f"{artifact_path}/requirements.txt")
22+
def get_pip_requirements(artifact_uri, return_constraints=False):
23+
req_path = download_artifacts(artifact_uri=f"{artifact_uri}/requirements.txt")
2424
reqs = read_lines(req_path)
2525

2626
if return_constraints:
27-
con_path = download_artifacts(
28-
run_id=run_id, artifact_path=f"{artifact_path}/constraints.txt"
29-
)
27+
con_path = download_artifacts(artifact_uri=f"{artifact_uri}/constraints.txt")
3028
cons = read_lines(con_path)
3129
return set(reqs), set(cons)
3230

@@ -43,30 +41,28 @@ def main():
4341
xgb_req = f"xgboost=={xgb.__version__}"
4442
sklearn_req = f"scikit-learn=={sklearn.__version__}"
4543

46-
with mlflow.start_run() as run:
47-
run_id = run.info.run_id
48-
44+
with mlflow.start_run():
4945
# Default (both `pip_requirements` and `extra_pip_requirements` are unspecified)
5046
artifact_path = "default"
51-
mlflow.xgboost.log_model(model, artifact_path, signature=signature)
52-
pip_reqs = get_pip_requirements(run_id, artifact_path)
47+
model_info = mlflow.xgboost.log_model(model, artifact_path, signature=signature)
48+
pip_reqs = get_pip_requirements(model_info.model_uri)
5349
assert xgb_req in pip_reqs, pip_reqs
5450

5551
# Overwrite the default set of pip requirements using `pip_requirements`
5652
artifact_path = "pip_requirements"
57-
mlflow.xgboost.log_model(
53+
model_info = mlflow.xgboost.log_model(
5854
model, artifact_path, pip_requirements=[sklearn_req], signature=signature
5955
)
60-
pip_reqs = get_pip_requirements(run_id, artifact_path)
56+
pip_reqs = get_pip_requirements(model_info.model_uri)
6157
assert sklearn_req in pip_reqs, pip_reqs
6258

6359
# Add extra pip requirements on top of the default set of pip requirements
6460
# using `extra_pip_requirements`
6561
artifact_path = "extra_pip_requirements"
66-
mlflow.xgboost.log_model(
62+
model_info = mlflow.xgboost.log_model(
6763
model, artifact_path, extra_pip_requirements=[sklearn_req], signature=signature
6864
)
69-
pip_reqs = get_pip_requirements(run_id, artifact_path)
65+
pip_reqs = get_pip_requirements(model_info.model_uri)
7066
assert pip_reqs.issuperset({xgb_req, sklearn_req}), pip_reqs
7167

7268
# Specify pip requirements using a requirements file
@@ -76,21 +72,21 @@ def main():
7672

7773
# Path to a pip requirements file
7874
artifact_path = "requirements_file_path"
79-
mlflow.xgboost.log_model(
75+
model_info = mlflow.xgboost.log_model(
8076
model, artifact_path, pip_requirements=f.name, signature=signature
8177
)
82-
pip_reqs = get_pip_requirements(run_id, artifact_path)
78+
pip_reqs = get_pip_requirements(model_info.model_uri)
8379
assert sklearn_req in pip_reqs, pip_reqs
8480

8581
# List of pip requirement strings
8682
artifact_path = "requirements_file_list"
87-
mlflow.xgboost.log_model(
83+
model_info = mlflow.xgboost.log_model(
8884
model,
8985
artifact_path,
9086
pip_requirements=[xgb_req, f"-r {f.name}"],
9187
signature=signature,
9288
)
93-
pip_reqs = get_pip_requirements(run_id, artifact_path)
89+
pip_reqs = get_pip_requirements(model_info.model_uri)
9490
assert pip_reqs.issuperset({xgb_req, sklearn_req}), pip_reqs
9591

9692
# Using a constraints file
@@ -99,15 +95,13 @@ def main():
9995
f.flush()
10096

10197
artifact_path = "constraints_file"
102-
mlflow.xgboost.log_model(
98+
model_info = mlflow.xgboost.log_model(
10399
model,
104100
artifact_path,
105101
pip_requirements=[xgb_req, f"-c {f.name}"],
106102
signature=signature,
107103
)
108-
pip_reqs, pip_cons = get_pip_requirements(
109-
run_id, artifact_path, return_constraints=True
110-
)
104+
pip_reqs, pip_cons = get_pip_requirements(model_info.model_uri, return_constraints=True)
111105
assert pip_reqs.issuperset({xgb_req, "-c constraints.txt"}), pip_reqs
112106
assert pip_cons == {sklearn_req}, pip_cons
113107

examples/pmdarima/train.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,15 @@ def calculate_cv_metrics(model, endog, metric, cv):
4848
predictions = arima.predict(n_periods=30, return_conf_int=False)
4949
signature = infer_signature(train, predictions)
5050

51-
mlflow.pmdarima.log_model(
51+
model_info = mlflow.pmdarima.log_model(
5252
pmdarima_model=arima, artifact_path=ARTIFACT_PATH, signature=signature
5353
)
5454
mlflow.log_params(parameters)
5555
mlflow.log_metrics(metrics)
56-
model_uri = mlflow.get_artifact_uri(ARTIFACT_PATH)
5756

58-
print(f"Model artifact logged to: {model_uri}")
57+
print(f"Model artifact logged to: {model_info.model_uri}")
5958

60-
loaded_model = mlflow.pmdarima.load_model(model_uri)
59+
loaded_model = mlflow.pmdarima.load_model(model_info.model_uri)
6160

6261
forecast = loaded_model.predict(30)
6362

0 commit comments

Comments
 (0)