Skip to content

Commit 184f1bb

Browse files
committed
feat: added various training loggers using confit.Draft
1 parent 30ec72a commit 184f1bb

File tree

4 files changed

+155
-98
lines changed

4 files changed

+155
-98
lines changed

edsnlp/training/loggers.py

Lines changed: 110 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,22 @@
1010
import edsnlp
1111

1212

13+
def flatten_dict(d, path=""):
14+
if not isinstance(d, (list, dict)):
15+
return {path: d}
16+
17+
if isinstance(d, list):
18+
items = enumerate(d)
19+
else:
20+
items = d.items()
21+
22+
return {
23+
k: v
24+
for key, val in items
25+
for k, v in flatten_dict(val, f"{path}/{key}" if path else key).items()
26+
}
27+
28+
1329
@edsnlp.registry.loggers.register("csv", auto_draft_in_config=True)
1430
class CSVLogger(accelerate.tracking.GeneralTracker):
1531
name = "csv"
@@ -53,7 +69,7 @@ def __init__(
5369
self._has_header = False
5470

5571
@property
56-
def tracker(self):
72+
def tracker(self): # pragma: no cover
5773
return None
5874

5975
@accelerate.tracking.on_main_process
@@ -70,11 +86,9 @@ def log(self, values: Dict[str, Any], step: Optional[int] = None):
7086
- All subsequent calls must use the same columns. Any missing columns get
7187
written as empty, any new columns generate a warning.
7288
"""
73-
# Ensure we have columns set
74-
print("LOGGING TO CSV", self.file_path)
89+
values = flatten_dict(values)
90+
7591
if self._columns is None:
76-
# Create the list of columns. We'll always reserve "step" as first if step
77-
# is provided.
7892
self._columns = list({**{"step": None}, **values}.keys())
7993
self._writer.writerow(self._columns)
8094
self._has_header = True
@@ -142,7 +156,7 @@ def __init__(
142156
self._logs = []
143157

144158
@property
145-
def tracker(self):
159+
def tracker(self): # pragma: no cover
146160
return None
147161

148162
@accelerate.tracking.on_main_process
@@ -267,8 +281,7 @@ def log(self, values: Dict[str, Any], step: Optional[int] = None):
267281
Logs values in the Rich table. If `step` is provided, we include it in the
268282
logged data.
269283
"""
270-
print("LOGGING WITH RICH")
271-
combined = {"step": step, **values}
284+
combined = {"step": step, **flatten_dict(values)}
272285
self.printer.log_metrics(combined)
273286

274287
@accelerate.tracking.on_main_process
@@ -280,36 +293,40 @@ def finish(self):
280293

281294

282295
@edsnlp.registry.loggers.register("tensorboard", auto_draft_in_config=True)
283-
def TensorBoardLogger(
284-
project_name: str,
285-
logging_dir: Optional[Union[str, os.PathLike]] = None,
286-
**kwargs,
287-
) -> "accelerate.tracking.TensorBoardTracker": # pragma: no cover
288-
"""
289-
Logger for [TensorBoard](https://github.com/tensorflow/tensorboard).
290-
This logger is also available via the loggers registry as `tensorboard`.
296+
class TensorBoardLogger(accelerate.tracking.TensorBoardTracker):
297+
def __init__(
298+
self,
299+
project_name: str,
300+
logging_dir: Optional[Union[str, os.PathLike]] = None,
301+
):
302+
"""
303+
Logger for [TensorBoard](https://github.com/tensorflow/tensorboard).
304+
This logger is also available via the loggers registry as `tensorboard`.
291305
292-
Parameters
293-
----------
294-
project_name: str
295-
Name of the project.
296-
logging_dir: Union[str, os.PathLike]
297-
Directory in which to store the TensorBoard logs. Logs of different runs
298-
will be stored in `logging_dir/project_name`. If not provided, the
299-
environment variable `TENSORBOARD_LOGGING_DIR` will be used.
300-
kwargs: Dict
301-
Additional keyword arguments to pass to `tensorboard.SummaryWriter`.
306+
Parameters
307+
----------
308+
project_name: str
309+
Name of the project.
310+
logging_dir: Union[str, os.PathLike]
311+
Directory in which to store the TensorBoard logs. Logs of different runs
312+
will be stored in `logging_dir/project_name`. If not provided, the
313+
environment variable `TENSORBOARD_LOGGING_DIR` will be used.
314+
kwargs: Dict
315+
Additional keyword arguments to pass to `tensorboard.SummaryWriter`.
316+
"""
317+
logging_dir = logging_dir or os.environ.get("TENSORBOARD_LOGGING_DIR", None)
318+
assert logging_dir is not None, (
319+
"Please provide a logging directory or set TENSORBOARD_LOGGING_DIR"
320+
)
321+
super().__init__(project_name, logging_dir)
302322

303-
Returns
304-
-------
305-
accelerate.tracking.TensorBoardTracker
306-
"""
307-
logging_dir = logging_dir or os.environ.get("TENSORBOARD_LOGGING_DIR", None)
308-
assert logging_dir is not None, (
309-
"Please provide a logging directory or set TENSORBOARD_LOGGING_DIR"
310-
)
323+
def store_init_configuration(self, values: Dict[str, Any]):
324+
values = json.loads(json.dumps(flatten_dict(values), default=str))
325+
return super().store_init_configuration(values)
311326

312-
return accelerate.tracking.TensorBoardTracker(project_name, logging_dir, **kwargs)
327+
def log(self, values: dict, step: Optional[int] = None, **kwargs):
328+
values = flatten_dict(values)
329+
return super().log(values, step, **kwargs)
313330

314331

315332
@edsnlp.registry.loggers.register("aim", auto_draft_in_config=True)
@@ -365,31 +382,6 @@ def WandBLogger(
365382
return accelerate.tracking.WandBTracker(project_name, **kwargs)
366383

367384

368-
@edsnlp.registry.loggers.register("clearml", auto_draft_in_config=True)
369-
def ClearMLLogger(
370-
project_name: str,
371-
**kwargs,
372-
) -> "accelerate.tracking.ClearMLTracker": # pragma: no cover
373-
"""
374-
Logger for
375-
[ClearML](https://clear.ml/docs/latest/docs/getting_started/ds/ds_first_steps/).
376-
This logger is also available via the loggers registry as `clearml`.
377-
378-
Parameters
379-
----------
380-
project_name: str
381-
Name of the experiment. Environment variables `CLEARML_PROJECT` and
382-
`CLEARML_TASK` have priority over this argument.
383-
kwargs: Dict
384-
Additional keyword arguments to pass to the ClearML Task object.
385-
386-
Returns
387-
-------
388-
accelerate.tracking.ClearMLTracker
389-
"""
390-
return accelerate.tracking.ClearMLTracker(project_name, **kwargs)
391-
392-
393385
@edsnlp.registry.loggers.register("mlflow", auto_draft_in_config=True)
394386
def MLflowLogger(
395387
project_name: str,
@@ -471,24 +463,63 @@ def CometMLLogger(
471463
return accelerate.tracking.CometMLTracker(project_name, **kwargs)
472464

473465

474-
@edsnlp.registry.loggers.register("dvclive", auto_draft_in_config=True)
475-
def DVCLiveLogger(
476-
live: Any = None,
477-
**kwargs,
478-
) -> "accelerate.tracking.DVCLiveTracker":
479-
"""
480-
Logger for [DVC Live](https://dvc.org/doc/dvclive).
481-
This logger is also available via the loggers registry as `dvclive`.
466+
try:
467+
from accelerate.tracking import ClearMLTracker as _ClearMLTracker
482468

483-
Parameters
484-
----------
485-
live: dvclive.Live
486-
An instance of `dvclive.Live` to use for logging.
487-
kwargs: Dict
488-
Additional keyword arguments to pass to the `dvclive.Live` constructor.
469+
@edsnlp.registry.loggers.register("clearml", auto_draft_in_config=True)
470+
def ClearMLLogger(
471+
project_name: str,
472+
**kwargs,
473+
) -> "accelerate.tracking.ClearMLTracker": # pragma: no cover
474+
"""
475+
Logger for
476+
[ClearML](https://clear.ml/docs/latest/docs/getting_started/ds/ds_first_steps/).
477+
This logger is also available via the loggers registry as `clearml`.
489478
490-
Returns
491-
-------
492-
accelerate.tracking.DVCLiveTracker
493-
"""
494-
return accelerate.tracking.DVCLiveTracker(None, live=live, **kwargs)
479+
Parameters
480+
----------
481+
project_name: str
482+
Name of the experiment. Environment variables `CLEARML_PROJECT` and
483+
`CLEARML_TASK` have priority over this argument.
484+
kwargs: Dict
485+
Additional keyword arguments to pass to the ClearML Task object.
486+
487+
Returns
488+
-------
489+
accelerate.tracking.ClearMLTracker
490+
"""
491+
return _ClearMLTracker(project_name, **kwargs)
492+
except ImportError: # pragma: no cover
493+
494+
def ClearMLLogger(*args, **kwargs):
495+
raise ImportError("ClearMLLogger is not available.")
496+
497+
498+
try:
499+
from accelerate.tracking import DVCLiveTracker as _DVCLiveTracker
500+
501+
@edsnlp.registry.loggers.register("dvclive", auto_draft_in_config=True)
502+
def DVCLiveLogger(
503+
live: Any = None,
504+
**kwargs,
505+
) -> "accelerate.tracking.DVCLiveTracker": # pragma: no cover
506+
"""
507+
Logger for [DVC Live](https://dvc.org/doc/dvclive).
508+
This logger is also available via the loggers registry as `dvclive`.
509+
510+
Parameters
511+
----------
512+
live: dvclive.Live
513+
An instance of `dvclive.Live` to use for logging.
514+
kwargs: Dict
515+
Additional keyword arguments to pass to the `dvclive.Live` constructor.
516+
517+
Returns
518+
-------
519+
accelerate.tracking.DVCLiveTracker
520+
"""
521+
return _DVCLiveTracker(None, live=live, **kwargs)
522+
except ImportError: # pragma: no cover
523+
524+
def DVCLiveLogger(*args, **kwargs):
525+
raise ImportError("DVCLiveLogger is not available.")

edsnlp/training/trainer.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,6 @@
7575
}
7676

7777

78-
def flatten_dict(d, path=""):
79-
if not isinstance(d, dict):
80-
return {path: d}
81-
82-
return {
83-
k: v
84-
for key, val in d.items()
85-
for k, v in flatten_dict(val, f"{path}/{key}" if path else key).items()
86-
}
87-
88-
8978
def fill_flat_stats(x, result, path=()):
9079
if result is None:
9180
result = {}
@@ -580,7 +569,6 @@ def train(
580569
logging_dir=output_dir,
581570
),
582571
)
583-
accelerator.init_trackers(project_name) # in theory project name shouldn't be used
584572
# accelerator.register_for_checkpointing(dataset)
585573
is_main_process = accelerator.is_main_process
586574
device = accelerator.device
@@ -593,15 +581,18 @@ def train(
593581

594582
output_dir = Path(output_dir or Path.cwd() / "artifacts")
595583
output_model_dir = Path(output_model_dir or output_dir / "model-last")
584+
unresolved_config = None
596585
if is_main_process:
597586
os.makedirs(output_dir, exist_ok=True)
598587
os.makedirs(output_model_dir, exist_ok=True)
599-
unresolved_config = {}
600588
if config_meta is not None: # pragma: no cover
601589
unresolved_config = config_meta["unresolved_config"]
602590
print(unresolved_config.to_yaml_str())
603591
unresolved_config.to_disk(output_dir / "train_config.yml")
604592
# TODO: handle config_meta is None
593+
accelerator.init_trackers(
594+
project_name, config=unresolved_config
595+
) # in theory project name shouldn't be used
605596

606597
validation_interval = validation_interval or max_steps // 10
607598
checkpoint_interval = checkpoint_interval or validation_interval
@@ -726,7 +717,7 @@ def train(
726717
**scores,
727718
}
728719
cumulated_data = defaultdict(lambda: 0, **default_metrics)
729-
accelerator.log(flatten_dict(metrics), step=step)
720+
accelerator.log(metrics, step=step)
730721

731722
if on_validation_callback:
732723
on_validation_callback(metrics)
@@ -787,7 +778,7 @@ def train(
787778
del all_res
788779
if isinstance(loss, torch.Tensor) and loss.requires_grad:
789780
accelerator.backward(loss)
790-
except torch.cuda.OutOfMemoryError:
781+
except torch.cuda.OutOfMemoryError: # pragma: no cover
791782
print(
792783
"Out of memory error encountered when processing a "
793784
"batch with the following statistics:"
@@ -858,4 +849,7 @@ def train(
858849

859850
del iterator
860851

852+
# Should we put this in a finally block?
853+
accelerator.end_training()
854+
861855
return nlp

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ where = ["."]
296296
"csv" = "edsnlp.training.loggers:CSVLogger"
297297
"json" = "edsnlp.training.loggers:JSONLogger"
298298
"rich" = "edsnlp.training.loggers:RichLogger"
299-
"tensorboard" = "edsnlp.training.loggers:TensorboardLogger"
299+
"tensorboard" = "edsnlp.training.loggers:TensorBoardLogger"
300300
"aim" = "edsnlp.training.loggers:AimLogger"
301301
"wandb" = "edsnlp.training.loggers:WandBLogger"
302302
"clearml" = "edsnlp.training.loggers:ClearMLLogger"

tests/training/test_train.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,16 @@ def test_ner_qualif_train_diff_bert(run_in_test_dir, tmp_path):
9999
config = Config.from_disk("ner_qlf_diff_bert_config.yml")
100100
shutil.rmtree(tmp_path, ignore_errors=True)
101101
kwargs = Config.resolve(config["train"], registry=registry, root=config)
102-
nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
102+
nlp = train(
103+
**kwargs,
104+
output_dir=tmp_path,
105+
cpu=True,
106+
config_meta={
107+
"config_path": "dep_parser_config.yml",
108+
"resolved_config": kwargs,
109+
"unresolved_config": config,
110+
},
111+
)
103112
scorer = GenericScorer(**kwargs["scorer"])
104113
val_data = kwargs["val_data"]
105114
last_scores = scorer(nlp, val_data)
@@ -120,7 +129,16 @@ def test_ner_qualif_train_same_bert(run_in_test_dir, tmp_path):
120129
config = Config.from_disk("ner_qlf_same_bert_config.yml")
121130
shutil.rmtree(tmp_path, ignore_errors=True)
122131
kwargs = Config.resolve(config["train"], registry=registry, root=config)
123-
nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
132+
nlp = train(
133+
**kwargs,
134+
output_dir=tmp_path,
135+
cpu=True,
136+
config_meta={
137+
"config_path": "dep_parser_config.yml",
138+
"resolved_config": kwargs,
139+
"unresolved_config": config,
140+
},
141+
)
124142
scorer = GenericScorer(**kwargs["scorer"])
125143
val_data = kwargs["val_data"]
126144
last_scores = scorer(nlp, val_data)
@@ -137,7 +155,16 @@ def test_qualif_train(run_in_test_dir, tmp_path):
137155
config = Config.from_disk("qlf_config.yml")
138156
shutil.rmtree(tmp_path, ignore_errors=True)
139157
kwargs = Config.resolve(config["train"], registry=registry, root=config)
140-
nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
158+
nlp = train(
159+
**kwargs,
160+
output_dir=tmp_path,
161+
cpu=True,
162+
config_meta={
163+
"config_path": "dep_parser_config.yml",
164+
"resolved_config": kwargs,
165+
"unresolved_config": config,
166+
},
167+
)
141168
scorer = GenericScorer(**kwargs["scorer"])
142169
val_data = kwargs["val_data"]
143170
last_scores = scorer(nlp, val_data)
@@ -158,6 +185,11 @@ def test_dep_parser_train(run_in_test_dir, tmp_path):
158185
logger=CSVLogger.draft(),
159186
output_dir=tmp_path,
160187
cpu=True,
188+
config_meta={
189+
"config_path": "dep_parser_config.yml",
190+
"resolved_config": kwargs,
191+
"unresolved_config": config,
192+
},
161193
)
162194
scorer = GenericScorer(**kwargs["scorer"])
163195
val_data = list(kwargs["val_data"])

0 commit comments

Comments
 (0)