Skip to content

Commit 4c5e51e

Browse files
committed
Add on-device initialization
1 parent 4ae7a62 commit 4c5e51e

File tree

8 files changed

+224
-58
lines changed

8 files changed

+224
-58
lines changed

bytelatent/base_transformer.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@
66
from typing import Optional, Tuple, Union
77

88
import torch
9+
10+
from bytelatent.tokenizers.constants import EOS_ID
911
from pydantic import BaseModel, ConfigDict
1012
from torch import nn
1113
from torch.nn import functional as F
1214
from torch.nn.attention.flex_attention import (
13-
BlockMask,
1415
_mask_mod_signature,
16+
BlockMask,
1517
flex_attention,
1618
)
1719
from xformers.ops import AttentionBias, fmha
1820

19-
from bytelatent.tokenizers.constants import EOS_ID
20-
2121
logger = logging.getLogger()
2222

2323
try:
@@ -42,7 +42,7 @@ class InitStdFactor(str, Enum):
4242

4343

4444
class BaseTransformerArgs(BaseModel):
45-
model_config = ConfigDict(extra="forbid")
45+
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
4646
dim: int = 512
4747
n_layers: int = 8
4848
head_dim: int | None = None
@@ -68,6 +68,9 @@ class BaseTransformerArgs(BaseModel):
6868
# Special token config
6969
eos_id: int | None = EOS_ID
7070

71+
init_device: str = "cpu"
72+
init_dtype: torch.dtype = torch.float32
73+
7174

7275
def cross_entropy(pred, target, **kwargs):
7376
return F.nll_loss(
@@ -95,6 +98,7 @@ def precompute_freqs_cis(
9598
end: int,
9699
theta: float = 10000.0,
97100
rope_use_fp32_in_outer_product: bool = False,
101+
device: str | torch.device = torch.device("cpu"),
98102
):
99103
"""
100104
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -111,7 +115,9 @@ def precompute_freqs_cis(
111115
Returns:
112116
torch.Tensor: Precomputed frequency tensor with complex exponentials.
113117
"""
114-
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
118+
freqs = 1.0 / (
119+
theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim)
120+
)
115121
t = torch.arange(end, device=freqs.device)
116122
if rope_use_fp32_in_outer_product:
117123
t = t.to(torch.float32)
@@ -258,6 +264,8 @@ def __init__(
258264
head_dim: int,
259265
max_seqlen: int = 1024,
260266
rope_use_fp32_in_outer_product: bool = False,
267+
device: str | torch.device = torch.device("cpu"),
268+
dtype: torch.dtype = torch.float32,
261269
):
262270
super().__init__()
263271

@@ -273,7 +281,8 @@ def __init__(
273281
end=max_seqlen,
274282
theta=theta,
275283
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
276-
),
284+
device=device,
285+
).to(dtype=dtype),
277286
persistent=False,
278287
)
279288

@@ -325,6 +334,8 @@ def __init__(
325334
n_heads: int,
326335
n_kv_heads: int,
327336
rope_theta: float,
337+
device: str | torch.device = torch.device("cpu"),
338+
dtype: torch.dtype = torch.float32,
328339
):
329340
super().__init__()
330341

@@ -340,22 +351,30 @@ def __init__(
340351
dim,
341352
n_heads * head_dim,
342353
bias=False,
354+
device=device,
355+
dtype=dtype,
343356
)
344357
self.wk = nn.Linear(
345358
dim,
346359
n_kv_heads * head_dim,
347360
bias=False,
361+
device=device,
362+
dtype=dtype,
348363
)
349364
self.wv = nn.Linear(
350365
dim,
351366
n_kv_heads * head_dim,
352367
bias=False,
368+
device=device,
369+
dtype=dtype,
353370
)
354371

355372
self.wo = nn.Linear(
356373
n_heads * head_dim,
357374
dim,
358375
bias=False,
376+
device=device,
377+
dtype=dtype,
359378
)
360379

361380
def forward(
@@ -368,6 +387,7 @@ def forward(
368387
) -> torch.Tensor:
369388
# B S D
370389
bsz, seq_len, dim = x.shape
390+
371391
xq = self.wq(x.view_as(x))
372392
xk = self.wk(x.view_as(x))
373393
xv = self.wv(x.view_as(x))
@@ -453,6 +473,8 @@ def __init__(
453473
multiple_of: int,
454474
ffn_dim_multiplier: Optional[float],
455475
mp_size: int = 1,
476+
device: str | torch.device = torch.device("cpu"),
477+
dtype: torch.dtype = torch.float32,
456478
):
457479
super().__init__()
458480

@@ -469,16 +491,22 @@ def __init__(
469491
dim,
470492
hidden_dim,
471493
bias=False,
494+
device=device,
495+
dtype=dtype,
472496
)
473497
self.w3 = nn.Linear(
474498
dim,
475499
hidden_dim,
476500
bias=False,
501+
device=device,
502+
dtype=dtype,
477503
)
478504
self.w2 = nn.Linear(
479505
hidden_dim,
480506
dim,
481507
bias=False,
508+
device=device,
509+
dtype=dtype,
482510
)
483511

484512
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -535,15 +563,24 @@ def __init__(self, args: BaseTransformerArgs):
535563
n_heads=self.n_heads,
536564
n_kv_heads=self.n_kv_heads,
537565
rope_theta=args.rope_theta,
566+
device=args.init_device,
567+
dtype=args.init_dtype,
538568
)
539569
self.feed_forward = FeedForward(
540570
dim=args.dim,
541571
hidden_dim=4 * args.dim,
542572
multiple_of=args.multiple_of,
543573
ffn_dim_multiplier=args.ffn_dim_multiplier,
574+
device=args.init_device,
575+
dtype=args.init_dtype,
576+
)
577+
# Norms stay in full precision
578+
self.attention_norm = RMSNorm(
579+
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
580+
)
581+
self.ffn_norm = RMSNorm(
582+
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
544583
)
545-
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
546-
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
547584

548585
def forward(
549586
self,
@@ -593,6 +630,8 @@ def __init__(self, args: BaseTransformerArgs):
593630
head_dim=args.head_dim or args.dim // args.n_heads,
594631
max_seqlen=args.max_seqlen,
595632
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
633+
device=args.init_device,
634+
dtype=args.init_dtype,
596635
)
597636
self.eos_id = args.eos_id
598637

bytelatent/entropy_model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
logger = logging.getLogger()
1111

1212

13-
def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
13+
def load_entropy_model(
14+
entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype=torch.bfloat16
15+
):
1416
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
1517
reloaded = json.loads(fr.read())
1618

17-
torch.set_default_dtype(torch.bfloat16)
19+
# torch.set_default_dtype(dtype)
1820
model_params = reloaded["entropy_model"]
1921
logger.warning(
2022
"Update checkpoint to load attn and sliding window args from checkpoint"
@@ -29,6 +31,8 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
2931
attn_bias_type="local_block_causal",
3032
attn_impl="xformers",
3133
sliding_window=512,
34+
init_device=device,
35+
init_dtype=dtype,
3236
)
3337
entropy_model = LMTransformer(entropy_model_args)
3438

@@ -38,6 +42,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
3842
entropy_model.to(device)
3943
entropy_model = entropy_model.eval()
4044
# no grads for the model:
41-
for param in entropy_model.parameters():
45+
for n, param in entropy_model.named_parameters():
4246
param.requires_grad = False
47+
4348
return entropy_model, entropy_model_args

bytelatent/generate.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44
import time
55

66
import torch
7-
from omegaconf import OmegaConf
8-
from torch import nn
9-
from torch.nn import functional as F
10-
from torch.nn.attention.flex_attention import create_block_mask
11-
from tqdm import tqdm
127

138
from bytelatent.args import EvalArgs, PackedCausalTransformerGeneratorArgs, TrainArgs
149
from bytelatent.base_transformer import (
@@ -19,9 +14,9 @@
1914
lengths_to_start_ids,
2015
)
2116
from bytelatent.checkpoint import (
17+
consolidate_checkpoints,
2218
CONSOLIDATE_FOLDER,
2319
CONSOLIDATE_NAME,
24-
consolidate_checkpoints,
2520
)
2621
from bytelatent.config_parser import parse_args_to_pydantic_model
2722
from bytelatent.data.file_util import get_fs
@@ -33,6 +28,11 @@
3328
from bytelatent.model.blt import ByteLatentTransformer
3429
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
3530
from bytelatent.transformer import LMTransformer
31+
from omegaconf import OmegaConf
32+
from torch import nn
33+
from torch.nn import functional as F
34+
from torch.nn.attention.flex_attention import create_block_mask
35+
from tqdm import tqdm
3636

3737

3838
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
@@ -400,25 +400,33 @@ def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=Fa
400400
setup_torch_distributed(distributed_args)
401401
train_args_path = os.path.join(consolidated_path, "params.json")
402402
fs = get_fs(train_args_path)
403+
403404
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
404405

406+
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
407+
train_args.distributed.model_dtype
408+
]
409+
405410
if train_args.train_entropy_model:
406411
model_args = train_args.entropy_model
412+
model_args.init_device = "cuda"
413+
model_args.init_dtype = param_dtype
407414
model = LMTransformer(model_args)
408415
else:
409416
model_args = train_args.model
410-
model = ByteLatentTransformer(model_args)
417+
model_args.init_device = "cuda"
418+
model_args.init_dtype = param_dtype
419+
model = ByteLatentTransformer(args=model_args)
420+
421+
model = model.eval()
411422

412-
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
413-
train_args.distributed.model_dtype
414-
]
415423
tokenizer = train_args.data.tokenizer_args.build()
424+
416425
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
417426
st_dict = torch.load(f, weights_only=True)
427+
418428
model.load_state_dict(st_dict["model"])
419-
model = model.cuda().eval()
420-
for param in model.parameters():
421-
param.data = param.data.to(dtype=param_dtype)
429+
422430
return model, tokenizer, train_args
423431

424432

0 commit comments

Comments
 (0)