17
17
from tensorflow .keras .callbacks import Callback
18
18
19
19
import simvue_tensorflow .extras .operators as operators
20
+ from simvue_tensorflow .extras .create_alerts import create_alerts
20
21
21
22
22
23
class TensorVue (Callback ):
@@ -110,8 +111,6 @@ def __init__(
110
111
Raised if the ML Optimisation framework is not enabled and no run name was provided
111
112
KeyError
112
113
Raised if attempted to add an alert to a run which was not defined
113
- RuntimeError
114
- Raised if a valid source could not be deduced from alert definition
115
114
116
115
"""
117
116
if not optimisation_framework and not run_name :
@@ -125,6 +124,7 @@ def __init__(
125
124
self .run_tags = run_tags or []
126
125
self .run_metadata = run_metadata or {}
127
126
self .run_mode = run_mode
127
+ self .alert_definitions = alert_definitions
128
128
self .script_filepath = script_filepath
129
129
self .model_checkpoint_filepath = model_checkpoint_filepath
130
130
self .model_final_filepath = model_final_filepath
@@ -136,32 +136,14 @@ def __init__(
136
136
self .simulation_run = simulation_run
137
137
self .eval_run = evaluation_run
138
138
139
- self . alerts = {}
139
+ # Create alerts in a disabled run up front, to validate that they have been defined accurately
140
140
if alert_definitions :
141
141
with simvue .Run (mode = "disabled" ) as temp_run :
142
- for alert_name , alert_definition in alert_definitions .items ():
143
- _source = alert_definition .pop ("source" )
144
- if _source == "events" :
145
- _alert_id = temp_run .create_event_alert (
146
- name = alert_name , ** alert_definition , attach_to_run = False
147
- )
148
- elif _source == "metrics" and alert_definition .get ("threshold" ):
149
- _alert_id = temp_run .create_metric_threshold_alert (
150
- name = alert_name , ** alert_definition , attach_to_run = False
151
- )
152
- elif _source == "metrics" :
153
- _alert_id = temp_run .create_metric_range_alert (
154
- name = alert_name , ** alert_definition , attach_to_run = False
155
- )
156
- elif _source == "user" :
157
- _alert_id = temp_run .create_user_alert (
158
- name = alert_name , ** alert_definition , attach_to_run = False
159
- )
160
- else :
161
- raise RuntimeError (
162
- f"{ alert_name } has unknown source type '{ _source } '"
163
- )
164
- self .alerts [alert_name ] = _alert_id
142
+ temp_run ._user_config .run .mode == self .run_mode
143
+ [
144
+ create_alerts (alert_name , alert_definition , temp_run )
145
+ for alert_name , alert_definition in alert_definitions .items ()
146
+ ]
165
147
166
148
self .manifest_alerts = manifest_alerts or []
167
149
self .simulation_alerts = simulation_alerts or []
@@ -175,7 +157,7 @@ def __init__(
175
157
+ self .evaluation_alerts
176
158
+ self .manifest_alerts
177
159
):
178
- if alert_name not in self .alerts .keys ():
160
+ if alert_name not in self .alert_definitions .keys ():
179
161
raise KeyError (
180
162
f"Alert name { alert_name } not present in alert definitions."
181
163
)
@@ -202,10 +184,10 @@ def create_manifest_run(self) -> simvue.Run:
202
184
description = self .run_description ,
203
185
metadata = self .run_metadata ,
204
186
)
205
- if self . manifest_alerts :
206
- manifest_run . add_alerts (
207
- [ self . alerts [ alert_name ] for alert_name in self .manifest_alerts ]
208
- )
187
+ [
188
+ create_alerts ( alert_name , self . alert_definitions [ alert_name ], manifest_run )
189
+ for alert_name in self .manifest_alerts
190
+ ]
209
191
210
192
if self .script_filepath :
211
193
manifest_run .save_file (
@@ -262,10 +244,12 @@ def on_train_begin(self, logs: dict):
262
244
263
245
self .simulation_run .update_metadata (self .params )
264
246
265
- if self . simulation_alerts :
266
- self . simulation_run . add_alerts (
267
- [ self .alerts [alert_name ] for alert_name in self .simulation_alerts ]
247
+ [
248
+ create_alerts (
249
+ alert_name , self .alert_definitions [alert_name ], self .simulation_run
268
250
)
251
+ for alert_name in self .simulation_alerts
252
+ ]
269
253
270
254
if self .script_filepath :
271
255
self .simulation_run .save_file (
@@ -336,9 +320,12 @@ def on_epoch_begin(self, epoch: int, logs: dict) -> None:
336
320
)
337
321
338
322
if self .epoch_alerts and epoch + 1 >= self .start_alerts_from_epoch :
339
- self .epoch_run .add_alerts (
340
- [self .alerts [alert_name ] for alert_name in self .epoch_alerts ]
341
- )
323
+ [
324
+ create_alerts (
325
+ alert_name , self .alert_definitions [alert_name ], self .epoch_run
326
+ )
327
+ for alert_name in self .epoch_alerts
328
+ ]
342
329
343
330
if epoch > 0 :
344
331
self .epoch_run .log_event ("Accuracy and Loss values before epoch training:" )
@@ -534,10 +521,12 @@ def on_test_begin(self, logs: dict):
534
521
"evaluation" ,
535
522
]
536
523
)
537
- if self . evaluation_alerts :
538
- self . eval_run . add_alerts (
539
- [ self .alerts [alert_name ] for alert_name in self .evaluation_alerts ]
524
+ [
525
+ create_alerts (
526
+ alert_name , self .alert_definitions [alert_name ], self .eval_run
540
527
)
528
+ for alert_name in self .evaluation_alerts
529
+ ]
541
530
542
531
if self .script_filepath :
543
532
self .eval_run .save_file (
0 commit comments