66from typing import Optional , Tuple , Union
77
88import torch
9+
10+ from bytelatent .tokenizers .constants import EOS_ID
911from pydantic import BaseModel , ConfigDict
1012from torch import nn
1113from torch .nn import functional as F
1214from torch .nn .attention .flex_attention import (
13- BlockMask ,
1415 _mask_mod_signature ,
16+ BlockMask ,
1517 flex_attention ,
1618)
1719from xformers .ops import AttentionBias , fmha
1820
19- from bytelatent .tokenizers .constants import EOS_ID
20-
2121logger = logging .getLogger ()
2222
2323try :
@@ -42,7 +42,7 @@ class InitStdFactor(str, Enum):
4242
4343
4444class BaseTransformerArgs (BaseModel ):
45- model_config = ConfigDict (extra = "forbid" )
45+ model_config = ConfigDict (extra = "forbid" , arbitrary_types_allowed = True )
4646 dim : int = 512
4747 n_layers : int = 8
4848 head_dim : int | None = None
@@ -68,6 +68,9 @@ class BaseTransformerArgs(BaseModel):
6868 # Special token config
6969 eos_id : int | None = EOS_ID
7070
71+ init_device : str = "cpu"
72+ init_dtype : torch .dtype = torch .float32
73+
7174
7275def cross_entropy (pred , target , ** kwargs ):
7376 return F .nll_loss (
@@ -95,6 +98,7 @@ def precompute_freqs_cis(
9598 end : int ,
9699 theta : float = 10000.0 ,
97100 rope_use_fp32_in_outer_product : bool = False ,
101+ device : str | torch .device = torch .device ("cpu" ),
98102):
99103 """
100104 Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -111,7 +115,9 @@ def precompute_freqs_cis(
111115 Returns:
112116 torch.Tensor: Precomputed frequency tensor with complex exponentials.
113117 """
114- freqs = 1.0 / (theta ** (torch .arange (0 , dim , 2 )[: (dim // 2 )].float () / dim ))
118+ freqs = 1.0 / (
119+ theta ** (torch .arange (0 , dim , 2 , device = device )[: (dim // 2 )].float () / dim )
120+ )
115121 t = torch .arange (end , device = freqs .device )
116122 if rope_use_fp32_in_outer_product :
117123 t = t .to (torch .float32 )
@@ -258,6 +264,8 @@ def __init__(
258264 head_dim : int ,
259265 max_seqlen : int = 1024 ,
260266 rope_use_fp32_in_outer_product : bool = False ,
267+ device : str | torch .device = torch .device ("cpu" ),
268+ dtype : torch .dtype = torch .float32 ,
261269 ):
262270 super ().__init__ ()
263271
@@ -273,7 +281,8 @@ def __init__(
273281 end = max_seqlen ,
274282 theta = theta ,
275283 rope_use_fp32_in_outer_product = self .rope_use_fp32_in_outer_product ,
276- ),
284+ device = device ,
285+ ).to (dtype = dtype ),
277286 persistent = False ,
278287 )
279288
@@ -325,6 +334,8 @@ def __init__(
325334 n_heads : int ,
326335 n_kv_heads : int ,
327336 rope_theta : float ,
337+ device : str | torch .device = torch .device ("cpu" ),
338+ dtype : torch .dtype = torch .float32 ,
328339 ):
329340 super ().__init__ ()
330341
@@ -340,22 +351,30 @@ def __init__(
340351 dim ,
341352 n_heads * head_dim ,
342353 bias = False ,
354+ device = device ,
355+ dtype = dtype ,
343356 )
344357 self .wk = nn .Linear (
345358 dim ,
346359 n_kv_heads * head_dim ,
347360 bias = False ,
361+ device = device ,
362+ dtype = dtype ,
348363 )
349364 self .wv = nn .Linear (
350365 dim ,
351366 n_kv_heads * head_dim ,
352367 bias = False ,
368+ device = device ,
369+ dtype = dtype ,
353370 )
354371
355372 self .wo = nn .Linear (
356373 n_heads * head_dim ,
357374 dim ,
358375 bias = False ,
376+ device = device ,
377+ dtype = dtype ,
359378 )
360379
361380 def forward (
@@ -368,6 +387,7 @@ def forward(
368387 ) -> torch .Tensor :
369388 # B S D
370389 bsz , seq_len , dim = x .shape
390+
371391 xq = self .wq (x .view_as (x ))
372392 xk = self .wk (x .view_as (x ))
373393 xv = self .wv (x .view_as (x ))
@@ -453,6 +473,8 @@ def __init__(
453473 multiple_of : int ,
454474 ffn_dim_multiplier : Optional [float ],
455475 mp_size : int = 1 ,
476+ device : str | torch .device = torch .device ("cpu" ),
477+ dtype : torch .dtype = torch .float32 ,
456478 ):
457479 super ().__init__ ()
458480
@@ -469,16 +491,22 @@ def __init__(
469491 dim ,
470492 hidden_dim ,
471493 bias = False ,
494+ device = device ,
495+ dtype = dtype ,
472496 )
473497 self .w3 = nn .Linear (
474498 dim ,
475499 hidden_dim ,
476500 bias = False ,
501+ device = device ,
502+ dtype = dtype ,
477503 )
478504 self .w2 = nn .Linear (
479505 hidden_dim ,
480506 dim ,
481507 bias = False ,
508+ device = device ,
509+ dtype = dtype ,
482510 )
483511
484512 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -535,15 +563,24 @@ def __init__(self, args: BaseTransformerArgs):
535563 n_heads = self .n_heads ,
536564 n_kv_heads = self .n_kv_heads ,
537565 rope_theta = args .rope_theta ,
566+ device = args .init_device ,
567+ dtype = args .init_dtype ,
538568 )
539569 self .feed_forward = FeedForward (
540570 dim = args .dim ,
541571 hidden_dim = 4 * args .dim ,
542572 multiple_of = args .multiple_of ,
543573 ffn_dim_multiplier = args .ffn_dim_multiplier ,
574+ device = args .init_device ,
575+ dtype = args .init_dtype ,
576+ )
577+ # Norms stay in full precision
578+ self .attention_norm = RMSNorm (
579+ args .dim , eps = args .norm_eps , device = args .init_device , dtype = args .init_dtype
580+ )
581+ self .ffn_norm = RMSNorm (
582+ args .dim , eps = args .norm_eps , device = args .init_device , dtype = args .init_dtype
544583 )
545- self .attention_norm = RMSNorm (args .dim , eps = args .norm_eps )
546- self .ffn_norm = RMSNorm (args .dim , eps = args .norm_eps )
547584
548585 def forward (
549586 self ,
@@ -593,6 +630,8 @@ def __init__(self, args: BaseTransformerArgs):
593630 head_dim = args .head_dim or args .dim // args .n_heads ,
594631 max_seqlen = args .max_seqlen ,
595632 rope_use_fp32_in_outer_product = args .rope_use_fp32_in_outer_product ,
633+ device = args .init_device ,
634+ dtype = args .init_dtype ,
596635 )
597636 self .eos_id = args .eos_id
598637
0 commit comments