Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved docstring for optax.centralize with explanation and example #1220

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/api/transformations.rst
Original file line number Diff line number Diff line change
@@ -138,6 +138,31 @@ Transformations and states

.. autofunction:: centralize

Centralizes gradients by subtracting their mean value, making them zero-centered.
This helps stabilize training and improve convergence in deep learning models.

**Example Usage:**

.. code-block:: python
import optax
import jax.numpy as jnp
grads = {
'w': jnp.array([1.0, 2.0, 3.0]),
'b': jnp.array([0.5, -0.5])
}
centralizer = optax.centralize()
updates, _ = centralizer.init(grads)
updates, _ = centralizer.update(grads, None)
print(updates) # Gradients will be zero-centered
**Reference:**
Yong et al., *Gradient Centralization: A New Optimization Technique for Deep
Neural Networks* (2020). Available at: `<https://arxiv.org/abs/2004.01461>`_.

.. autofunction:: conditionally_mask
.. autoclass:: ConditionallyMaskState

29 changes: 25 additions & 4 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
@@ -1083,14 +1083,35 @@ def _subtract_mean(g):


def centralize() -> base.GradientTransformation:
"""Centralize gradients.
"""
Centralizes gradients by subtracting their mean along the feature dimension.
Gradient centralization re-scales gradients such that their mean across the feature
dimension is zero. This technique has been shown to improve convergence stability
and generalization in deep learning models.
Returns:
A :class:`optax.GradientTransformation` object.
optax.GradientTransformation: A transformation that modifies gradients to be centralized.
Example:
```python
import jax.numpy as jnp
import optax
# Define a dummy gradient
grads = {'param': jnp.array([[1.0, 2.0], [3.0, 4.0]])}
# Apply centralization
centralizer = optax.centralize()
centralized_grads, _ = centralizer.update(grads, optax.EmptyState())
print(centralized_grads)
```
References:
Yong et al, `Gradient Centralization: A New Optimization Technique for Deep
Neural Networks <https://arxiv.org/abs/2004.01461>`_, 2020.
Yong et al., "Gradient Centralization: A New Optimization Technique for Deep
Neural Networks" (2020). Available at:
`<https://arxiv.org/abs/2004.01461>`_.
"""

def init_fn(params):