Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 24 additions & 30 deletions pvnet/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class AbstractOptimizer(ABC):
"""

@abstractmethod
def __call__(self):
def __call__(self, model: Module):
"""Abstract call"""
pass

Expand Down Expand Up @@ -129,19 +129,18 @@ def __call__(self, model):
{"params": decay, "weight_decay": self.weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]
monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs)

sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
)
sch = {
"scheduler": sch,
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
return {
"optimizer": opt,
"lr_scheduler": {"scheduler": sch, "monitor": monitor},
}
return [opt], [sch]


class AdamWReduceLROnPlateau(AbstractOptimizer):
Expand All @@ -157,7 +156,7 @@ def __init__(
**opt_kwargs,
):
"""AdamW optimizer and reduce on plateau scheduler"""
self._lr = lr
self.lr = lr
self.patience = patience
self.factor = factor
self.threshold = threshold
Expand All @@ -169,7 +168,7 @@ def _call_multi(self, model):

group_args = []

for key in self._lr.keys():
for key in self.lr.keys():
if key == "default":
continue

Expand All @@ -178,43 +177,38 @@ def _call_multi(self, model):
if param_name.startswith(key):
submodule_params += [remaining_params.pop(param_name)]

group_args += [{"params": submodule_params, "lr": self._lr[key]}]
group_args += [{"params": submodule_params, "lr": self.lr[key]}]

remaining_params = [p for k, p in remaining_params.items()]
group_args += [{"params": remaining_params}]

opt = torch.optim.AdamW(
group_args,
lr=self._lr["default"] if model.lr is None else model.lr,
**self.opt_kwargs,
monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
opt = torch.optim.AdamW(group_args, lr=self.lr["default"], **self.opt_kwargs)
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
)
sch = {
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
),
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
return {
"optimizer": opt,
"lr_scheduler": {"scheduler": sch, "monitor": monitor},
}

return [opt], [sch]

def __call__(self, model):
"""Return optimizer"""
if not isinstance(self._lr, float):
if not isinstance(self.lr, float):
return self._call_multi(model)
else:
default_lr = self._lr if model.lr is None else model.lr
opt = torch.optim.AdamW(model.parameters(), lr=default_lr, **self.opt_kwargs)
monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val"
opt = torch.optim.AdamW(model.parameters(), lr=self.lr, **self.opt_kwargs)
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
)
sch = {
"scheduler": sch,
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
return {
"optimizer": opt,
"lr_scheduler": {"scheduler": sch, "monitor": monitor},
}
return [opt], [sch]
10 changes: 6 additions & 4 deletions pvnet/training/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
self.model = model
self._optimizer = optimizer
self.save_all_validation_results = save_all_validation_results
self.save_hyperparameters(ignore=["model", "optimizer"])

# Model must have lr to allow tuning
# This setting is only used when lr is tuned with callback
Expand Down Expand Up @@ -109,7 +110,7 @@ def training_step(self, batch: TensorBatch, batch_idx: int) -> torch.Tensor:
losses = self._calculate_common_losses(y, y_hat)
losses = {f"{k}/train": v for k, v in losses.items()}

self.log_dict(losses, on_step=True, on_epoch=True)
self.log_dict(losses, on_step=True, on_epoch=True, sync_dist=True, batch_size=y.size(0))

if self.model.use_quantile_regression:
opt_target = losses["quantile_loss/train"]
Expand Down Expand Up @@ -276,8 +277,9 @@ def validation_step(self, batch: TensorBatch, batch_idx: int) -> None:
}
)

# Log the metrics
self.log_dict(losses, on_step=False, on_epoch=True)
# Log the metrics (DDP/epoch-correct for scheduler monitors)
self.log_dict(losses, on_step=False, on_epoch=True, sync_dist=True, batch_size=y.size(0))


def on_validation_epoch_end(self) -> None:
"""Run on epoch end"""
Expand Down Expand Up @@ -322,7 +324,7 @@ def on_validation_epoch_end(self) -> None:
filepath = f"{wandb_log_dir}/validation_results.netcdf"
ds_val_results.to_netcdf(filepath)

# Uplodad to wandb
# Uplodad to wandb
self.logger.experiment.save(filepath, base_path=wandb_log_dir, policy="now")

# Create the horizon accuracy curve
Expand Down