From 32cf5beaa71ab95c764da1c0e8b446ea09311a1a Mon Sep 17 00:00:00 2001 From: Han Zhu <1106766460@qq.com> Date: Fri, 13 Dec 2024 19:10:34 +0800 Subject: [PATCH 1/2] Refactor optimizer --- egs/librispeech/ASR/zipformer/optim.py | 420 +++++++++++-------------- 1 file changed, 187 insertions(+), 233 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 8434fab138..c83173920e 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -121,6 +121,139 @@ def batched_params(self, param_group, group_params_names): p.copy_(stacked_params[i]) +def basic_step(group, p, state, grad): + # computes basic Adam update using beta2 (dividing by gradient stddev) only. no momentum yet. + lr = group["lr"] + if p.numel() == p.shape[0]: + lr = lr * group["scalar_lr_scale"] + beta2 = group["betas"][1] + eps = group["eps"] + # p shape: (batch_size,) or (batch_size, 1, [1,..]) + try: + exp_avg_sq = state[ + "exp_avg_sq" + ] # shape: (batch_size,) or (batch_size, 1, [1,..]) + except KeyError: + exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["exp_avg_sq"] = exp_avg_sq + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # bias_correction2 is like in Adam. + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + denom = exp_avg_sq.sqrt().add_(eps) + + return -lr * grad / denom + + +def scaling_step(group, p, state, grad): + delta = basic_step(group, p, state, grad) + if p.numel() == p.shape[0]: + return delta # there is no scaling for scalar parameters. (p.shape[0] is the batch of parameters.) + + step = state["step"] + size_update_period = group["size_update_period"] + + try: + param_rms = state["param_rms"] + scale_grads = state["scale_grads"] + scale_exp_avg_sq = state["scale_exp_avg_sq"] + except KeyError: + # we know p.ndim > 1 because we'd have returned above if not, so don't worry + # about the speial case of dim=[] that pytorch treats inconsistently. + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + param_rms = param_rms.to(torch.float) + scale_exp_avg_sq = torch.zeros_like(param_rms) + scale_grads = torch.zeros( + size_update_period, *param_rms.shape, dtype=torch.float, device=p.device + ) + state["param_rms"] = param_rms + state["scale_grads"] = scale_grads + state["scale_exp_avg_sq"] = scale_exp_avg_sq + + # on every step, update the gradient w.r.t. the scale of the parameter, we + # store these as a batch and periodically update the size (for speed only, to + # avoid too many operations). + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True + ) + + # periodically recompute the value of param_rms. + if step % size_update_period == size_update_period - 1: + param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()) + + param_min_rms = group["param_min_rms"] + + # scale the step size by param_rms. This is the most important "scaling" part of + # ScaledAdam + delta *= param_rms.clamp(min=param_min_rms) + + if step % size_update_period == size_update_period - 1 and step > 0: + # This block updates the size of parameter by adding a step ("delta") value in + # the direction of either shrinking or growing it. + beta2 = group["betas"][1] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + batch_size = p.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2**size_update_period + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr**size_step + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) + + is_too_small = param_rms < param_min_rms + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + + # The following may help prevent instability: don't allow the scale step to be too large in + # either direction. + scale_step.clamp_(min=-0.1, max=0.1) + + # and ensure the parameter rms after update never exceeds param_max_rms. + # We have to look at the trained model for parameters at or around the + # param_max_rms, because sometimes they can indicate a problem with the + # topology or settings. + scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) + + delta.add_(p * scale_step) + + return delta + + +def momentum_step(group, p, state, grad): + delta = scaling_step(group, p, state, grad) + beta1 = group["betas"][0] + try: + stored_delta = state["delta"] + except KeyError: + stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["delta"] = stored_delta + stored_delta.mul_(beta1) + stored_delta.add_(delta, alpha=(1 - beta1)) + # we don't bother doing the "bias correction" part of Adam for beta1 because this is just + # an edge effect that affects the first 10 or so batches; and the effect of not doing it + # is just to do a slower update for the first few batches, which will help stability. + return stored_delta + + class ScaledAdam(BatchedOptimizer): """ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update @@ -352,57 +485,25 @@ def step(self, closure=None): raise RuntimeError( "ScaledAdam optimizer does not support sparse gradients" ) - # State initialization - if len(state) == 0: - self._init_state(group, p, state) - - self._step_one_batch(group, p, state, clipping_scale) - - return loss - - def _init_state(self, group: dict, p: Tensor, state: dict): - """ - Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p - is actually the batch dimension, corresponding to batched-together - parameters of a given shape. - - - Args: - group: Dict to look up configuration values. - p: The parameter that we are initializing the state for - state: Dict from string to whatever state we are initializing - """ - size_update_period = group["size_update_period"] - state["step"] = 0 + try: + cur_step = state["step"] + except KeyError: + state["step"] = 0 + cur_step = 0 - kwargs = {"device": p.device, "dtype": p.dtype} + grad = ( + p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale) + ) + p += momentum_step(group, p.detach(), state, grad) - # 'delta' implements conventional momentum. There are - # several different kinds of update going on, so rather than - # compute "exp_avg" like in Adam, we store and decay a - # parameter-change "delta", which combines all forms of - # update. this is equivalent to how it's done in Adam, - # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) + if p.numel() == p.shape[0]: # scalar parameter + scalar_max = group["scalar_max"] + p.clamp_(min=-scalar_max, max=scalar_max) - batch_size = p.shape[0] - numel = p.numel() // batch_size - - if numel > 1: - # "param_rms" just periodically records the scalar root-mean-square value of - # the parameter tensor. - # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - state["param_rms"] = param_rms - - state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros( - size_update_period, *param_rms.shape, **kwargs - ) + state["step"] = cur_step + 1 - # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + return loss def _get_clipping_scale( self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] @@ -484,7 +585,7 @@ def _get_clipping_scale( ) first_state["num_clipped"] = 0 quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.warn( + logging.warning( f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" ) @@ -499,8 +600,8 @@ def _get_clipping_scale( ans = 0.0 if ans < 1.0: first_state["num_clipped"] += 1 - if ans < 0.1: - logging.warn( + if ans < 0.5: + logging.warning( f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" ) if self.show_dominant_parameters: @@ -508,6 +609,7 @@ def _get_clipping_scale( self._show_gradient_dominating_parameter( tuples, tot_sumsq, group["scalar_lr_scale"] ) + self._show_param_with_unusual_grad(tuples) if ans == 0.0: for (p, state, param_names) in tuples: @@ -515,6 +617,36 @@ def _get_clipping_scale( return ans + def _show_param_with_unusual_grad( + self, + tuples: List[Tuple[Tensor, dict, List[str]]], + ): + """ + Print information about parameter which has the largest ratio of grad-on-this-batch + divided by normal grad size. + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + largest_ratio = 0.0 + largest_name = "" + for (p, state, batch_param_names) in tuples: + dims = list(range(1, p.ndim)) + grad_ratio = (p.grad**2).mean(dim=dims) / state["exp_avg_sq"].mean( + dim=dims + ) + max_grad_ratio, max_index = grad_ratio.to("cpu").max(dim=0) + if max_grad_ratio.item() > largest_ratio: + largest_ratio = max_grad_ratio.item() + largest_name = batch_param_names[max_index.item()] + logging.warning( + f"Parameter with most larger-than-usual grad is {largest_name}, with ratio (cur_grad / normal_grad) of " + f"{largest_ratio ** 0.5}" + ) + def _show_gradient_dominating_parameter( self, tuples: List[Tuple[Tensor, dict, List[str]]], @@ -572,7 +704,7 @@ def _show_gradient_dominating_parameter( dominant_rms, dominant_grad, ) = sorted_by_proportion[dominant_param_name] - logging.warn( + logging.warning( f"Parameter dominating tot_sumsq {dominant_param_name}" f" with proportion {dominant_proportion:.2f}," f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" @@ -581,183 +713,6 @@ def _show_gradient_dominating_parameter( f" orig_rms_sq={(dominant_rms**2).item():.3e}" ) - def _step_one_batch( - self, group: dict, p: Tensor, state: dict, clipping_scale: float - ): - """ - Do the step for one parameter, which is actually going to be a batch of - `real` parameters, with dim 0 as the batch dim. - Args: - group: dict to look up configuration values - p: parameter to update (actually multiple parameters stacked together - as a batch) - state: state-dict for p, to look up the optimizer state - """ - lr = group["lr"] - size_update_period = group["size_update_period"] - beta1 = group["betas"][0] - - grad = p.grad - if clipping_scale != 1.0: - grad *= clipping_scale - step = state["step"] - delta = state["delta"] - - delta.mul_(beta1) - batch_size = p.shape[0] - numel = p.numel() // batch_size - if numel > 1: - # Update the size/scale of p, and set param_rms - scale_grads = state["scale_grads"] - scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) - if step % size_update_period == size_update_period - 1: - param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) - if step > 0: - # self._size_update() learns the overall scale on the - # parameter, by shrinking or expanding it. - self._size_update(group, scale_grads, p, state) - - if numel == 1: - # For parameters with 1 element we just use regular Adam. - # Updates delta. - self._step_scalar(group, p, state) - else: - self._step(group, p, state) - - state["step"] = step + 1 - - def _size_update( - self, group: dict, scale_grads: Tensor, p: Tensor, state: dict - ) -> None: - """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p - """ - - param_rms = state["param_rms"] - beta1, beta2 = group["betas"] - size_lr = group["lr"] * group["scalar_lr_scale"] - param_min_rms = group["param_min_rms"] - param_max_rms = group["param_max_rms"] - eps = group["eps"] - step = state["step"] - batch_size = p.shape[0] - - size_update_period = scale_grads.shape[0] - # correct beta2 for the size update period: we will have - # faster decay at this level. - beta2_corr = beta2**size_update_period - - scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) - scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) - - # The 1st time we reach here is when size_step == 1. - size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step - # we don't bother with bias_correction1; this will help prevent divergence - # at the start of training. - - denom = scale_exp_avg_sq.sqrt() + eps - - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) - - is_too_small = param_rms < param_min_rms - - # when the param gets too small, just don't shrink it any further. - scale_step.masked_fill_(is_too_small, 0.0) - - # and ensure the parameter rms after update never exceeds param_max_rms. - # We have to look at the trained model for parameters at or around the - # param_max_rms, because sometimes they can indicate a problem with the - # topology or settings. - scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) - - delta = state["delta"] - # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, group: dict, p: Tensor, state: dict): - """ - This function does the core update of self.step(), in the case where the members of - the batch have more than 1 element. - - Args: - group: A dict which will be used to look up configuration values - p: The parameter to be updated - grad: The grad of p - state: The state-dict corresponding to parameter p - - This function modifies p. - """ - grad = p.grad - lr = group["lr"] - beta1, beta2 = group["betas"] - eps = group["eps"] - param_min_rms = group["param_min_rms"] - step = state["step"] - - exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) - - this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) - bias_correction2 = 1 - beta2 ** (this_step + 1) - if bias_correction2 < 0.99: - # note: not in-place. - exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) - - denom = exp_avg_sq.sqrt() - denom += eps - grad = grad / denom - - alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) - - delta = state["delta"] - delta.add_(grad * alpha) - p.add_(delta) - - def _step_scalar(self, group: dict, p: Tensor, state: dict): - """ - A simplified form of the core update for scalar tensors, where we cannot get a good - estimate of the parameter rms. - """ - beta1, beta2 = group["betas"] - scalar_max = group["scalar_max"] - eps = group["eps"] - lr = group["lr"] * group["scalar_lr_scale"] - grad = p.grad - - exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # bias_correction2 is like in Adam. Don't bother with bias_correction1; - # slower update at the start will help stability anyway. - bias_correction2 = 1 - beta2 ** (state["step"] + 1) - denom = (exp_avg_sq / bias_correction2).sqrt() + eps - - delta = state["delta"] - delta.add_(grad / denom, alpha=-lr * (1 - beta1)) - p.clamp_(min=-scalar_max, max=scalar_max) - p.add_(delta) - class LRScheduler(object): """ @@ -787,9 +742,9 @@ def state_dict(self): is not the optimizer. """ return { - # the user might try to override the base_lr, so don't include this in the state. - # previously they were included. - # "base_lrs": self.base_lrs, + # the user might try to override the base_lr, so don't include this in the state. + # previously they were included. + # "base_lrs": self.base_lrs, "epoch": self.epoch, "batch": self.batch, } @@ -807,7 +762,6 @@ def load_state_dict(self, state_dict): self.__dict__.update(state_dict) self.base_lrs = base_lrs - def get_last_lr(self) -> List[float]: """Return last computed learning rate by current scheduler. Will be a list of float.""" return self._last_lr @@ -853,7 +807,7 @@ def _set_lrs(self): def print_lr(self, is_verbose, group, lr): """Display the current learning rate.""" if is_verbose: - logging.warn( + logging.warning( f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" f" of group {group} to {lr:.4e}." ) @@ -1184,7 +1138,7 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() From 06b9138f3348f8841f93654c03778b946a65c3d8 Mon Sep 17 00:00:00 2001 From: Han Zhu <1106766460@qq.com> Date: Mon, 30 Dec 2024 14:00:13 +0800 Subject: [PATCH 2/2] Print indexes of largest grad --- egs/librispeech/ASR/zipformer/optim.py | 41 +++++++++++++++++++++----- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index c83173920e..8a17646513 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -633,18 +633,37 @@ def _show_param_with_unusual_grad( """ largest_ratio = 0.0 largest_name = "" + # ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor) + ratios_names = [] for (p, state, batch_param_names) in tuples: dims = list(range(1, p.ndim)) - grad_ratio = (p.grad**2).mean(dim=dims) / state["exp_avg_sq"].mean( - dim=dims + + def mean(x): + # workaround for bad interface of torch's "mean" for when dims is the empty list. + if len(dims) > 0: + return x.mean(dim=dims) + else: + return x + + grad_ratio = ( + (mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims)) + .sqrt() + .to("cpu") + ) + + ratios_names += zip( + grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0) ) - max_grad_ratio, max_index = grad_ratio.to("cpu").max(dim=0) - if max_grad_ratio.item() > largest_ratio: - largest_ratio = max_grad_ratio.item() - largest_name = batch_param_names[max_index.item()] + + ratios_names = sorted(ratios_names, reverse=True) + ratios_names = ratios_names[:10] + ratios_names = [ + (ratio, name, largest_index(tensor)) + for (ratio, name, tensor) in ratios_names + ] + logging.warning( - f"Parameter with most larger-than-usual grad is {largest_name}, with ratio (cur_grad / normal_grad) of " - f"{largest_ratio ** 0.5}" + f"Parameters with most larger-than-usual grads, with ratios, are: {ratios_names}" ) def _show_gradient_dominating_parameter( @@ -714,6 +733,12 @@ def _show_gradient_dominating_parameter( ) +def largest_index(x: Tensor): + x = x.contiguous() + argmax = x.abs().argmax().item() + return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)] + + class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the