diff --git a/pvnet/optimizers.py b/pvnet/optimizers.py index adfd7b93..7d9d22e0 100644 --- a/pvnet/optimizers.py +++ b/pvnet/optimizers.py @@ -65,7 +65,7 @@ class AbstractOptimizer(ABC): """ @abstractmethod - def __call__(self): + def __call__(self, model: Module): """Abstract call""" pass @@ -129,19 +129,18 @@ def __call__(self, model): {"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}, ] + monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val" opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs) - sch = torch.optim.lr_scheduler.ReduceLROnPlateau( opt, factor=self.factor, patience=self.patience, threshold=self.threshold, ) - sch = { - "scheduler": sch, - "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", + return { + "optimizer": opt, + "lr_scheduler": {"scheduler": sch, "monitor": monitor}, } - return [opt], [sch] class AdamWReduceLROnPlateau(AbstractOptimizer): @@ -157,7 +156,7 @@ def __init__( **opt_kwargs, ): """AdamW optimizer and reduce on plateau scheduler""" - self._lr = lr + self.lr = lr self.patience = patience self.factor = factor self.threshold = threshold @@ -169,7 +168,7 @@ def _call_multi(self, model): group_args = [] - for key in self._lr.keys(): + for key in self.lr.keys(): if key == "default": continue @@ -178,43 +177,38 @@ def _call_multi(self, model): if param_name.startswith(key): submodule_params += [remaining_params.pop(param_name)] - group_args += [{"params": submodule_params, "lr": self._lr[key]}] + group_args += [{"params": submodule_params, "lr": self.lr[key]}] remaining_params = [p for k, p in remaining_params.items()] group_args += [{"params": remaining_params}] - - opt = torch.optim.AdamW( - group_args, - lr=self._lr["default"] if model.lr is None else model.lr, - **self.opt_kwargs, + monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val" + opt = torch.optim.AdamW(group_args, lr=self.lr["default"], **self.opt_kwargs) + sch = torch.optim.lr_scheduler.ReduceLROnPlateau( + opt, + factor=self.factor, + patience=self.patience, + threshold=self.threshold, ) - sch = { - "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( - opt, - factor=self.factor, - patience=self.patience, - threshold=self.threshold, - ), - "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", + return { + "optimizer": opt, + "lr_scheduler": {"scheduler": sch, "monitor": monitor}, } - return [opt], [sch] def __call__(self, model): """Return optimizer""" - if not isinstance(self._lr, float): + if not isinstance(self.lr, float): return self._call_multi(model) else: - default_lr = self._lr if model.lr is None else model.lr - opt = torch.optim.AdamW(model.parameters(), lr=default_lr, **self.opt_kwargs) + monitor = "quantile_loss/val" if model.use_quantile_regression else "MAE/val" + opt = torch.optim.AdamW(model.parameters(), lr=self.lr, **self.opt_kwargs) sch = torch.optim.lr_scheduler.ReduceLROnPlateau( opt, factor=self.factor, patience=self.patience, threshold=self.threshold, ) - sch = { - "scheduler": sch, - "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", + return { + "optimizer": opt, + "lr_scheduler": {"scheduler": sch, "monitor": monitor}, } - return [opt], [sch] diff --git a/pvnet/training/lightning_module.py b/pvnet/training/lightning_module.py index 252d1906..b5d5030d 100644 --- a/pvnet/training/lightning_module.py +++ b/pvnet/training/lightning_module.py @@ -39,6 +39,7 @@ def __init__( self.model = model self._optimizer = optimizer self.save_all_validation_results = save_all_validation_results + self.save_hyperparameters(ignore=["model", "optimizer"]) # Model must have lr to allow tuning # This setting is only used when lr is tuned with callback @@ -109,7 +110,7 @@ def training_step(self, batch: TensorBatch, batch_idx: int) -> torch.Tensor: losses = self._calculate_common_losses(y, y_hat) losses = {f"{k}/train": v for k, v in losses.items()} - self.log_dict(losses, on_step=True, on_epoch=True) + self.log_dict(losses, on_step=True, on_epoch=True, sync_dist=True, batch_size=y.size(0)) if self.model.use_quantile_regression: opt_target = losses["quantile_loss/train"] @@ -276,8 +277,9 @@ def validation_step(self, batch: TensorBatch, batch_idx: int) -> None: } ) - # Log the metrics - self.log_dict(losses, on_step=False, on_epoch=True) + # Log the metrics (DDP/epoch-correct for scheduler monitors) + self.log_dict(losses, on_step=False, on_epoch=True, sync_dist=True, batch_size=y.size(0)) + def on_validation_epoch_end(self) -> None: """Run on epoch end""" @@ -322,7 +324,7 @@ def on_validation_epoch_end(self) -> None: filepath = f"{wandb_log_dir}/validation_results.netcdf" ds_val_results.to_netcdf(filepath) - # Uplodad to wandb + # Uplodad to wandb self.logger.experiment.save(filepath, base_path=wandb_log_dir, policy="now") # Create the horizon accuracy curve