Skip to content

Commit

Permalink
[PyTorch] Use same API in optimizer zero_grad as PyTorch optimizers (
Browse files Browse the repository at this point in the history
…#1466)

Use same API in optimizer zero_grad as PyT optimizers

Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 authored Feb 22, 2025
1 parent 257345a commit b4fbc2b
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 32 deletions.
67 changes: 48 additions & 19 deletions transformer_engine/pytorch/optimizers/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
# See LICENSE for license information.

"""Fused Adam optimizer."""
from __future__ import annotations
from collections.abc import Iterable
from copy import deepcopy
from itertools import chain
from typing import Optional
import warnings

import torch
import transformer_engine_torch as tex
Expand Down Expand Up @@ -52,8 +56,6 @@ class FusedAdam(torch.optim.Optimizer):
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
bias_correction (bool, optional): apply correction factor to
moment estimates. (default: True)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
Expand All @@ -62,10 +64,10 @@ class FusedAdam(torch.optim.Optimizer):
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
bias_correction (bool, optional): apply correction factor to
moment estimates. (default: True)
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
capturable (bool, optional): whether to use the version of the optimizer
that can be used with CUDA Graphs. (default: False)
master_weights (bool, optional): whether to maintain FP32 master weights
Expand Down Expand Up @@ -106,22 +108,23 @@ class FusedAdam(torch.optim.Optimizer):

def __init__(
self,
params,
lr=1e-3,
params: Iterable[torch.nn.Parameter | dict],
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
amsgrad: bool = False,
*,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
adam_w_mode=True,
weight_decay=0.0,
amsgrad=False,
set_grad_none=True,
capturable=False,
master_weights=False,
master_weight_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
use_decoupled_grad=False,
store_param_remainders=False,
set_grad_none: Optional[bool] = None, # deprecated
):

if amsgrad:
Expand Down Expand Up @@ -160,7 +163,6 @@ def __init__(
}
super().__init__(params, defaults)
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none

self.capturable = capturable
self.master_weights = master_weights
Expand Down Expand Up @@ -204,19 +206,46 @@ def __init__(
store_param_remainders and master_weights and master_weight_dtype == torch.float32
)

def zero_grad(self):
# pylint: disable=missing-function-docstring
if not self.use_decoupled_grad and not self.set_grad_none:
super().zero_grad()
# Deprecated options
self.set_grad_none = set_grad_none
if self.set_grad_none is not None:
warnings.warn(
"set_grad_none kwarg in FusedAdam constructor is deprecated. "
"Use set_to_none kwarg in zero_grad instead.",
DeprecationWarning,
)

def zero_grad(self, set_to_none: Optional[bool] = None) -> None:
"""Reset parameter gradients.
Arguments:
set_to_none (bool, optional): whether to set grads to `None`
instead of zeroing out buffers. (default: True)
"""

# Handle deprecated set_grad_none option
if self.set_grad_none is not None:
if set_to_none is not None and set_to_none != self.set_grad_none:
raise ValueError(
f"Called zero_grad with set_to_none={set_to_none}, "
f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}"
)
set_to_none = self.set_grad_none
if set_to_none is None:
set_to_none = True

if not self.use_decoupled_grad and not set_to_none:
super().zero_grad(set_to_none=set_to_none)
return

for group in self.param_groups:
for p in group["params"]:
if self.use_decoupled_grad and self.set_grad_none:
if self.use_decoupled_grad and set_to_none:
p.decoupled_grad = None
elif self.use_decoupled_grad and not self.set_grad_none:
elif self.use_decoupled_grad and not set_to_none:
p.decoupled_grad.zero_()
elif not self.use_decoupled_grad and self.set_grad_none:
elif not self.use_decoupled_grad and set_to_none:
p.grad = None

def _apply_scale(self, state_name, unscaled_state, scaled_state, scale):
Expand Down
59 changes: 46 additions & 13 deletions transformer_engine/pytorch/optimizers/fused_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
# See LICENSE for license information.

"""Fused SGD optimizer."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import warnings

import torch
from torch.optim.optimizer import Optimizer, required

Expand Down Expand Up @@ -37,8 +42,8 @@ class FusedSGD(Optimizer):
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
Expand Down Expand Up @@ -74,15 +79,16 @@ class FusedSGD(Optimizer):

def __init__(
self,
params,
lr=required,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
params: Iterable[torch.nn.Parameter | dict],
lr: float | Any = required,
momentum: float = 0.0,
dampening: float = 0.0,
weight_decay: float = 0.0,
nesterov: bool = False,
*,
wd_after_momentum=False,
materialize_master_grads=True,
set_grad_none=False,
set_grad_none: Optional[bool] = None, # deprecated
):
if lr is not required and lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
Expand All @@ -98,30 +104,57 @@ def __init__(
"weight_decay": weight_decay,
"nesterov": nesterov,
}
if nesterov and (momentum <= 0 or dampening != 0):
if nesterov and (momentum <= 0.0 or dampening != 0.0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)

self.wd_after_momentum = wd_after_momentum
self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
self.set_grad_none = set_grad_none

# Skip buffer
self._dummy_overflow_buf = torch.tensor(
[0], dtype=torch.int, device=self.param_groups[0]["params"][0].device
)
self.multi_tensor_sgd = tex.multi_tensor_sgd

# Deprecated options
self.set_grad_none = set_grad_none
if self.set_grad_none is not None:
warnings.warn(
"set_grad_none kwarg in FusedAdam constructor is deprecated. "
"Use set_to_none kwarg in zero_grad instead.",
DeprecationWarning,
)

def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)

def zero_grad(self):
# pylint: disable=missing-function-docstring
if self.set_grad_none:
def zero_grad(self, set_to_none: Optional[bool] = None) -> None:
"""Reset parameter gradients.
Arguments:
set_to_none (bool, optional): whether to set grads to `None`
instead of zeroing out buffers. (default: True)
"""

# Handle deprecated set_grad_none option
if self.set_grad_none is not None:
if set_to_none is not None and set_to_none != self.set_grad_none:
raise ValueError(
f"Called zero_grad with set_to_none={set_to_none}, "
f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}"
)
set_to_none = self.set_grad_none
if set_to_none is None:
set_to_none = True

# Reset grads
if set_to_none:
for group in self.param_groups:
for p in group["params"]:
p.grad = None
Expand Down

0 comments on commit b4fbc2b

Please sign in to comment.