import torch import torch.nn as nn from torch.nn import functional as F from .layers import layer_norm, linear, mlp from .rope import apply_rotary_emb, precompute_freqs_cis from .weights import AttentionWeights from .config import TextConfig def text_encoder(input_ids: torch.Tensor, w: nn.Module): return F.embedding(input_ids, w.wte) def attn( x: torch.Tensor, w: AttentionWeights, freqs_cis: torch.Tensor, layer_kv_cache: torch.Tensor, attn_mask: torch.Tensor, n_heads: int, pos: int, ): bsz, q_len, d_model = x.shape head_dim = d_model // n_heads q, k, v = [ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) for t in linear(x, w.qkv).chunk(3, dim=-1) ] position_ids = torch.arange(pos, pos + q_len, dtype=torch.long) q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads) k = apply_rotary_emb(k, freqs_cis, position_ids, n_heads) k_, v_ = k, v if layer_kv_cache is not None: k = torch.cat([layer_kv_cache[0, :, :, :pos, :], k], dim=2) v = torch.cat([layer_kv_cache[1, :, :, :pos, :], v], dim=2) out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask).to( # This type conversion isn't needed when running in PyTorch directly, but the # ONNX export runs attention in float32 because the attention mask is cast to # float32. x.dtype ) out = out.transpose(1, 2).reshape(bsz, q_len, d_model) out = linear(out, w.proj) return out, torch.stack([k_, v_]) def text_decoder( inputs_embeds: torch.Tensor, w: nn.Module, kv_cache: torch.Tensor, pos: int, config: TextConfig, ): hidden_BTC = inputs_embeds new_kv_cache = [torch.empty(0)] * len(w.blocks) attn_mask = w.attn_mask[ :, :, pos : pos + hidden_BTC.size(1), : pos + hidden_BTC.size(1) ] for i, block in enumerate(w.blocks): l_in = layer_norm(hidden_BTC, block.ln) l_attn, new_kv_cache[i] = attn( l_in, block.attn, freqs_cis=w.freqs_cis, layer_kv_cache=kv_cache[i], attn_mask=attn_mask, n_heads=config.n_heads, pos=pos, ) l_mlp = mlp(l_in, block.mlp) hidden_BTC = hidden_BTC + l_attn + l_mlp return hidden_BTC, torch.stack(new_kv_cache) def lm_head(hidden_BTC: torch.Tensor, w: nn.Module): hidden_BC = hidden_BTC[:, -1, :] hidden_BC = layer_norm(hidden_BC, w.post_ln) logits = linear(hidden_BC, w.lm_head) return logits def prefill( inputs_embeds: torch.Tensor, kv_cache: torch.Tensor, pos: int, w: nn.Module, config: TextConfig, ): # Updates kv_cache in-place hidden, kv_cache[:, :, :, :, pos : pos + inputs_embeds.size(1), :] = text_decoder( inputs_embeds, w, kv_cache, pos, config ) return hidden def decode_one_token( token_emb: torch.Tensor, kv_cache: torch.Tensor, pos: int, w: nn.Module, config: TextConfig, ): hidden, kv_cache_update = text_decoder(token_emb[None], w, kv_cache, pos, config) logits = lm_head(hidden, w) return logits, hidden, kv_cache_update def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: text = nn.ModuleDict( { "blocks": nn.ModuleList( [ nn.ModuleDict( { "ln": nn.LayerNorm(config.dim, dtype=dtype), "attn": nn.ModuleDict( { "qkv": nn.Linear( config.dim, 3 * config.dim, dtype=dtype ), "proj": nn.Linear( config.dim, config.dim, dtype=dtype ), } ), "mlp": nn.ModuleDict( { "fc1": nn.Linear( config.dim, 4 * config.dim, dtype=dtype ), "fc2": nn.Linear( 4 * config.dim, config.dim, dtype=dtype ), } ), } ) for _ in range(config.n_layers) ] ), "post_ln": nn.LayerNorm(config.dim, dtype=dtype), "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype), } ) text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype)) text.register_buffer( "freqs_cis", precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context), persistent=False, ) attn_mask = torch.tril( torch.ones(1, 1, config.max_context, config.max_context, dtype=torch.bool) ) if config.prefix_attn != 0: attn_mask[..., : config.prefix_attn, : config.prefix_attn] = 1 text.register_buffer("attn_mask", attn_mask, persistent=False) return text