@@ -26,12 +26,15 @@ def _read_yaml(filename):
26
26
"tag:yaml.org,2002:float" ,
27
27
re .compile (
28
28
"""^(?:
29
- [-+]?(?:[0-9][0-9_]*)\\ .[0-9_]*(?:[eE][-+]?[0-9]+)?
30
- |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
31
- |\\ .[0-9_]+(?:[eE][-+][0-9]+)?
32
- |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\ .[0-9_]*
33
- |[-+]?\\ .(?:inf|Inf|INF)
34
- |\\ .(?:nan|NaN|NAN))$""" ,
29
+
30
+ [-+]?(?:[0-9][0-9_]*)\\ .[0-9_]*(?:[eE][-+]?[0-9]+)?
31
+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
32
+ |\\ .[0-9_]+(?:[eE][-+][0-9]+)?
33
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\ .[0-9_]*
34
+ |[-+]?\\ .(?:inf|Inf|INF)
35
+ |\\ .(?:nan|NaN|NAN))$
36
+
37
+ """ ,
35
38
re .X ,
36
39
),
37
40
list ("-+0123456789." ),
@@ -192,9 +195,9 @@ class DataConfig:
192
195
)
193
196
194
197
def __post_init__ (self ):
195
- assert (
196
- len ( self . categorical_cols ) + len ( self . continuous_cols ) + len ( self . date_columns ) > 0
197
- ), "There should be at-least one feature defined in categorical, continuous, or date columns"
198
+ assert len ( self . categorical_cols ) + len ( self . continuous_cols ) + len ( self . date_columns ) > 0 , (
199
+ "There should be at-least one feature defined in categorical, continuous, or date columns"
200
+ )
198
201
_validate_choices (self )
199
202
if os .name == "nt" and self .num_workers != 0 :
200
203
print ("Windows does not support num_workers > 0. Setting num_workers to 0" )
@@ -255,9 +258,9 @@ class InferredConfig:
255
258
256
259
def __post_init__ (self ):
257
260
if self .embedding_dims is not None :
258
- assert all (
259
- ( isinstance ( t , Iterable ) and len ( t ) == 2 ) for t in self . embedding_dims
260
- ), "embedding_dims must be a list of tuples (cardinality, embedding_dim)"
261
+ assert all (( isinstance ( t , Iterable ) and len ( t ) == 2 ) for t in self . embedding_dims ), (
262
+ "embedding_dims must be a list of tuples (cardinality, embedding_dim)"
263
+ )
261
264
self .embedded_cat_dim = sum ([t [1 ] for t in self .embedding_dims ])
262
265
else :
263
266
self .embedded_cat_dim = 0
@@ -581,24 +584,25 @@ def __post_init__(self):
581
584
582
585
@dataclass
583
586
class ExperimentConfig :
584
- """Experiment configuration. Experiment Tracking with WandB and Tensorboard.
587
+ """Experiment configuration.
585
588
586
- Args:
587
- project_name (str): The name of the project under which all runs will be logged. For Tensorboard
588
- this defines the folder under which the logs will be saved and for W&B it defines the project name
589
+ Experiment Tracking with WandB and Tensorboard.
590
+ Args:
591
+ project_name (str): The name of the project under which all runs will be logged. For Tensorboard
592
+ this defines the folder under which the logs will be saved and for W&B it defines the project name
589
593
590
- run_name (Optional[str]): The name of the run; a specific identifier to recognize the run. If left
591
- blank, will be assigned an auto-generated name
594
+ run_name (Optional[str]): The name of the run; a specific identifier to recognize the run. If left
595
+ blank, will be assigned an auto-generated name
592
596
593
- exp_watch (Optional[str]): The level of logging required. Can be `gradients`, `parameters`, `all`
594
- or `None`. Defaults to None. Choices are: [`gradients`,`parameters`,`all`,`None`].
597
+ exp_watch (Optional[str]): The level of logging required. Can be `gradients`, `parameters`, `all`
598
+ or `None`. Defaults to None. Choices are: [`gradients`,`parameters`,`all`,`None`].
595
599
596
- log_target (str): Determines where logging happens - Tensorboard or W&B. Choices are:
597
- [`wandb`,`tensorboard`].
600
+ log_target (str): Determines where logging happens - Tensorboard or W&B. Choices are:
601
+ [`wandb`,`tensorboard`].
598
602
599
- log_logits (bool): Turn this on to log the logits as a histogram in W&B
603
+ log_logits (bool): Turn this on to log the logits as a histogram in W&B
600
604
601
- exp_log_freq (int): step count between logging of gradients and parameters.
605
+ exp_log_freq (int): step count between logging of gradients and parameters.
602
606
603
607
"""
604
608
@@ -730,8 +734,8 @@ def __init__(
730
734
self ,
731
735
exp_version_manager : str = ".pt_tmp/exp_version_manager.yml" ,
732
736
) -> None :
733
- """The manages the versions of the experiments based on the name. It is a simple dictionary(yaml) based lookup.
734
- Primary purpose is to avoid overwriting of saved models while running the training without changing the
737
+ """The manages the versions of the experiments based on the name. Primary purpose is to avoid overwriting of
738
+ saved models while running the training without changing the It is a simple dictionary(yaml) based lookup.
735
739
experiment name.
736
740
737
741
Args:
0 commit comments