Skip to content

Commit

Permalink
Add GPT and ViT models
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Nov 14, 2022
1 parent d4b320b commit 2e33fc8
Show file tree
Hide file tree
Showing 8 changed files with 1,209 additions and 5 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Our tentative roadmap:
6. ~~[Jul 2022] Implement cross-attention~~[Done].
7. ~~[Jul 2022] Support head dimension 128~~[Done].
8. [Jul 2022] Support SM70 GPUs (V100).
9. [Aug 2022] Fuse rotary embedding.
9. ~~[Aug 2022] Fuse rotary embedding~~[Done].
10. [Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding).

## Speedup and Memory Savings
Expand Down Expand Up @@ -154,10 +154,10 @@ and for his thoughtful answers to our questions about CUDA.
## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
```
@article{dao2022flashattention,
@inproceedings{dao2022flashattention,
title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
journal={arXiv preprint arXiv:2205.14135},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
```
2 changes: 1 addition & 1 deletion csrc/fused_dense_lib/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
This CUDA extensions implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu
This CUDA extension implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu
(forward and backward), adapted from Apex's
[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense).
We make it work for bfloat16.
Expand Down
2 changes: 1 addition & 1 deletion csrc/layer_norm/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
This CUDA extensions implements fused dropout + residual + LayerNorm, based on
This CUDA extension implements fused dropout + residual + LayerNorm, based on
Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
```sh
Expand Down
174 changes: 174 additions & 0 deletions flash_attn/models/gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) 2022, Tri Dao.

import math
from functools import partial

from collections import namedtuple
from collections.abc import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers.models.gpt2.configuration_gpt2 import GPT2Config

from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import GPT2Embeddings

try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
dropout_add_layer_norm = None

try:
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
except ImportError:
FusedDenseSqreluDense = None


def create_mixer_cls(config, layer_idx=None):
head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
if config.scale_attn_by_inverse_layer_idx:
assert layer_idx is not None
softmax_scale /= float(layer_idx + 1)
dwconv = getattr(config, 'attn_dwconv', False)
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
use_flash_attn = getattr(config, 'use_flash_attn', False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
softmax_scale=softmax_scale, causal=True, dwconv=dwconv,
rotary_emb_dim=rotary_emb_dim,
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn)
return mixer_cls


def create_mlp_cls(config, layer_idx=None):
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
assert not (fused_dense_sqrelu_dense and fused_dense_gelu_dense)
if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense:
mlp_cls = partial(Mlp, hidden_features=inner_dim,
activation=partial(F.gelu, approximate='tanh'))
else:
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if isinstance(mlp_checkpoint_lvl, Sequence):
assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
if fused_dense_gelu_dense:
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl)
elif fused_dense_sqrelu_dense:
assert FusedDenseSqreluDense is not None
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl)
else:
raise RuntimeError('MLP type not supported')
return mlp_cls


def create_block(config, layer_idx=None):
mixer_cls = create_mixer_cls(config, layer_idx)
mlp_cls = create_mlp_cls(config, layer_idx)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon)
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=True, resid_dropout=config.resid_pdrop,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False))
block.layer_idx = layer_idx
return block


# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)

if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))


class GPT2Model(nn.Module):

def __init__(self, config: GPT2Config):
super().__init__()
self.pad_vocab_size_multiple_8 = getattr(config, 'pad_vocab_size_multiple_8', False)
if self.pad_vocab_size_multiple_8:
if config.vocab_size % 8 != 0:
config.vocab_size += 8 - (config.vocab_size % 8)

self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size,
config.max_position_embeddings)
self.emb_drop = nn.Dropout(config.embd_pdrop)

# We change the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
# the main branch (output of LN). The model definition is unchanged, but the mapping of the
# nn.LayerNorm weights are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
# self.ln_0 is the first layer norm in the model, while self.ln_f (in the pretrained weight)
# is the final layer norm.
self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)

self.layers = nn.ModuleList([create_block(config, layer_idx=i)
for i in range(config.num_hidden_layers)])

self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range))

def forward(self, input_ids, position_ids=None):
hidden_states = self.embeddings(input_ids, position_ids=position_ids)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
if not self.fused_dropout_add_ln:
residual = self.emb_drop(hidden_states).float()
hidden_states = self.ln_0(residual.to(dtype=self.ln_0.weight.dtype))
else:
hidden_states, residual = dropout_add_layer_norm(
hidden_states, None, self.ln_0.weight, self.ln_0.bias,
self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True,
residual_in_fp32=True
)
for layer in self.layers:
hidden_states, residual = layer(hidden_states, residual)
return hidden_states


class GPT2LMHeadModel(nn.Module):

def __init__(self, config: GPT2Config):
super().__init__()
self.transformer = GPT2Model(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range))
self.tie_weights()

def tie_weights(self):
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight

def forward(self, input_ids, position_ids=None):
hidden_states = self.transformer(input_ids, position_ids=position_ids)
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
return CausalLMOutput(logits=lm_logits)
Loading

0 comments on commit 2e33fc8

Please sign in to comment.