Skip to content

Commit e2f8704

Browse files
committed
Make offline mode work
1 parent e02f6db commit e2f8704

File tree

6 files changed

+93
-88
lines changed

6 files changed

+93
-88
lines changed

examples/detailed_integration.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
import simvue_tensorflow.plugin as sv_tf
2424

2525
def tensorflow_example(run_folder, offline=False):
26+
# Delete results from previous run, if they exist:
27+
pathlib.Path(__file__).parent.joinpath("results").unlink(missing_ok=True)
28+
2629
# Load the training and test data
2730
(img_train, label_train), (img_test, label_test) = keras.datasets.fashion_mnist.load_data()
2831

@@ -43,8 +46,8 @@ def tensorflow_example(run_folder, offline=False):
4346

4447
# Can use the ModelCheckpoint callback, which is built into Tensorflow, to save a model after each Epoch
4548
# Providing the model_checkpoint_filepath in the TensorVue callback means it will automatically upload checkpoints to the Epoch runs
46-
temp_dir = tempfile.TemporaryDirectory()
47-
checkpoint_filepath = str(pathlib.Path(temp_dir.name).joinpath("checkpoint.model.keras"))
49+
results_dir = pathlib.Path(__file__).parent.joinpath("results")
50+
checkpoint_filepath = str(pathlib.Path(results_dir.name).joinpath("checkpoint.model.keras"))
4851
model_checkpoint_callback = ModelCheckpoint(
4952
filepath=checkpoint_filepath, save_best_only=False, verbose=1
5053
)
@@ -85,7 +88,7 @@ def tensorflow_example(run_folder, offline=False):
8588
evaluation_target=0.99,
8689

8790
# Choose where the final model is saved
88-
model_final_filepath=str(pathlib.Path(temp_dir.name).joinpath("tf_fashion_model.keras"))
91+
model_final_filepath=str(pathlib.Path(results_dir.name).joinpath("tf_fashion_model.keras"))
8992
)
9093

9194
# Fit and evaluate the model, including the tensorvue callback:

examples/example.py

Lines changed: 0 additions & 41 deletions
This file was deleted.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Create Alerts.
2+
3+
Function for creating alerts based on definitions provided by user
4+
5+
"""
6+
7+
import typing
8+
9+
import simvue
10+
11+
12+
def create_alerts(
13+
alert_name: str, alert_definition: dict[str, typing.Any], run: simvue.Run
14+
) -> None:
15+
"""Create alerts from their definitions provided in to TensorVue.
16+
17+
Parameters
18+
----------
19+
alert_name : str
20+
Name of the alert to create
21+
alert_definition : dict[str, typing.Any]
22+
Definition of the alert, passed into the relevant Run method as kwargs
23+
run : simvue.Run
24+
The run to add the alerts to
25+
26+
Raises
27+
------
28+
RuntimeError
29+
Raised if a valid source could not be deduced from alert definition
30+
31+
"""
32+
alert_definition = alert_definition.copy()
33+
_source = alert_definition.pop("source")
34+
if _source == "events":
35+
_alert_id = run.create_event_alert(
36+
name=alert_name,
37+
**alert_definition,
38+
)
39+
elif _source == "metrics" and alert_definition.get("threshold"):
40+
_alert_id = run.create_metric_threshold_alert(
41+
name=alert_name,
42+
**alert_definition,
43+
)
44+
elif _source == "metrics":
45+
_alert_id = run.create_metric_range_alert(
46+
name=alert_name,
47+
**alert_definition,
48+
)
49+
elif _source == "user":
50+
_alert_id = run.create_user_alert(
51+
name=alert_name,
52+
**alert_definition,
53+
)
54+
else:
55+
raise RuntimeError(f"{alert_name} has unknown source type '{_source}'")

simvue_tensorflow/extras/operators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
"""Common.
1+
"""Operators.
22
3-
Classses which could be used in the construction of multiple adapters.
3+
Operators used in the definition of a TensorVue instance
44
"""
55
# ruff: noqa: DOC201
66

simvue_tensorflow/plugin.py

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from tensorflow.keras.callbacks import Callback
1818

1919
import simvue_tensorflow.extras.operators as operators
20+
from simvue_tensorflow.extras.create_alerts import create_alerts
2021

2122

2223
class TensorVue(Callback):
@@ -110,8 +111,6 @@ def __init__(
110111
Raised if the ML Optimisation framework is not enabled and no run name was provided
111112
KeyError
112113
Raised if attempted to add an alert to a run which was not defined
113-
RuntimeError
114-
Raised if a valid source could not be deduced from alert definition
115114
116115
"""
117116
if not optimisation_framework and not run_name:
@@ -125,6 +124,7 @@ def __init__(
125124
self.run_tags = run_tags or []
126125
self.run_metadata = run_metadata or {}
127126
self.run_mode = run_mode
127+
self.alert_definitions = alert_definitions
128128
self.script_filepath = script_filepath
129129
self.model_checkpoint_filepath = model_checkpoint_filepath
130130
self.model_final_filepath = model_final_filepath
@@ -136,32 +136,14 @@ def __init__(
136136
self.simulation_run = simulation_run
137137
self.eval_run = evaluation_run
138138

139-
self.alerts = {}
139+
# Create alerts in a disabled run up front, to validate that they have been defined accurately
140140
if alert_definitions:
141141
with simvue.Run(mode="disabled") as temp_run:
142-
for alert_name, alert_definition in alert_definitions.items():
143-
_source = alert_definition.pop("source")
144-
if _source == "events":
145-
_alert_id = temp_run.create_event_alert(
146-
name=alert_name, **alert_definition, attach_to_run=False
147-
)
148-
elif _source == "metrics" and alert_definition.get("threshold"):
149-
_alert_id = temp_run.create_metric_threshold_alert(
150-
name=alert_name, **alert_definition, attach_to_run=False
151-
)
152-
elif _source == "metrics":
153-
_alert_id = temp_run.create_metric_range_alert(
154-
name=alert_name, **alert_definition, attach_to_run=False
155-
)
156-
elif _source == "user":
157-
_alert_id = temp_run.create_user_alert(
158-
name=alert_name, **alert_definition, attach_to_run=False
159-
)
160-
else:
161-
raise RuntimeError(
162-
f"{alert_name} has unknown source type '{_source}'"
163-
)
164-
self.alerts[alert_name] = _alert_id
142+
temp_run._user_config.run.mode == self.run_mode
143+
[
144+
create_alerts(alert_name, alert_definition, temp_run)
145+
for alert_name, alert_definition in alert_definitions.items()
146+
]
165147

166148
self.manifest_alerts = manifest_alerts or []
167149
self.simulation_alerts = simulation_alerts or []
@@ -175,7 +157,7 @@ def __init__(
175157
+ self.evaluation_alerts
176158
+ self.manifest_alerts
177159
):
178-
if alert_name not in self.alerts.keys():
160+
if alert_name not in self.alert_definitions.keys():
179161
raise KeyError(
180162
f"Alert name {alert_name} not present in alert definitions."
181163
)
@@ -202,10 +184,10 @@ def create_manifest_run(self) -> simvue.Run:
202184
description=self.run_description,
203185
metadata=self.run_metadata,
204186
)
205-
if self.manifest_alerts:
206-
manifest_run.add_alerts(
207-
[self.alerts[alert_name] for alert_name in self.manifest_alerts]
208-
)
187+
[
188+
create_alerts(alert_name, self.alert_definitions[alert_name], manifest_run)
189+
for alert_name in self.manifest_alerts
190+
]
209191

210192
if self.script_filepath:
211193
manifest_run.save_file(
@@ -262,10 +244,12 @@ def on_train_begin(self, logs: dict):
262244

263245
self.simulation_run.update_metadata(self.params)
264246

265-
if self.simulation_alerts:
266-
self.simulation_run.add_alerts(
267-
[self.alerts[alert_name] for alert_name in self.simulation_alerts]
247+
[
248+
create_alerts(
249+
alert_name, self.alert_definitions[alert_name], self.simulation_run
268250
)
251+
for alert_name in self.simulation_alerts
252+
]
269253

270254
if self.script_filepath:
271255
self.simulation_run.save_file(
@@ -336,9 +320,12 @@ def on_epoch_begin(self, epoch: int, logs: dict) -> None:
336320
)
337321

338322
if self.epoch_alerts and epoch + 1 >= self.start_alerts_from_epoch:
339-
self.epoch_run.add_alerts(
340-
[self.alerts[alert_name] for alert_name in self.epoch_alerts]
341-
)
323+
[
324+
create_alerts(
325+
alert_name, self.alert_definitions[alert_name], self.epoch_run
326+
)
327+
for alert_name in self.epoch_alerts
328+
]
342329

343330
if epoch > 0:
344331
self.epoch_run.log_event("Accuracy and Loss values before epoch training:")
@@ -534,10 +521,12 @@ def on_test_begin(self, logs: dict):
534521
"evaluation",
535522
]
536523
)
537-
if self.evaluation_alerts:
538-
self.eval_run.add_alerts(
539-
[self.alerts[alert_name] for alert_name in self.evaluation_alerts]
524+
[
525+
create_alerts(
526+
alert_name, self.alert_definitions[alert_name], self.eval_run
540527
)
528+
for alert_name in self.evaluation_alerts
529+
]
541530

542531
if self.script_filepath:
543532
self.eval_run.save_file(

tests/integration/test_tensorflow.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
@pytest.mark.parametrize("offline", (True, False), ids=("offline", "online"))
99
def test_tensorflow_connector(folder_setup, offline):
1010

11-
run_name = tensorflow_example(folder_setup)
12-
11+
run_name = tensorflow_example(folder_setup, offline)
1312
if offline:
1413
_id_mapping = sender()
1514

0 commit comments

Comments
 (0)