Skip to content

Commit dbe81e1

Browse files
authored
Use ModelInfo.model_uri in h2o tests (mlflow#14639)
Signed-off-by: harupy <[email protected]>
1 parent 6a5fc08 commit dbe81e1

File tree

1 file changed

+34
-39
lines changed

1 file changed

+34
-39
lines changed

tests/h2o/test_h2o_model_export.py

+34-39
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,8 @@ def test_model_log(h2o_iris_model):
131131
try:
132132
artifact_path = "gbm_model"
133133
model_info = mlflow.h2o.log_model(h2o_model, artifact_path)
134-
model_uri = f"runs:/{mlflow.active_run().info.run_id}/{artifact_path}"
135-
assert model_info.model_uri == model_uri
136134
# Load model
137-
h2o_model_loaded = mlflow.h2o.load_model(model_uri=model_uri)
135+
h2o_model_loaded = mlflow.h2o.load_model(model_uri=model_info.model_uri)
138136
assert all(
139137
h2o_model_loaded.predict(h2o_iris_model.inference_data).as_data_frame()
140138
== h2o_model.predict(h2o_iris_model.inference_data).as_data_frame()
@@ -199,29 +197,29 @@ def test_log_model_with_pip_requirements(h2o_iris_model, tmp_path):
199197
req_file = tmp_path.joinpath("requirements.txt")
200198
req_file.write_text("a")
201199
with mlflow.start_run():
202-
mlflow.h2o.log_model(h2o_iris_model.model, "model", pip_requirements=str(req_file))
203-
_assert_pip_requirements(
204-
mlflow.get_artifact_uri("model"), [expected_mlflow_version, "a"], strict=True
200+
model_info = mlflow.h2o.log_model(
201+
h2o_iris_model.model, "model", pip_requirements=str(req_file)
205202
)
203+
_assert_pip_requirements(model_info.model_uri, [expected_mlflow_version, "a"], strict=True)
206204

207205
# List of requirements
208206
with mlflow.start_run():
209-
mlflow.h2o.log_model(
207+
model_info = mlflow.h2o.log_model(
210208
h2o_iris_model.model,
211209
"model",
212210
pip_requirements=[f"-r {req_file}", "b"],
213211
)
214212
_assert_pip_requirements(
215-
mlflow.get_artifact_uri("model"), [expected_mlflow_version, "a", "b"], strict=True
213+
model_info.model_uri, [expected_mlflow_version, "a", "b"], strict=True
216214
)
217215

218216
# Constraints file
219217
with mlflow.start_run():
220-
mlflow.h2o.log_model(
218+
model_info = mlflow.h2o.log_model(
221219
h2o_iris_model.model, "model", pip_requirements=[f"-c {req_file}", "b"]
222220
)
223221
_assert_pip_requirements(
224-
mlflow.get_artifact_uri("model"),
222+
model_info.model_uri,
225223
[expected_mlflow_version, "b", "-c constraints.txt"],
226224
["a"],
227225
strict=True,
@@ -236,27 +234,29 @@ def test_log_model_with_extra_pip_requirements(h2o_iris_model, tmp_path):
236234
req_file = tmp_path.joinpath("requirements.txt")
237235
req_file.write_text("a")
238236
with mlflow.start_run():
239-
mlflow.h2o.log_model(h2o_iris_model.model, "model", extra_pip_requirements=str(req_file))
237+
model_info = mlflow.h2o.log_model(
238+
h2o_iris_model.model, "model", extra_pip_requirements=str(req_file)
239+
)
240240
_assert_pip_requirements(
241-
mlflow.get_artifact_uri("model"), [expected_mlflow_version, *default_reqs, "a"]
241+
model_info.model_uri, [expected_mlflow_version, *default_reqs, "a"]
242242
)
243243

244244
# List of requirements
245245
with mlflow.start_run():
246-
mlflow.h2o.log_model(
246+
model_info = mlflow.h2o.log_model(
247247
h2o_iris_model.model, "model", extra_pip_requirements=[f"-r {req_file}", "b"]
248248
)
249249
_assert_pip_requirements(
250-
mlflow.get_artifact_uri("model"), [expected_mlflow_version, *default_reqs, "a", "b"]
250+
model_info.model_uri, [expected_mlflow_version, *default_reqs, "a", "b"]
251251
)
252252

253253
# Constraints file
254254
with mlflow.start_run():
255-
mlflow.h2o.log_model(
255+
model_info = mlflow.h2o.log_model(
256256
h2o_iris_model.model, "model", extra_pip_requirements=[f"-c {req_file}", "b"]
257257
)
258258
_assert_pip_requirements(
259-
mlflow.get_artifact_uri("model"),
259+
model_info.model_uri,
260260
[expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"],
261261
["a"],
262262
)
@@ -281,11 +281,11 @@ def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(
281281
):
282282
artifact_path = "model"
283283
with mlflow.start_run():
284-
mlflow.h2o.log_model(h2o_iris_model.model, artifact_path, conda_env=h2o_custom_env)
285-
model_path = _download_artifact_from_uri(
286-
f"runs:/{mlflow.active_run().info.run_id}/{artifact_path}"
284+
model_info = mlflow.h2o.log_model(
285+
h2o_iris_model.model, artifact_path, conda_env=h2o_custom_env
287286
)
288287

288+
model_path = _download_artifact_from_uri(model_info.model_uri)
289289
pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
290290
saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
291291
assert os.path.exists(saved_conda_env_path)
@@ -301,10 +301,10 @@ def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(
301301
def test_model_log_persists_requirements_in_mlflow_model_directory(h2o_iris_model, h2o_custom_env):
302302
artifact_path = "model"
303303
with mlflow.start_run():
304-
mlflow.h2o.log_model(h2o_iris_model.model, artifact_path, conda_env=h2o_custom_env)
305-
model_path = _download_artifact_from_uri(
306-
f"runs:/{mlflow.active_run().info.run_id}/{artifact_path}"
304+
model_info = mlflow.h2o.log_model(
305+
h2o_iris_model.model, artifact_path, conda_env=h2o_custom_env
307306
)
307+
model_path = _download_artifact_from_uri(model_info.model_uri)
308308

309309
saved_pip_req_path = os.path.join(model_path, "requirements.txt")
310310
_compare_conda_env_requirements(h2o_custom_env, saved_pip_req_path)
@@ -322,9 +322,8 @@ def test_model_log_without_specified_conda_env_uses_default_env_with_expected_de
322322
):
323323
artifact_path = "model"
324324
with mlflow.start_run():
325-
mlflow.h2o.log_model(h2o_iris_model.model, artifact_path)
326-
model_uri = mlflow.get_artifact_uri(artifact_path)
327-
_assert_pip_requirements(model_uri, mlflow.h2o.get_default_pip_requirements())
325+
model_info = mlflow.h2o.log_model(h2o_iris_model.model, artifact_path)
326+
_assert_pip_requirements(model_info.model_uri, mlflow.h2o.get_default_pip_requirements())
328327

329328

330329
def test_pyfunc_serve_and_score(h2o_iris_model):
@@ -348,15 +347,13 @@ def test_pyfunc_serve_and_score(h2o_iris_model):
348347

349348

350349
def test_log_model_with_code_paths(h2o_iris_model):
351-
artifact_path = "model_uri"
352350
with (
353351
mlflow.start_run(),
354352
mock.patch("mlflow.h2o._add_code_from_conf_to_system_path") as add_mock,
355353
):
356-
mlflow.h2o.log_model(h2o_iris_model.model, artifact_path, code_paths=[__file__])
357-
model_uri = mlflow.get_artifact_uri(artifact_path)
358-
_compare_logged_code_paths(__file__, model_uri, mlflow.h2o.FLAVOR_NAME)
359-
mlflow.h2o.load_model(model_uri)
354+
model_info = mlflow.h2o.log_model(h2o_iris_model.model, "model_uri", code_paths=[__file__])
355+
_compare_logged_code_paths(__file__, model_info.model_uri, mlflow.h2o.FLAVOR_NAME)
356+
mlflow.h2o.load_model(model_info.model_uri)
360357
add_mock.assert_called()
361358

362359

@@ -370,17 +367,14 @@ def test_model_save_load_with_metadata(h2o_iris_model, model_path):
370367

371368

372369
def test_model_log_with_metadata(h2o_iris_model):
373-
artifact_path = "model"
374-
375370
with mlflow.start_run():
376-
mlflow.h2o.log_model(
371+
model_info = mlflow.h2o.log_model(
377372
h2o_iris_model.model,
378-
artifact_path,
373+
"model",
379374
metadata={"metadata_key": "metadata_value"},
380375
)
381-
model_uri = mlflow.get_artifact_uri(artifact_path)
382376

383-
reloaded_model = mlflow.pyfunc.load_model(model_uri=model_uri)
377+
reloaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
384378
assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
385379

386380

@@ -389,8 +383,9 @@ def test_model_log_with_signature_inference(h2o_iris_model, h2o_iris_model_signa
389383
example = h2o_iris_model.inference_data.as_data_frame().head(3)
390384

391385
with mlflow.start_run():
392-
mlflow.h2o.log_model(h2o_iris_model.model, artifact_path, input_example=example)
393-
model_uri = mlflow.get_artifact_uri(artifact_path)
386+
model_info = mlflow.h2o.log_model(
387+
h2o_iris_model.model, artifact_path, input_example=example
388+
)
394389

395-
mlflow_model = Model.load(model_uri)
390+
mlflow_model = Model.load(model_info.model_uri)
396391
assert mlflow_model.signature == h2o_iris_model_signature

0 commit comments

Comments
 (0)