Skip to content

Commit

Permalink
Implement kv cache (#74)
Browse files Browse the repository at this point in the history
* Allow parsing parameters from a config file (#141)

* Allow parsing parameters from a config file

* Address nits

* Support for beam search and kv_cache

* Added tiny test + black

* separate tiny generate test from generate test

* set dataset version to avoid error: AttributeError: module 'threading' has no attribute '_Condition'.

* compatible datasets version and unified python versions to 3.10

* Using xformers LowerTriangularFromBottomRightMask instead of custom mask

* revert back to using custom mask because it is not compatible with the xformers version that we need

* added tiny beam generation test and fixed cache reorder (thanks Rui)

---------

Co-authored-by: Achal Dave <[email protected]>
  • Loading branch information
jmercat and achalddave authored Dec 13, 2023
1 parent 81aeb87 commit e016855
Show file tree
Hide file tree
Showing 21 changed files with 734 additions and 185 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,5 @@ eval/*.jsonl_tmp
weights*
out*
tests/assets/*
.vscode/
.vscode/
checkpoints/
2 changes: 1 addition & 1 deletion environment-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: open_lm_tests
channels:
- defaults
dependencies:
- python=3.8
- python=3.10
- pip
- pip:
- -r requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: open_lm
channels:
- defaults
dependencies:
- python=3.8
- python=3.10
- pip
- pip:
- -r requirements.txt
Expand Down
8 changes: 5 additions & 3 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from .data import proc_token
from .model import Block
from .losses import CrossEntropyLossWithZLoss
from open_lm.data import proc_token
from open_lm.model import Block
from open_lm.losses import CrossEntropyLossWithZLoss

try:
import wandb
Expand All @@ -44,13 +44,15 @@
tensorboard = None

from open_lm.model import create_model

from open_lm.utils.transformers.hf_wrapper import create_wrapped_hf_model
from open_lm.data import get_data, get_wds_dataset
from open_lm.distributed import is_master, init_distributed_device, broadcast_object
from open_lm.logger import setup_logging
from open_lm.params import parse_args
from open_lm.scheduler import cosine_lr
from open_lm.train import train_one_epoch, evaluate_loop

from open_lm.file_utils import (
pt_load,
check_exists,
Expand Down
114 changes: 94 additions & 20 deletions open_lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,69 @@ class Params:
ffn_type: str = "swiglu"


def get_rectangular_mask(shape, q_seq_len, k_seq_len, device, dtype):
# xformers requires the mask to be built with a shape that is a multiple of 8
# probably because of the way it is implemented in CUDA
next_multiple_8 = (k_seq_len + 7) // 8 * 8 #
mask = torch.ones((q_seq_len, next_multiple_8), device=device, dtype=bool)
mask[:, -q_seq_len:] = torch.tril(mask[:, -q_seq_len:], diagonal=0)
return torch.zeros((*shape, q_seq_len, next_multiple_8), device=device, dtype=dtype).masked_fill(
~mask, float("-inf")
)[:, :, :, :k_seq_len]


def xformers_attn(queries, keys, values, is_causal):
# xformers assumes q, k, v are [batch, seq_len, heads, embed_dim]
# We assume that queries match the last part of the key / value sequences
# see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask)
# we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask()
# sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1

mask = None
if is_causal:
# If queries have shape [batch, 1, heads, dim] it means there is only one query in the sequence.
# In this case, there is no notion of causal masking, so we can just set the mask to None.
# This is actually needed to get the desired behavior with seq_len=1.
if is_causal and queries.shape[1] == keys.shape[1]:
mask = xops.LowerTriangularMask()
elif is_causal and queries.shape[1] > 1:
# Build causal mask that assumes queries are in the end of the sequence.
batch, q_seq_len, heads, _ = queries.shape
k_seq_len = keys.shape[1]
mask = get_rectangular_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype)
return xops.memory_efficient_attention(queries, keys, values, attn_bias=mask)


def torch_attn(queries, keys, values, is_causal):
# Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail.
# Possibly related: https://github.com/pytorch/pytorch/issues/110213 - behavior of scaled_dot_product_attention
# changed between 2.0 and 2.1
return F.scaled_dot_product_attention(queries, keys, values, is_causal=is_causal).contiguous()
if is_causal and keys.shape[1] > queries.shape[1] > 1:
q_seq_len = queries.shape[1]
k_seq_len = keys.shape[1]
# Same as above, we would like to use:
# mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask().materialize((1, 1, q_seq_len, k_seq_len), queries.dtype, queries.device)
mask = get_rectangular_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype)
return (
F.scaled_dot_product_attention(
queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask
)
.transpose(1, 2)
.contiguous()
)
elif queries.shape[1] == 1:
return (
F.scaled_dot_product_attention(queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2))
.transpose(1, 2)
.contiguous()
)
else:
return (
F.scaled_dot_product_attention(
queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), is_causal=is_causal
)
.transpose(1, 2)
.contiguous()
)


def get_pos_embed(args: Params):
Expand Down Expand Up @@ -149,24 +199,37 @@ def reset_parameters(self):
std = std / math.sqrt(2 * (self.layer_id + 1))
torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std)

def forward(self, x: torch.Tensor, is_causal=True):
batchsize, seqlen, _ = x.shape
def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False):
batchsize, q_len, _ = x.shape
queries, keys, vals = self.in_proj(x).chunk(3, dim=-1)

queries = self.q_norm(queries)
keys = self.k_norm(keys)

queries = queries.view(batchsize, seqlen, self.n_heads, self.head_dim)
keys = keys.view(batchsize, seqlen, self.n_heads, self.head_dim)
vals = vals.view(batchsize, seqlen, self.n_heads, self.head_dim)
queries = queries.view(batchsize, q_len, self.n_heads, self.head_dim)
keys = keys.view(batchsize, q_len, self.n_heads, self.head_dim)
vals = vals.view(batchsize, q_len, self.n_heads, self.head_dim)

past_length = 0 if past_key_value is None else past_key_value[0].shape[1]
queries, keys, vals = self.pos_embed(queries, keys, vals, offset=past_length)

queries, keys, vals = self.pos_embed(queries, keys, vals)
if past_key_value is not None and use_cache:
keys = torch.cat([past_key_value[0], keys], dim=1)
vals = torch.cat([past_key_value[1], vals], dim=1)

output = self.attn_fn(queries, keys, vals, is_causal=is_causal)
if use_cache:
past_key_value = [keys, vals]

output = self.attn_fn(
queries,
keys,
vals,
is_causal=is_causal,
)

output = output.view(batchsize, seqlen, -1)
output = output.view(batchsize, q_len, -1)

return self.out_proj(output)
return self.out_proj(output), past_key_value


class Block(nn.Module):
Expand Down Expand Up @@ -218,10 +281,16 @@ def reset_parameters(self):
std = std / math.sqrt(2 * (self._layer_id + 1))
torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std)

def forward(self, x):
h = x + self.attention(self.attention_norm(x), is_causal=True)
def forward(self, x, past_key_value=None, use_cache=False):
h, past_key_value = self.attention(
self.attention_norm(x),
is_causal=True,
past_key_value=past_key_value,
use_cache=use_cache,
)
h = x + h
out = h + self.feed_forward(self.ffn_norm(h))
return out
return out, past_key_value


class Transformer(nn.Module, PyTorchModelHubMixin):
Expand Down Expand Up @@ -271,20 +340,25 @@ def reset_parameters(self):
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable

def forward(self, input):
def forward(self, input, past_key_values=None, use_cache=False):
x = self.tok_embeddings(input)
x = self.post_embed_norm(x)

for layer in self.layers:
if past_key_values is None:
past_key_values = [None] * self.n_layers
elif isinstance(past_key_values, tuple):
past_key_values = list(past_key_values)
for i, layer in enumerate(self.layers):
if self.grad_checkpointing:
x = checkpoint(layer, x)
x, past_key_values[i] = checkpoint(layer, x, past_key_values[i], use_cache)
else:
x = layer(x)

x, past_key_values[i] = layer(x, past_key_values[i], use_cache=use_cache)
if past_key_values[0] is None:
past_key_values = None
x = self.norm(x)
output = self.output(x)
# follow llama in casting this to float.
return output.float(), x
return output.float(), x, past_key_values

def get_input_embeddings(self):
return self.tok_embeddings
Expand Down
13 changes: 13 additions & 0 deletions open_lm/model_configs/open_lm_1b_old.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"hidden_dim": 2048,
"n_layers": 24,
"n_heads": 16,
"seq_len": 2048,
"vocab_size": 50432,
"post_embed_norm": false,
"weight_tying": false,
"qk_norm": false,
"ffn_type": "swiglu",
"model_norm": "default_layer_norm",
"positional_embedding_type": "head_rotary"
}
1 change: 1 addition & 0 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def add_model_args(parser):
parser.add_argument(
"--positional-embedding-type",
type=str,
choices=["rotary", "head_rotary", "llama_rotary"],
default="rotary",
help="Type of positional embedding to use. This might be overridden by the model config.",
)
Expand Down
79 changes: 17 additions & 62 deletions open_lm/positional_embedding/head_rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,78 +16,33 @@

import torch

from open_lm.positional_embedding.rotary import apply_rotary_pos_emb, RotaryEmbedding

def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)


@torch.jit.script
def apply_rotary_pos_emb(x, cos, sin):
# NOTE: This could probably be moved to Triton

# Handle a possible sequence length mismatch in between q and k
cos = cos[:, :, : x.shape[-2], :]
sin = sin[:, :, : x.shape[-2], :]

return (x * cos) + (rotate_half(x) * sin)


class HeadRotaryEmbedding(torch.nn.Module):
class HeadRotaryEmbedding(RotaryEmbedding):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
.. warning: Please note that this embedding is not registered on purpose, as it is transformative
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
The rotary position embeddings used in the first version of OpenLM.
It is only kept for compatibility, RotaryEmbedding should be used instead.
"""

def __init__(self, dim_model: int, *_, **__):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
self.register_buffer("inv_freq", inv_freq)

self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None

def _update_cos_sin_tables(self, x, seq_dimension=1):
seq_len = x.shape[seq_dimension]

# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
def __init__(self, dim_model: int, seq_len: int, *_, **__):
super().__init__(dim_model, seq_len)
self._has_warned = False

return self._cos_cached, self._sin_cached
def forward(self, q: torch.Tensor, k: torch.Tensor, offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
self._update_cos_sin_tables(k.shape[2], device=k.device, dtype=k.dtype)

def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
if not self._has_warned and (offset != 0):
print("Warning. HeadRotaryEmbedding does not support offset, I am not applying it.")
self._has_warned = True

return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)
out_q = apply_rotary_pos_emb(q.transpose(1, 2), self._cos_cached, self._sin_cached).transpose(1, 2)
out_k = apply_rotary_pos_emb(k.transpose(1, 2), self._cos_cached, self._sin_cached).transpose(1, 2)
return out_q, out_k


class HeadRotaryWithCast(HeadRotaryEmbedding):
# NOTE: this version has the bug, but we trained the 7B model with it so it's default
def forward(self, q, k, v):
q, k = super().forward(q, k)
def forward(self, q, k, v, offset: int = 0):
q, k = super().forward(q, k, offset)
return q.to(v.dtype), k.to(v.dtype), v
Loading

0 comments on commit e016855

Please sign in to comment.