diff --git a/hungarian_net/lightning_modules/hnet_gru_lightning.py b/hungarian_net/lightning_modules/hnet_gru_lightning.py index f4378a7..c7937b5 100644 --- a/hungarian_net/lightning_modules/hnet_gru_lightning.py +++ b/hungarian_net/lightning_modules/hnet_gru_lightning.py @@ -9,6 +9,7 @@ from torch import optim from torchmetrics import MetricCollection +from hungarian_net.loss.balancer import Balancer from hungarian_net.torch_modules.hnet_gru import HNetGRU @@ -50,7 +51,11 @@ def __init__( self.criterion1 = nn.BCEWithLogitsLoss(reduction="sum") self.criterion2 = nn.BCEWithLogitsLoss(reduction="sum") self.criterion3 = nn.BCEWithLogitsLoss(reduction="sum") - self.criterion_wts = [1.0, 1.0, 1.0] + + self.loss: dict = {"crit1": self.criterion1, "crit2": self.criterion2, "crit3": self.criterion3} + self.criterion_wts = {"crit1": 1.0, "crit2": 1.0, "crit3": 1.0} + + self.balancer = Balancer(self.criterion_wts) self.optimizer: torch.optim.Optimizer = optimizer( params=self.model.parameters() @@ -80,14 +85,14 @@ def common_step( # forward pass output = self.model(data) - l1 = self.criterion1(output[0], target[0]) - l2 = self.criterion2(output[1], target[1]) - l3 = self.criterion3(output[2], target[2]) + self.loss['crit1'] = self.criterion1(output[0], target[0]) + self.loss['crit2'] = self.criterion2(output[1], target[1]) + self.loss['crit3'] = self.criterion3(output[2], target[2]) loss = ( - self.criterion_wts[0] * l1 - + self.criterion_wts[1] * l2 - + self.criterion_wts[2] * l3 + self.criterion_wts['crit1'] * self.loss['crit1'] + + self.criterion_wts['crit2'] * self.loss['crit2'] + + self.criterion_wts['crit3'] * self.loss['crit3'] ) return loss, output, target @@ -104,8 +109,14 @@ def training_step(self, batch, batch_idx) -> Dict[str, torch.Tensor]: torch.Tensor: Training loss. """ + self.optimizer.zero_grad() + loss, output, target = self.common_step(batch, batch_idx) + self.balancer.backward(loss, output[0]) + + self.optimizer.step() + outputs = {"loss": loss, "output": output, "target": target} return outputs diff --git a/hungarian_net/loss/balancer.py b/hungarian_net/loss/balancer.py new file mode 100644 index 0000000..65b2723 --- /dev/null +++ b/hungarian_net/loss/balancer.py @@ -0,0 +1,176 @@ +""" +@article{defossez2022highfi, + title={High Fidelity Neural Audio Compression}, + author={Défossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi}, + journal={arXiv preprint arXiv:2210.13438}, + year={2022} +} +""" + +from collections import defaultdict +import typing as tp + +import torch +from torch import autograd + +def world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + +def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): + if is_distributed(): + return torch.distributed.all_reduce(tensor, op) + +def is_distributed(): + return world_size() > 1 + +def average_metrics(metrics: tp.Dict[str, float], count=1.): + """Average a dictionary of metrics across all workers, using the optional + `count` as unnormalized weight. + """ + if not is_distributed(): + return metrics + keys, values = zip(*metrics.items()) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) + tensor *= count + all_reduce(tensor) + averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() + return dict(zip(keys, averaged)) + +def averager(beta: float = 1): + """ + Exponential Moving Average callback. + Returns a single function that can be called to repeatidly update the EMA + with a dict of metrics. The callback will return + the new averaged dict of metrics. + + Note that for `beta=1`, this is just plain averaging. + """ + fix: tp.Dict[str, float] = defaultdict(float) + total: tp.Dict[str, float] = defaultdict(float) + + def _update(metrics: tp.Dict[str, tp.Any], weight: float = 1) -> tp.Dict[str, float]: + nonlocal total, fix + for key, value in metrics.items(): + total[key] = total[key] * beta + weight * float(value) + fix[key] = fix[key] * beta + weight + return {key: tot / fix[key] for key, tot in total.items()} + return _update + + +class Balancer: + """Loss balancer. + + The loss balancer combines losses together to compute gradients for the backward. + A call to the balancer will weight the losses according the specified weight coefficients. + A call to the backward method of the balancer will compute the gradients, combining all the losses and + potentially rescaling the gradients, which can help stabilize the training and reasonate + about multiple losses with varying scales. + + Expected usage: + weights = {'loss_a': 1, 'loss_b': 4} + balancer = Balancer(weights, ...) + losses: dict = {} + losses['loss_a'] = compute_loss_a(x, y) + losses['loss_b'] = compute_loss_b(x, y) + if model.training(): + balancer.backward(losses, x) + + ..Warning:: It is unclear how this will interact with DistributedDataParallel, + in particular if you have some losses not handled by the balancer. In that case + you can use `encodec.distrib.sync_grad(model.parameters())` and + `encodec.distrib.sync_buffwers(model.buffers())` as a safe alternative. + + Args: + weights (Dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys + from the backward method to match the weights keys to assign weight to each of the provided loss. + rescale_grads (bool): Whether to rescale gradients or not, without. If False, this is just + a regular weighted sum of losses. + total_norm (float): Reference norm when rescaling gradients, ignored otherwise. + emay_decay (float): EMA decay for averaging the norms when `rescale_grads` is True. + per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds + when rescaling the gradients. + epsilon (float): Epsilon value for numerical stability. + monitor (bool): Whether to store additional ratio for each loss key in metrics. + """ + + def __init__(self, weights: tp.Dict[str, float], rescale_grads: bool = True, total_norm: float = 1., + ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12, + monitor: bool = False): + self.weights = weights + self.per_batch_item = per_batch_item + self.total_norm = total_norm + self.averager = averager(ema_decay) + self.epsilon = epsilon + self.monitor = monitor + self.rescale_grads = rescale_grads + self._metrics: tp.Dict[str, tp.Any] = {} + + @property + def metrics(self): + return self._metrics + + def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor): + norms = {} + grads = {} + for name, loss in losses.items(): + grad, = autograd.grad(loss, [input], retain_graph=True) + if self.per_batch_item: + dims = tuple(range(1, grad.dim())) + norm = grad.norm(dim=dims).mean() + else: + norm = grad.norm() + norms[name] = norm + grads[name] = grad + + count = 1 + if self.per_batch_item: + count = len(grad) + avg_norms = average_metrics(self.averager(norms), count) + total = sum(avg_norms.values()) + + self._metrics = {} + if self.monitor: + for k, v in avg_norms.items(): + self._metrics[f'ratio_{k}'] = v / total + + total_weights = sum([self.weights[k] for k in avg_norms]) + ratios = {k: w / total_weights for k, w in self.weights.items()} + + out_grad: tp.Any = 0 + for name, avg_norm in avg_norms.items(): + if self.rescale_grads: + scale = ratios[name] * self.total_norm / (self.epsilon + avg_norm) + grad = grads[name] * scale + else: + grad = self.weights[name] * grads[name] + out_grad += grad + input.backward(out_grad) + + +def test(): + from torch.nn import functional as F + x = torch.zeros(1, requires_grad=True) + one = torch.ones_like(x) + loss_1 = F.l1_loss(x, one) + loss_2 = 100 * F.l1_loss(x, -one) + losses = {'1': loss_1, '2': loss_2} + + balancer = Balancer(weights={'1': 1, '2': 1}, rescale_grads=False) + balancer.backward(losses, x) + assert torch.allclose(x.grad, torch.tensor(99.)), x.grad + + loss_1 = F.l1_loss(x, one) + loss_2 = 100 * F.l1_loss(x, -one) + losses = {'1': loss_1, '2': loss_2} + x.grad = None + balancer = Balancer(weights={'1': 1, '2': 1}, rescale_grads=True) + balancer.backward({'1': loss_1, '2': loss_2}, x) + assert torch.allclose(x.grad, torch.tensor(0.)), x.grad + + +if __name__ == '__main__': + test() \ No newline at end of file