diff --git a/praxis/optimizers.py b/praxis/optimizers.py index 5a6982b6..d7e25d73 100644 --- a/praxis/optimizers.py +++ b/praxis/optimizers.py @@ -2691,6 +2691,8 @@ def _get_raw_grad_transformation(self, lr: optax.Schedule): def sharded_static_accumulation( num_sub_batches: int, + clip_gradient_norm_to_value: float, + clip_gradient_single_norm_to_value: float, base_tx: ShardedGradientTransformation, ) -> ShardedGradientTransformation: """Gradient transformation for ShardedStaticAccumulator optimizer.""" @@ -2759,10 +2761,52 @@ def update_fn(updates: NestedJTensor, lambda: new_count) def _run_base_tx(): + + def _compute_grad_norm(grads: NestedMap) -> JTensor: + """Computes total grad norm.""" + grad_norms_squared = jax.tree_map(lambda x: jnp.sum(x * x), grads) + grad_norms_squared, _ = jax.tree_util.tree_flatten(grad_norms_squared) + return jnp.sqrt(jnp.sum(jnp.stack(grad_norms_squared))) + + + def scale_gradients( + raw_grads: NestedMap, + clip_grad_norm_to_value: float = 0.0, + clip_grad_single_norm_to_value: float = 0.0): + + def clip_grads(grads): + assert not (clip_grad_norm_to_value and clip_grad_single_norm_to_value) + if clip_grad_norm_to_value: + grad_norm = _compute_grad_norm(raw_grads) + + grad_scale = jnp.minimum( + jnp.array(1, grad_norm.dtype), + jnp.array(clip_grad_norm_to_value, grad_norm.dtype) + / grad_norm) + grads = jax.tree_map(lambda g: g * grad_scale, grads) + elif clip_grad_single_norm_to_value: + grad_single_norm = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), + grads) + + def scale_gradient(grad, norm): + return grad * jnp.minimum( + jnp.array(1, norm.dtype), + jnp.array(clip_grad_single_norm_to_value, + norm.dtype) / norm) + grads = jax.tree_map(scale_gradient, grads, grad_single_norm) + + return grads + + grads = clip_grads(raw_grads) + return grads + averaged_updated = jax.tree_map(lambda acc: acc / num_sub_batches, new_accumulated_update) + scaled_updated = scale_gradients(averaged_updated, + clip_gradient_norm_to_value, + clip_gradient_single_norm_to_value) emission_updates, emission_base_state = base_tx.update( - averaged_updated, state.base_state, params) # pytype: disable=attribute-error # jax-ndarray + scaled_updated, state.base_state, params) # pytype: disable=attribute-error # jax-ndarray return (emission_updates, jax.tree_map(lambda u: jnp.zeros_like(u, dtype=jnp.float32), updates), emission_base_state) @@ -2830,4 +2874,5 @@ def _get_raw_grad_transformation( self, lr: optax.Schedule) -> GeneralGradientTransformation: p = self._hparams base_tx = self.base_optimizer._get_raw_grad_transformation(lr) # pylint: disable=protected-access - return sharded_static_accumulation(p.num_sub_batches, base_tx) + return sharded_static_accumulation(p.num_sub_batches, p.clip_gradient_norm_to_value, + p.clip_gradient_single_norm_to_value, base_tx)