From c0a426c2e3f0b1ae1e29f92cdc807ba4abeb9100 Mon Sep 17 00:00:00 2001 From: dbftdiyoeywga Date: Fri, 6 Jul 2018 01:34:29 +0900 Subject: [PATCH 1/3] add multi metrics monitor --- hyperdash/experiment.py | 12 ++++++------ tests/test_sdk.py | 34 ++++++++++++++++++++-------------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/hyperdash/experiment.py b/hyperdash/experiment.py index 79f022e..e27c362 100644 --- a/hyperdash/experiment.py +++ b/hyperdash/experiment.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function, unicode_literals +import numpy as np import sys import uuid import threading @@ -217,13 +218,12 @@ def __init__(self, exp): def on_epoch_end(self, epoch, logs=None): if not logs: logs = {} - val_acc = logs.get("val_acc") - val_loss = logs.get("val_loss") + for k, v in logs.items(): + if isinstance(v, (np.ndarray, np.generic)): + self._exp.metric(k, v.item()) + else: + self._exp.metric(k, v) - if val_acc is not None: - self._exp.metric("val_acc", val_acc) - if val_loss is not None: - self._exp.metric("val_loss", val_loss) cb = _KerasCallback(self._exp) self._callbacks[KERAS] = cb return cb diff --git a/tests/test_sdk.py b/tests/test_sdk.py index 15a662d..fc612b8 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -131,7 +131,8 @@ def test_job(): # Make sure correct API name / version headers are sent assert server_sdk_headers[0][API_KEY_NAME] == API_NAME_MONITOR - assert server_sdk_headers[0][VERSION_KEY_NAME] == get_hyperdash_version() + assert server_sdk_headers[0][VERSION_KEY_NAME] == get_hyperdash_version( + ) def test_monitor_raises_exceptions(self): exception_raised = True @@ -276,11 +277,13 @@ def test_job(exp): expected_metrics = [ {"is_internal": False, "name": "acc", "value": 99.0}, {"is_internal": False, "name": "loss", "value": 0.00000000041}, - {"is_internal": False, "name": "val_loss", "value": 4324320984309284328743827432.0}, + {"is_internal": False, "name": "val_loss", + "value": 4324320984309284328743827432.0}, {"is_internal": False, "name": "mse", "value": -431.321}, {"is_internal": False, "name": "acc", "value": 97.0}, {"is_internal": False, "name": "loss", "value": -1.99999999959}, - {"is_internal": False, "name": "val_loss", "value": 4324320984309284328743827430.0}, + {"is_internal": False, "name": "val_loss", + "value": 4324320984309284328743827430.0}, {"is_internal": False, "name": "mse", "value": -433.321} ] for i, message in enumerate(sent_vals): @@ -334,7 +337,7 @@ def test_experiment(self): exp.metric("accuracy", i*0.2) time.sleep(0.1) exp.end() - + # Test params match what is expected params_messages = [] for msg in server_sdk_messages: @@ -355,7 +358,7 @@ def test_experiment(self): }, "is_internal": True, }, - ] + ] assert len(expect_params) == len(params_messages) for i, message in enumerate(params_messages): assert message == expect_params[i] @@ -372,20 +375,21 @@ def test_experiment(self): {"is_internal": False, "name": "accuracy", "value": 0}, {"is_internal": True, "name": "hd_iter_0", "value": 1}, {"is_internal": False, "name": "accuracy", "value": 0.2}, - ] + ] assert len(expect_metrics) == len(metrics_messages) for i, message in enumerate(metrics_messages): assert message["is_internal"] == expect_metrics[i]["is_internal"] assert message["name"] == expect_metrics[i]["name"] assert message["value"] == expect_metrics[i]["value"] - + captured_out = faked_out.getvalue() assert "error" not in captured_out # Make sure correct API name / version headers are sent assert server_sdk_headers[0][API_KEY_NAME] == API_NAME_EXPERIMENT - assert server_sdk_headers[0][VERSION_KEY_NAME] == get_hyperdash_version() - + assert server_sdk_headers[0][VERSION_KEY_NAME] == get_hyperdash_version( + ) + # Make sure logs were persisted expect_logs = [ "{ batch size: 32 }", @@ -410,10 +414,10 @@ def test_experiment_keras_callback(self): with patch("sys.stdout", new=StringIO()) as faked_out: exp = Experiment("MNIST") keras_cb = exp.callbacks.keras - keras_cb.on_epoch_end(0, {"val_acc": 1, "val_loss": 2}) + keras_cb.on_epoch_end(0, {"val_acc": 1, "val_loss": 2, "any_metrics": 10}) # Sleep 1 second due to client sampling time.sleep(1) - keras_cb.on_epoch_end(1, {"val_acc": 3, "val_loss": 4}) + keras_cb.on_epoch_end(1, {"val_acc": 3, "val_loss": 4, "any_metrics": 20}) exp.end() # Test metrics match what is expected @@ -425,15 +429,17 @@ def test_experiment_keras_callback(self): expect_metrics = [ {"is_internal": False, "name": "val_acc", "value": 1}, {"is_internal": False, "name": "val_loss", "value": 2}, + {"is_internal": False, "name": "any_metrics", "value": 10}, {"is_internal": False, "name": "val_acc", "value": 3}, {"is_internal": False, "name": "val_loss", "value": 4}, + {"is_internal": False, "name": "any_metrics", "value": 20}, ] assert len(expect_metrics) == len(metrics_messages) for i, message in enumerate(metrics_messages): assert message["is_internal"] == expect_metrics[i]["is_internal"] assert message["name"] == expect_metrics[i]["name"] assert message["value"] == expect_metrics[i]["value"] - + captured_out = faked_out.getvalue() assert "error" not in captured_out @@ -460,7 +466,7 @@ def test_experiment_handles_numpy_numbers(self): exp.metric("test_metric_{}".format(name), num) exp.param("test_param_{}".format(name), num) exp.end() - + # Test params match what is expected params_messages = [] for msg in server_sdk_messages: @@ -502,7 +508,7 @@ def test_experiment_handles_numpy_numbers(self): assert message["is_internal"] == expected_metrics[i]["is_internal"] assert message["name"] == expected_metrics[i]["name"] assert message["value"] == expected_metrics[i]["value"] - + def experiment_raises_exceptions(self): exception_raised = True expected_exception = "some_exception_b" From 6ea24f91e262c29651d53a2a6f7ba329d18cfee5 Mon Sep 17 00:00:00 2001 From: dbftdiyoeywga Date: Fri, 6 Jul 2018 01:42:55 +0900 Subject: [PATCH 2/3] fix doc format --- hyperdash/experiment.py | 4 +++- tests/test_sdk.py | 12 ++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/hyperdash/experiment.py b/hyperdash/experiment.py index e27c362..cf9f6e5 100644 --- a/hyperdash/experiment.py +++ b/hyperdash/experiment.py @@ -1,9 +1,11 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import numpy as np import sys import uuid import threading + +import numpy as np + from threading import Lock from datetime import datetime diff --git a/tests/test_sdk.py b/tests/test_sdk.py index fc612b8..ae9235e 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -131,8 +131,7 @@ def test_job(): # Make sure correct API name / version headers are sent assert server_sdk_headers[0][API_KEY_NAME] == API_NAME_MONITOR - assert server_sdk_headers[0][VERSION_KEY_NAME] == get_hyperdash_version( - ) + assert server_sdk_headers[0][VERSION_KEY_NAME] == get_hyperdash_version() def test_monitor_raises_exceptions(self): exception_raised = True @@ -277,13 +276,11 @@ def test_job(exp): expected_metrics = [ {"is_internal": False, "name": "acc", "value": 99.0}, {"is_internal": False, "name": "loss", "value": 0.00000000041}, - {"is_internal": False, "name": "val_loss", - "value": 4324320984309284328743827432.0}, + {"is_internal": False, "name": "val_loss", "value": 4324320984309284328743827432.0}, {"is_internal": False, "name": "mse", "value": -431.321}, {"is_internal": False, "name": "acc", "value": 97.0}, {"is_internal": False, "name": "loss", "value": -1.99999999959}, - {"is_internal": False, "name": "val_loss", - "value": 4324320984309284328743827430.0}, + {"is_internal": False, "name": "val_loss", "value": 4324320984309284328743827430.0}, {"is_internal": False, "name": "mse", "value": -433.321} ] for i, message in enumerate(sent_vals): @@ -387,8 +384,7 @@ def test_experiment(self): # Make sure correct API name / version headers are sent assert server_sdk_headers[0][API_KEY_NAME] == API_NAME_EXPERIMENT - assert server_sdk_headers[0][VERSION_KEY_NAME] == get_hyperdash_version( - ) + assert server_sdk_headers[0][VERSION_KEY_NAME] == get_hyperdash_version() # Make sure logs were persisted expect_logs = [ From c7567f30820f9736b2da045949f839af40b8ff57 Mon Sep 17 00:00:00 2001 From: dbftdiyoeywga Date: Fri, 6 Jul 2018 22:38:28 +0900 Subject: [PATCH 3/3] update for