Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions bytelatent/base_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@
from typing import Optional, Tuple, Union

import torch

from bytelatent.model.utils import DTYPE_MAP
from bytelatent.tokenizers.constants import EOS_ID
from pydantic import BaseModel, ConfigDict
from torch import nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import (
BlockMask,
_mask_mod_signature,
BlockMask,
flex_attention,
)
from xformers.ops import AttentionBias, fmha

from bytelatent.tokenizers.constants import EOS_ID

logger = logging.getLogger()

try:
Expand Down Expand Up @@ -68,6 +69,9 @@ class BaseTransformerArgs(BaseModel):
# Special token config
eos_id: int | None = EOS_ID

init_device: str = "cpu"
init_dtype: str = "fp32"


def cross_entropy(pred, target, **kwargs):
return F.nll_loss(
Expand Down Expand Up @@ -95,6 +99,7 @@ def precompute_freqs_cis(
end: int,
theta: float = 10000.0,
rope_use_fp32_in_outer_product: bool = False,
device: str | torch.device = torch.device("cpu"),
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
Expand All @@ -111,7 +116,9 @@ def precompute_freqs_cis(
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim)
)
t = torch.arange(end, device=freqs.device)
if rope_use_fp32_in_outer_product:
t = t.to(torch.float32)
Expand Down Expand Up @@ -258,6 +265,8 @@ def __init__(
head_dim: int,
max_seqlen: int = 1024,
rope_use_fp32_in_outer_product: bool = False,
device: str | torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
):
super().__init__()

Expand All @@ -273,7 +282,8 @@ def __init__(
end=max_seqlen,
theta=theta,
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
),
device=device,
).to(dtype=dtype),
persistent=False,
)

Expand Down Expand Up @@ -325,6 +335,8 @@ def __init__(
n_heads: int,
n_kv_heads: int,
rope_theta: float,
device: str | torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
):
super().__init__()

Expand All @@ -340,22 +352,30 @@ def __init__(
dim,
n_heads * head_dim,
bias=False,
device=device,
dtype=dtype,
)
self.wk = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
device=device,
dtype=dtype,
)
self.wv = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
device=device,
dtype=dtype,
)

self.wo = nn.Linear(
n_heads * head_dim,
dim,
bias=False,
device=device,
dtype=dtype,
)

def forward(
Expand All @@ -368,6 +388,7 @@ def forward(
) -> torch.Tensor:
# B S D
bsz, seq_len, dim = x.shape

xq = self.wq(x.view_as(x))
xk = self.wk(x.view_as(x))
xv = self.wv(x.view_as(x))
Expand Down Expand Up @@ -453,6 +474,8 @@ def __init__(
multiple_of: int,
ffn_dim_multiplier: Optional[float],
mp_size: int = 1,
device: str | torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
):
super().__init__()

Expand All @@ -469,16 +492,22 @@ def __init__(
dim,
hidden_dim,
bias=False,
device=device,
dtype=dtype,
)
self.w3 = nn.Linear(
dim,
hidden_dim,
bias=False,
device=device,
dtype=dtype,
)
self.w2 = nn.Linear(
hidden_dim,
dim,
bias=False,
device=device,
dtype=dtype,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -535,15 +564,30 @@ def __init__(self, args: BaseTransformerArgs):
n_heads=self.n_heads,
n_kv_heads=self.n_kv_heads,
rope_theta=args.rope_theta,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
)
# Norms stay in full precision
self.attention_norm = RMSNorm(
args.dim,
eps=args.norm_eps,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
)
self.ffn_norm = RMSNorm(
args.dim,
eps=args.norm_eps,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

def forward(
self,
Expand Down Expand Up @@ -593,6 +637,8 @@ def __init__(self, args: BaseTransformerArgs):
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
)
self.eos_id = args.eos_id

Expand Down
10 changes: 7 additions & 3 deletions bytelatent/entropy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
logger = logging.getLogger()


def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
def load_entropy_model(
entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype="bf16"
):
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
reloaded = json.loads(fr.read())

torch.set_default_dtype(torch.bfloat16)
model_params = reloaded["entropy_model"]
logger.warning(
"Update checkpoint to load attn and sliding window args from checkpoint"
Expand All @@ -29,6 +30,8 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
attn_bias_type="local_block_causal",
attn_impl="xformers",
sliding_window=512,
init_device=device,
init_dtype=dtype,
)
entropy_model = LMTransformer(entropy_model_args)

Expand All @@ -38,6 +41,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
entropy_model.to(device)
entropy_model = entropy_model.eval()
# no grads for the model:
for param in entropy_model.parameters():
for n, param in entropy_model.named_parameters():
param.requires_grad = False

return entropy_model, entropy_model_args
34 changes: 19 additions & 15 deletions bytelatent/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@
import time

import torch
from omegaconf import OmegaConf
from torch import nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import create_block_mask
from tqdm import tqdm

from bytelatent.args import EvalArgs, PackedCausalTransformerGeneratorArgs, TrainArgs
from bytelatent.base_transformer import (
Expand All @@ -19,9 +14,9 @@
lengths_to_start_ids,
)
from bytelatent.checkpoint import (
consolidate_checkpoints,
CONSOLIDATE_FOLDER,
CONSOLIDATE_NAME,
consolidate_checkpoints,
)
from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs
Expand All @@ -33,6 +28,11 @@
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
from bytelatent.transformer import LMTransformer
from omegaconf import OmegaConf
from torch import nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import create_block_mask
from tqdm import tqdm


def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
Expand Down Expand Up @@ -400,25 +400,29 @@ def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=Fa
setup_torch_distributed(distributed_args)
train_args_path = os.path.join(consolidated_path, "params.json")
fs = get_fs(train_args_path)

train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))

if train_args.train_entropy_model:
model_args = train_args.entropy_model
model_args.init_device = "cuda"
model_args.init_dtype = train_args.distributed.model_dtype
model = LMTransformer(model_args)
else:
model_args = train_args.model
model = ByteLatentTransformer(model_args)
model_args.init_device = "cuda"
model_args.init_dtype = train_args.distributed.model_dtype
model = ByteLatentTransformer(args=model_args)

model = model.eval()

param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
train_args.distributed.model_dtype
]
tokenizer = train_args.data.tokenizer_args.build()
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
st_dict = torch.load(f, weights_only=True)

with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as fp:
st_dict = torch.load(fp, weights_only=True)

model.load_state_dict(st_dict["model"])
model = model.cuda().eval()
for param in model.parameters():
param.data = param.data.to(dtype=param_dtype)

return model, tokenizer, train_args


Expand Down
Loading