-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Deepspeed integration #4693
base: main
Are you sure you want to change the base?
Deepspeed integration #4693
Changes from 7 commits
e2ac4b5
619657e
a329fd2
00666c2
f0da3bf
d0e8a68
0a74573
eaf8aa5
498d3a2
a211b5e
3b30e21
fdd888b
ef544c9
083a6d0
0f8d5b7
4e4f7d7
f48ea19
b3328fc
966e296
2fdb7c0
95a9e5f
b152fe1
5b82534
4fb6604
e21fb1f
703843c
e7b8825
3fc1835
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from allennlp.training.deepspeed.trainer import DeepspeedTrainer | ||
jacobdanovitch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from allennlp.training.deepspeed.optimizers import ( | ||
FusedAdamOptimizer, | ||
DeepspeedCPUAdamOptimizer, | ||
FusedLambOptimizer | ||
) | ||
|
||
try: | ||
from allennlp.training.deepspeed.sparse_transformer_embedder import SparseTransformerEmbedder | ||
except ImportError: | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from typing import Dict, Any | ||
from enum import IntEnum | ||
from allennlp.common import FromParams | ||
from dataclasses import dataclass, asdict | ||
|
||
|
||
@dataclass | ||
class DeepspeedFP16Config(FromParams): | ||
enabled: bool = True | ||
loss_scale: float = 0. | ||
initial_scale_power: int = 32 | ||
loss_scale_window: int = 1000 | ||
hysteresis: int = 2 | ||
min_loss_scale: float = 1. | ||
|
||
@dataclass | ||
class DeepspeedAMPConfig(FromParams): | ||
enabled: bool = False | ||
opt_level: str = "O1" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought AMP was dead and we now use things built directly into PyTorch? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah but it's a required install for deepspeed and you can use it there, so I thought I would keep it in for compatibility. It can be removed if need be. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm. Surely in the next DeepSpeed version they will make it use PyTorch-native AMP. But if we need it for now, that's cool. |
||
|
||
@dataclass | ||
class DeepspeedOptimizerConfig(FromParams): | ||
type: str | ||
params: Dict[str, Any] | ||
|
||
class DeepspeedZeROStage(IntEnum): | ||
DISABLED = 0 | ||
OPTIMIZER = 1 | ||
GRADIENT = 2 | ||
|
||
@dataclass | ||
class DeepspeedZeROConfig(FromParams): | ||
stage: DeepspeedZeROStage = DeepspeedZeROStage.GRADIENT | ||
allgather_partitions: bool = True | ||
allgather_bucket_size: int = 500000000 | ||
overlap_comm: bool = False | ||
reduce_scatter: bool = True | ||
reduce_bucket_size: int = 500000000 | ||
contiguous_gradients: bool = False | ||
cpu_offload: bool = False | ||
|
||
|
||
@dataclass | ||
class DeepspeedConfig(FromParams): | ||
zero_optimization: DeepspeedZeROConfig | ||
fp16: DeepspeedFP16Config | ||
amp: DeepspeedAMPConfig = DeepspeedAMPConfig() | ||
optimizer: DeepspeedOptimizerConfig = None | ||
|
||
zero_allow_untested_optimizer: bool = True | ||
wall_clock_breakdown: bool = False | ||
|
||
def to_dict(self): | ||
return asdict(self) | ||
|
||
|
||
@dataclass | ||
class DeepspeedArgs(FromParams): | ||
local_rank: int | ||
deepspeed: bool = True | ||
deepspeed_mpi: bool = False | ||
deepspeed_config: str = None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from typing import List, Tuple, Dict, Any | ||
|
||
import torch | ||
|
||
from apex.optimizers.fused_adam import FusedAdam | ||
from deepspeed.ops.adam import DeepSpeedCPUAdam | ||
from deepspeed.ops.lamb import FusedLamb | ||
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam | ||
|
||
from allennlp.training.optimizers import Optimizer, make_parameter_groups | ||
|
||
@Optimizer.register("fused_adam") | ||
class FusedAdamOptimizer(Optimizer, FusedAdam): | ||
def __init__( | ||
self, | ||
model_parameters: List[Tuple[str, torch.nn.Parameter]], | ||
parameter_groups: List[Tuple[List[str], Dict[str, Any]]] = None, | ||
lr: float = 0.001, | ||
betas: Tuple[float, float] = (0.9, 0.999), | ||
eps: float = 1e-08, | ||
weight_decay: float = 0.0, | ||
amsgrad: bool = False, | ||
bias_correction: bool =True, | ||
adam_w_mode: bool = True, | ||
set_grad_none: bool = True, | ||
): | ||
super().__init__( | ||
params=make_parameter_groups(model_parameters, parameter_groups), | ||
lr=lr, | ||
betas=betas, | ||
eps=eps, | ||
weight_decay=weight_decay, | ||
amsgrad=amsgrad, | ||
bias_correction=bias_correction, | ||
adam_w_mode=adam_w_mode, | ||
set_grad_none=set_grad_none, | ||
) | ||
|
||
# This does not currently work | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not? If it doesn't work and is not necessary, can we remove it? |
||
@Optimizer.register("cpu_adam") | ||
class DeepspeedCPUAdamOptimizer(Optimizer, DeepSpeedCPUAdam): | ||
def __init__( | ||
self, | ||
model_parameters: List[Tuple[str, torch.nn.Parameter]], | ||
parameter_groups: List[Tuple[List[str], Dict[str, Any]]] = None, | ||
lr: float = 0.001, | ||
betas: Tuple[float, float] = (0.9, 0.999), | ||
eps: float = 1e-08, | ||
weight_decay: float = 0.0, | ||
amsgrad: bool = False, | ||
): | ||
super().__init__( | ||
model_params=make_parameter_groups(model_parameters, parameter_groups), | ||
lr=lr, | ||
betas=betas, | ||
eps=eps, | ||
weight_decay=weight_decay, | ||
amsgrad=amsgrad | ||
) | ||
|
||
@Optimizer.register("fused_lamb") | ||
class FusedLambOptimizer(Optimizer, FusedLamb): | ||
def __init__( | ||
self, | ||
model_parameters: List[Tuple[str, torch.nn.Parameter]], | ||
parameter_groups: List[Tuple[List[str], Dict[str, Any]]] = None, | ||
lr: float = 0.001, | ||
betas: Tuple[float, float] = (0.9, 0.999), | ||
eps: float = 1e-08, | ||
eps_inside_sqrt: bool = False, | ||
weight_decay: float = 0.0, | ||
amsgrad: bool = False, | ||
max_grad_norm: float = 0., | ||
max_coeff: float = 10.0, | ||
min_coeff: float = 0.01 | ||
): | ||
super().__init__( | ||
params=make_parameter_groups(model_parameters, parameter_groups), | ||
lr=lr, | ||
betas=betas, | ||
eps=eps, | ||
weight_decay=weight_decay, | ||
amsgrad=amsgrad, | ||
max_grad_norm=max_grad_norm, | ||
max_coeff=max_coeff, | ||
min_coeff=min_coeff, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder | ||
from allennlp.modules.token_embedders.pretrained_transformer_embedder import PretrainedTransformerEmbedder | ||
|
||
from deepspeed.ops.sparse_attention.sparse_attention_utils import SparseAttentionUtils | ||
|
||
@TokenEmbedder.register('sparse_transformer') | ||
class SparseTransformerEmbedder(PretrainedTransformerEmbedder): | ||
class __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.transformer_model = SparseAttentionUtils.replace_model_self_attention_with_sparse_self_attention(self.transformer_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this leftover debug code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depends on how we want to include this. Based on my experience, I wouldn't recommend making deepspeed a required dependency. If we're doing the
pip install allennlp[deepspeed]
thing, this could be replaced/updated (not sure offhand how that gets handled but I can look for some examples).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you don't mind doing the work making it optional, then let's make it optional.