10
10
import edsnlp
11
11
12
12
13
+ def flatten_dict (d , path = "" ):
14
+ if not isinstance (d , (list , dict )):
15
+ return {path : d }
16
+
17
+ if isinstance (d , list ):
18
+ items = enumerate (d )
19
+ else :
20
+ items = d .items ()
21
+
22
+ return {
23
+ k : v
24
+ for key , val in items
25
+ for k , v in flatten_dict (val , f"{ path } /{ key } " if path else key ).items ()
26
+ }
27
+
28
+
13
29
@edsnlp .registry .loggers .register ("csv" , auto_draft_in_config = True )
14
30
class CSVLogger (accelerate .tracking .GeneralTracker ):
15
31
name = "csv"
@@ -53,7 +69,7 @@ def __init__(
53
69
self ._has_header = False
54
70
55
71
@property
56
- def tracker (self ):
72
+ def tracker (self ): # pragma: no cover
57
73
return None
58
74
59
75
@accelerate .tracking .on_main_process
@@ -70,11 +86,9 @@ def log(self, values: Dict[str, Any], step: Optional[int] = None):
70
86
- All subsequent calls must use the same columns. Any missing columns get
71
87
written as empty, any new columns generate a warning.
72
88
"""
73
- # Ensure we have columns set
74
- print ( "LOGGING TO CSV" , self . file_path )
89
+ values = flatten_dict ( values )
90
+
75
91
if self ._columns is None :
76
- # Create the list of columns. We'll always reserve "step" as first if step
77
- # is provided.
78
92
self ._columns = list ({** {"step" : None }, ** values }.keys ())
79
93
self ._writer .writerow (self ._columns )
80
94
self ._has_header = True
@@ -142,7 +156,7 @@ def __init__(
142
156
self ._logs = []
143
157
144
158
@property
145
- def tracker (self ):
159
+ def tracker (self ): # pragma: no cover
146
160
return None
147
161
148
162
@accelerate .tracking .on_main_process
@@ -267,8 +281,7 @@ def log(self, values: Dict[str, Any], step: Optional[int] = None):
267
281
Logs values in the Rich table. If `step` is provided, we include it in the
268
282
logged data.
269
283
"""
270
- print ("LOGGING WITH RICH" )
271
- combined = {"step" : step , ** values }
284
+ combined = {"step" : step , ** flatten_dict (values )}
272
285
self .printer .log_metrics (combined )
273
286
274
287
@accelerate .tracking .on_main_process
@@ -280,36 +293,40 @@ def finish(self):
280
293
281
294
282
295
@edsnlp .registry .loggers .register ("tensorboard" , auto_draft_in_config = True )
283
- def TensorBoardLogger (
284
- project_name : str ,
285
- logging_dir : Optional [Union [str , os .PathLike ]] = None ,
286
- ** kwargs ,
287
- ) -> "accelerate.tracking.TensorBoardTracker" : # pragma: no cover
288
- """
289
- Logger for [TensorBoard](https://github.com/tensorflow/tensorboard).
290
- This logger is also available via the loggers registry as `tensorboard`.
296
+ class TensorBoardLogger (accelerate .tracking .TensorBoardTracker ):
297
+ def __init__ (
298
+ self ,
299
+ project_name : str ,
300
+ logging_dir : Optional [Union [str , os .PathLike ]] = None ,
301
+ ):
302
+ """
303
+ Logger for [TensorBoard](https://github.com/tensorflow/tensorboard).
304
+ This logger is also available via the loggers registry as `tensorboard`.
291
305
292
- Parameters
293
- ----------
294
- project_name: str
295
- Name of the project.
296
- logging_dir: Union[str, os.PathLike]
297
- Directory in which to store the TensorBoard logs. Logs of different runs
298
- will be stored in `logging_dir/project_name`. If not provided, the
299
- environment variable `TENSORBOARD_LOGGING_DIR` will be used.
300
- kwargs: Dict
301
- Additional keyword arguments to pass to `tensorboard.SummaryWriter`.
306
+ Parameters
307
+ ----------
308
+ project_name: str
309
+ Name of the project.
310
+ logging_dir: Union[str, os.PathLike]
311
+ Directory in which to store the TensorBoard logs. Logs of different runs
312
+ will be stored in `logging_dir/project_name`. If not provided, the
313
+ environment variable `TENSORBOARD_LOGGING_DIR` will be used.
314
+ kwargs: Dict
315
+ Additional keyword arguments to pass to `tensorboard.SummaryWriter`.
316
+ """
317
+ logging_dir = logging_dir or os .environ .get ("TENSORBOARD_LOGGING_DIR" , None )
318
+ assert logging_dir is not None , (
319
+ "Please provide a logging directory or set TENSORBOARD_LOGGING_DIR"
320
+ )
321
+ super ().__init__ (project_name , logging_dir )
302
322
303
- Returns
304
- -------
305
- accelerate.tracking.TensorBoardTracker
306
- """
307
- logging_dir = logging_dir or os .environ .get ("TENSORBOARD_LOGGING_DIR" , None )
308
- assert logging_dir is not None , (
309
- "Please provide a logging directory or set TENSORBOARD_LOGGING_DIR"
310
- )
323
+ def store_init_configuration (self , values : Dict [str , Any ]):
324
+ values = json .loads (json .dumps (flatten_dict (values ), default = str ))
325
+ return super ().store_init_configuration (values )
311
326
312
- return accelerate .tracking .TensorBoardTracker (project_name , logging_dir , ** kwargs )
327
+ def log (self , values : dict , step : Optional [int ] = None , ** kwargs ):
328
+ values = flatten_dict (values )
329
+ return super ().log (values , step , ** kwargs )
313
330
314
331
315
332
@edsnlp .registry .loggers .register ("aim" , auto_draft_in_config = True )
@@ -365,31 +382,6 @@ def WandBLogger(
365
382
return accelerate .tracking .WandBTracker (project_name , ** kwargs )
366
383
367
384
368
- @edsnlp .registry .loggers .register ("clearml" , auto_draft_in_config = True )
369
- def ClearMLLogger (
370
- project_name : str ,
371
- ** kwargs ,
372
- ) -> "accelerate.tracking.ClearMLTracker" : # pragma: no cover
373
- """
374
- Logger for
375
- [ClearML](https://clear.ml/docs/latest/docs/getting_started/ds/ds_first_steps/).
376
- This logger is also available via the loggers registry as `clearml`.
377
-
378
- Parameters
379
- ----------
380
- project_name: str
381
- Name of the experiment. Environment variables `CLEARML_PROJECT` and
382
- `CLEARML_TASK` have priority over this argument.
383
- kwargs: Dict
384
- Additional keyword arguments to pass to the ClearML Task object.
385
-
386
- Returns
387
- -------
388
- accelerate.tracking.ClearMLTracker
389
- """
390
- return accelerate .tracking .ClearMLTracker (project_name , ** kwargs )
391
-
392
-
393
385
@edsnlp .registry .loggers .register ("mlflow" , auto_draft_in_config = True )
394
386
def MLflowLogger (
395
387
project_name : str ,
@@ -471,24 +463,63 @@ def CometMLLogger(
471
463
return accelerate .tracking .CometMLTracker (project_name , ** kwargs )
472
464
473
465
474
- @edsnlp .registry .loggers .register ("dvclive" , auto_draft_in_config = True )
475
- def DVCLiveLogger (
476
- live : Any = None ,
477
- ** kwargs ,
478
- ) -> "accelerate.tracking.DVCLiveTracker" :
479
- """
480
- Logger for [DVC Live](https://dvc.org/doc/dvclive).
481
- This logger is also available via the loggers registry as `dvclive`.
466
+ try :
467
+ from accelerate .tracking import ClearMLTracker as _ClearMLTracker
482
468
483
- Parameters
484
- ----------
485
- live: dvclive.Live
486
- An instance of `dvclive.Live` to use for logging.
487
- kwargs: Dict
488
- Additional keyword arguments to pass to the `dvclive.Live` constructor.
469
+ @edsnlp .registry .loggers .register ("clearml" , auto_draft_in_config = True )
470
+ def ClearMLLogger (
471
+ project_name : str ,
472
+ ** kwargs ,
473
+ ) -> "accelerate.tracking.ClearMLTracker" : # pragma: no cover
474
+ """
475
+ Logger for
476
+ [ClearML](https://clear.ml/docs/latest/docs/getting_started/ds/ds_first_steps/).
477
+ This logger is also available via the loggers registry as `clearml`.
489
478
490
- Returns
491
- -------
492
- accelerate.tracking.DVCLiveTracker
493
- """
494
- return accelerate .tracking .DVCLiveTracker (None , live = live , ** kwargs )
479
+ Parameters
480
+ ----------
481
+ project_name: str
482
+ Name of the experiment. Environment variables `CLEARML_PROJECT` and
483
+ `CLEARML_TASK` have priority over this argument.
484
+ kwargs: Dict
485
+ Additional keyword arguments to pass to the ClearML Task object.
486
+
487
+ Returns
488
+ -------
489
+ accelerate.tracking.ClearMLTracker
490
+ """
491
+ return _ClearMLTracker (project_name , ** kwargs )
492
+ except ImportError : # pragma: no cover
493
+
494
+ def ClearMLLogger (* args , ** kwargs ):
495
+ raise ImportError ("ClearMLLogger is not available." )
496
+
497
+
498
+ try :
499
+ from accelerate .tracking import DVCLiveTracker as _DVCLiveTracker
500
+
501
+ @edsnlp .registry .loggers .register ("dvclive" , auto_draft_in_config = True )
502
+ def DVCLiveLogger (
503
+ live : Any = None ,
504
+ ** kwargs ,
505
+ ) -> "accelerate.tracking.DVCLiveTracker" : # pragma: no cover
506
+ """
507
+ Logger for [DVC Live](https://dvc.org/doc/dvclive).
508
+ This logger is also available via the loggers registry as `dvclive`.
509
+
510
+ Parameters
511
+ ----------
512
+ live: dvclive.Live
513
+ An instance of `dvclive.Live` to use for logging.
514
+ kwargs: Dict
515
+ Additional keyword arguments to pass to the `dvclive.Live` constructor.
516
+
517
+ Returns
518
+ -------
519
+ accelerate.tracking.DVCLiveTracker
520
+ """
521
+ return _DVCLiveTracker (None , live = live , ** kwargs )
522
+ except ImportError : # pragma: no cover
523
+
524
+ def DVCLiveLogger (* args , ** kwargs ):
525
+ raise ImportError ("DVCLiveLogger is not available." )
0 commit comments