Skip to content

Commit 66ca3e1

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent bf600f7 commit 66ca3e1

File tree

25 files changed

+178
-187
lines changed

25 files changed

+178
-187
lines changed

examples/__only_for_dev__/to_test_regression_custom_models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,15 @@ class MultiStageModelConfig(ModelConfig):
8383
threshold_init_beta: float = field(
8484
default=1.0,
8585
metadata={
86-
"help": """
87-
Used in the Data-aware initialization of thresholds where the threshold is initialized randomly
88-
(with a beta distribution) to feature values in the first batch.
89-
It initializes threshold to a q-th quantile of data points.
90-
where q ~ Beta(:threshold_init_beta:, :threshold_init_beta:)
91-
If this param is set to 1, initial thresholds will have the same distribution as data points
92-
If greater than 1 (e.g. 10), thresholds will be closer to median data value
93-
If less than 1 (e.g. 0.1), thresholds will approach min/max data values.
94-
"""
86+
"help": """Used in the Data-aware initialization of thresholds where the threshold is initialized randomly
87+
(with a beta distribution) to feature values in the first batch.
88+
89+
It initializes threshold to a q-th quantile of data points. where q ~ Beta(:threshold_init_beta:,
90+
:threshold_init_beta:) If this param is set to 1, initial thresholds will have the same distribution
91+
as data points If greater than 1 (e.g. 10), thresholds will be closer to median data value If less
92+
than 1 (e.g. 0.1), thresholds will approach min/max data values.
93+
94+
"""
9595
},
9696
)
9797
threshold_init_cutoff: float = field(

src/pytorch_tabular/categorical_encoders.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def transform(self, X):
5858
raise ValueError("`fit` method must be called before `transform`.")
5959
assert all(c in X.columns for c in self.cols)
6060
if self.handle_missing == "error":
61-
assert (
62-
not X[self.cols].isnull().any().any()
63-
), "`handle_missing` = `error` and missing values found in columns to encode."
61+
assert not X[self.cols].isnull().any().any(), (
62+
"`handle_missing` = `error` and missing values found in columns to encode."
63+
)
6464
X_encoded = X.copy(deep=True)
6565
category_cols = X_encoded.select_dtypes(include="category").columns
6666
X_encoded[category_cols] = X_encoded[category_cols].astype("object")
@@ -153,9 +153,9 @@ def fit(self, X, y=None):
153153
"""
154154
self._before_fit_check(X, y)
155155
if self.handle_missing == "error":
156-
assert (
157-
not X[self.cols].isnull().any().any()
158-
), "`handle_missing` = `error` and missing values found in columns to encode."
156+
assert not X[self.cols].isnull().any().any(), (
157+
"`handle_missing` = `error` and missing values found in columns to encode."
158+
)
159159
for col in self.cols:
160160
map = Series(unique(X[col].fillna(NAN_CATEGORY)), name=col).reset_index().rename(columns={"index": "value"})
161161
map["value"] += 1

src/pytorch_tabular/config/config.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,15 @@ def _read_yaml(filename):
2626
"tag:yaml.org,2002:float",
2727
re.compile(
2828
"""^(?:
29-
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
30-
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
31-
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
32-
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
33-
|[-+]?\\.(?:inf|Inf|INF)
34-
|\\.(?:nan|NaN|NAN))$""",
29+
30+
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
31+
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
32+
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
33+
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
34+
|[-+]?\\.(?:inf|Inf|INF)
35+
|\\.(?:nan|NaN|NAN))$
36+
37+
""",
3538
re.X,
3639
),
3740
list("-+0123456789."),
@@ -192,9 +195,9 @@ class DataConfig:
192195
)
193196

194197
def __post_init__(self):
195-
assert (
196-
len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0
197-
), "There should be at-least one feature defined in categorical, continuous, or date columns"
198+
assert len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0, (
199+
"There should be at-least one feature defined in categorical, continuous, or date columns"
200+
)
198201
_validate_choices(self)
199202
if os.name == "nt" and self.num_workers != 0:
200203
print("Windows does not support num_workers > 0. Setting num_workers to 0")
@@ -255,9 +258,9 @@ class InferredConfig:
255258

256259
def __post_init__(self):
257260
if self.embedding_dims is not None:
258-
assert all(
259-
(isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims
260-
), "embedding_dims must be a list of tuples (cardinality, embedding_dim)"
261+
assert all((isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims), (
262+
"embedding_dims must be a list of tuples (cardinality, embedding_dim)"
263+
)
261264
self.embedded_cat_dim = sum([t[1] for t in self.embedding_dims])
262265
else:
263266
self.embedded_cat_dim = 0
@@ -581,24 +584,25 @@ def __post_init__(self):
581584

582585
@dataclass
583586
class ExperimentConfig:
584-
"""Experiment configuration. Experiment Tracking with WandB and Tensorboard.
587+
"""Experiment configuration.
585588
586-
Args:
587-
project_name (str): The name of the project under which all runs will be logged. For Tensorboard
588-
this defines the folder under which the logs will be saved and for W&B it defines the project name
589+
Experiment Tracking with WandB and Tensorboard.
590+
Args:
591+
project_name (str): The name of the project under which all runs will be logged. For Tensorboard
592+
this defines the folder under which the logs will be saved and for W&B it defines the project name
589593
590-
run_name (Optional[str]): The name of the run; a specific identifier to recognize the run. If left
591-
blank, will be assigned an auto-generated name
594+
run_name (Optional[str]): The name of the run; a specific identifier to recognize the run. If left
595+
blank, will be assigned an auto-generated name
592596
593-
exp_watch (Optional[str]): The level of logging required. Can be `gradients`, `parameters`, `all`
594-
or `None`. Defaults to None. Choices are: [`gradients`,`parameters`,`all`,`None`].
597+
exp_watch (Optional[str]): The level of logging required. Can be `gradients`, `parameters`, `all`
598+
or `None`. Defaults to None. Choices are: [`gradients`,`parameters`,`all`,`None`].
595599
596-
log_target (str): Determines where logging happens - Tensorboard or W&B. Choices are:
597-
[`wandb`,`tensorboard`].
600+
log_target (str): Determines where logging happens - Tensorboard or W&B. Choices are:
601+
[`wandb`,`tensorboard`].
598602
599-
log_logits (bool): Turn this on to log the logits as a histogram in W&B
603+
log_logits (bool): Turn this on to log the logits as a histogram in W&B
600604
601-
exp_log_freq (int): step count between logging of gradients and parameters.
605+
exp_log_freq (int): step count between logging of gradients and parameters.
602606
603607
"""
604608

@@ -730,8 +734,8 @@ def __init__(
730734
self,
731735
exp_version_manager: str = ".pt_tmp/exp_version_manager.yml",
732736
) -> None:
733-
"""The manages the versions of the experiments based on the name. It is a simple dictionary(yaml) based lookup.
734-
Primary purpose is to avoid overwriting of saved models while running the training without changing the
737+
"""The manages the versions of the experiments based on the name. Primary purpose is to avoid overwriting of
738+
saved models while running the training without changing the It is a simple dictionary(yaml) based lookup.
735739
experiment name.
736740
737741
Args:

src/pytorch_tabular/feature_extractor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
5959
pd.DataFrame: The encoded dataframe
6060
6161
"""
62-
6362
X_encoded = X.copy(deep=True)
6463
orig_features = X_encoded.columns
6564
self.tabular_model.model.eval()

src/pytorch_tabular/models/category_embedding/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class CategoryEmbeddingModelConfig(ModelConfig):
9898
)
9999
use_batch_norm: bool = field(
100100
default=False,
101-
metadata={"help": ("Flag to include a BatchNorm layer after each Linear Layer+DropOut." " Defaults to False")},
101+
metadata={"help": ("Flag to include a BatchNorm layer after each Linear Layer+DropOut. Defaults to False")},
102102
)
103103
initialization: str = field(
104104
default="kaiming",

src/pytorch_tabular/models/common/heads/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
@dataclass
88
class LinearHeadConfig:
9-
"""A model class for Linear Head configuration; serves as a template and documentation. The models take a
10-
dictionary as input, but if there are keys which are not present in this model class, it'll throw an exception.
9+
"""A model class for Linear Head configuration; serves as a template and documentation. dictionary as input, but if
10+
there are keys which are not present in this model class, it'll throw an exception. The models take a.
1111
1212
Args:
1313
layers (str): Hyphen-separated number of layers and units in the classification/regression head.

src/pytorch_tabular/models/common/layers/activations.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def _sparsemax_threshold_and_support(X, dim=-1, k=None):
151151
the threshold value for each vector
152152
support_size : torch LongTensor, shape like `tau`
153153
the number of nonzeros in each vector.
154-
"""
155154
155+
"""
156156
if k is None or k >= X.shape[dim]: # do full sort
157157
topk, _ = torch.sort(X, dim=dim, descending=True)
158158
else:
@@ -204,7 +204,6 @@ def _entmax_threshold_and_support(X, dim=-1, k=None):
204204
the number of nonzeros in each vector.
205205
206206
"""
207-
208207
if k is None or k >= X.shape[dim]: # do full sort
209208
Xsrt, _ = torch.sort(X, dim=dim, descending=True)
210209
else:
@@ -288,7 +287,7 @@ def backward(cls, ctx, dY):
288287

289288

290289
def sparsemax(X, dim=-1, k=None):
291-
"""sparsemax: normalizing sparse transform (a la softmax).
290+
"""Sparsemax: normalizing sparse transform (a la softmax).
292291
293292
Solves the projection:
294293
@@ -313,8 +312,8 @@ def sparsemax(X, dim=-1, k=None):
313312
-------
314313
P : torch tensor, same shape as X
315314
The projection result, such that P.sum(dim=dim) == 1 elementwise.
316-
"""
317315
316+
"""
318317
return SparsemaxFunction.apply(X, dim, k)
319318

320319

@@ -347,13 +346,12 @@ def entmax15(X, dim=-1, k=None):
347346
P : torch tensor, same shape as X
348347
The projection result, such that P.sum(dim=dim) == 1 elementwise.
349348
"""
350-
351349
return Entmax15Function.apply(X, dim, k)
352350

353351

354352
class Sparsemax(nn.Module):
355353
def __init__(self, dim=-1, k=None):
356-
"""sparsemax: normalizing sparse transform (a la softmax).
354+
"""Sparsemax: normalizing sparse transform (a la softmax).
357355
358356
Solves the projection:
359357
@@ -370,6 +368,7 @@ def __init__(self, dim=-1, k=None):
370368
nonzeros in the solution. If the solution is more than k-sparse,
371369
this function is recursively called with a 2*k schedule.
372370
If `None`, full sorting is performed from the beginning.
371+
373372
"""
374373
self.dim = dim
375374
self.k = k

src/pytorch_tabular/models/common/layers/embeddings.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
8484
x.get("continuous", torch.empty(0, 0)),
8585
x.get("categorical", torch.empty(0, 0)),
8686
)
87-
assert (
88-
categorical_data.shape[1] == self.categorical_dim
89-
), "categorical_data must have same number of columns as categorical embedding layers"
90-
assert (
91-
continuous_data.shape[1] == self.continuous_dim
92-
), "continuous_data must have same number of columns as continuous dim"
87+
assert categorical_data.shape[1] == self.categorical_dim, (
88+
"categorical_data must have same number of columns as categorical embedding layers"
89+
)
90+
assert continuous_data.shape[1] == self.continuous_dim, (
91+
"continuous_data must have same number of columns as continuous dim"
92+
)
9393
embed = None
9494
if continuous_data.shape[1] > 0:
9595
if self.batch_norm_continuous_input:
@@ -141,12 +141,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
141141
x.get("continuous", torch.empty(0, 0)),
142142
x.get("categorical", torch.empty(0, 0)),
143143
)
144-
assert categorical_data.shape[1] == len(
145-
self.cat_embedding_layers
146-
), "categorical_data must have same number of columns as categorical embedding layers"
147-
assert (
148-
continuous_data.shape[1] == self.continuous_dim
149-
), "continuous_data must have same number of columns as continuous dim"
144+
assert categorical_data.shape[1] == len(self.cat_embedding_layers), (
145+
"categorical_data must have same number of columns as categorical embedding layers"
146+
)
147+
assert continuous_data.shape[1] == self.continuous_dim, (
148+
"continuous_data must have same number of columns as continuous dim"
149+
)
150150
embed = None
151151
if continuous_data.shape[1] > 0:
152152
if self.batch_norm_continuous_input:
@@ -273,12 +273,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
273273
x.get("continuous", torch.empty(0, 0)),
274274
x.get("categorical", torch.empty(0, 0)),
275275
)
276-
assert categorical_data.shape[1] == len(
277-
self.cat_embedding_layers
278-
), "categorical_data must have same number of columns as categorical embedding layers"
279-
assert (
280-
continuous_data.shape[1] == self.continuous_dim
281-
), "continuous_data must have same number of columns as continuous dim"
276+
assert categorical_data.shape[1] == len(self.cat_embedding_layers), (
277+
"categorical_data must have same number of columns as categorical embedding layers"
278+
)
279+
assert continuous_data.shape[1] == self.continuous_dim, (
280+
"continuous_data must have same number of columns as continuous dim"
281+
)
282282
embed = None
283283
if continuous_data.shape[1] > 0:
284284
cont_idx = torch.arange(self.continuous_dim, device=continuous_data.device).expand(

src/pytorch_tabular/models/gate/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def __post_init__(self):
173173
assert self.tree_depth > 0, "tree_depth should be greater than 0"
174174
# Either gflu_stages or num_trees should be greater than 0
175175
assert self.num_trees > 0, (
176-
"`num_trees` must be greater than 0." "If you want a lighter model which performs better, use GANDALF."
176+
"`num_trees` must be greater than 0.If you want a lighter model which performs better, use GANDALF."
177177
)
178178
super().__post_init__()
179179

src/pytorch_tabular/models/gate/gate_model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ def __init__(
5151
embedding_dropout: float = 0.0,
5252
):
5353
super().__init__()
54-
assert (
55-
binning_activation in self.BINARY_ACTIVATION_MAP.keys()
56-
), f"`binning_activation should be one of {self.BINARY_ACTIVATION_MAP.keys()}"
57-
assert (
58-
feature_mask_function in self.ACTIVATION_MAP.keys()
59-
), f"`feature_mask_function should be one of {self.ACTIVATION_MAP.keys()}"
54+
assert binning_activation in self.BINARY_ACTIVATION_MAP.keys(), (
55+
f"`binning_activation should be one of {self.BINARY_ACTIVATION_MAP.keys()}"
56+
)
57+
assert feature_mask_function in self.ACTIVATION_MAP.keys(), (
58+
f"`feature_mask_function should be one of {self.ACTIVATION_MAP.keys()}"
59+
)
6060

6161
self.gflu_stages = gflu_stages
6262
self.num_trees = num_trees

0 commit comments

Comments
 (0)