Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfixes surrounding torch dtypes and QOL updates for torch #53

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions entropix/torch_kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,11 @@ def update(
- keys: Updated or repeated keys tensor.
- values: Updated or repeated values tensor.
"""
# Ensure xk and xv have the correct device and dtype
xk = xk.to(self.k.dtype)
xv = xv.to(self.v.dtype)

# Update the k and v tensors in the specified layer and position
insert_len = xk.size(1) # Assuming xk shape is (bsz, insert_len, kv_heads, head_dim)
self.k[layer_idx, :, cur_pos:cur_pos+insert_len, :, :] = xk
self.v[layer_idx, :, cur_pos:cur_pos+insert_len, :, :] = xv
bsz, insert_len, _, _ = xk.shape # Assuming xk shape is (bsz, insert_len, kv_heads, head_dim)
self.k[layer_idx, :bsz, cur_pos:cur_pos+insert_len, :, :] = xk
self.v[layer_idx, :bsz, cur_pos:cur_pos+insert_len, :, :] = xv

if cur_pos == 0:
# If inserting at the beginning, repeat the new keys and values
Expand All @@ -77,10 +74,11 @@ def update(
# Otherwise, repeat the existing keys and values from the cache
keys = self.k[layer_idx].repeat_interleave(n_rep, dim=2)
values = self.v[layer_idx].repeat_interleave(n_rep, dim=2)

keys = keys[: bsz].to(xk.dtype)
values = values[: bsz].to(xv.dtype)
return keys, values, self

def clear(self):
"""Resets the k and v caches to zeros."""
self.k.zero_()
self.v.zero_()
self.v.zero_()
36 changes: 18 additions & 18 deletions entropix/torch_main.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
from typing import NamedTuple, Optional, Tuple

import torch
import torch.nn.functional as F

import math
import tyro

from pathlib import Path
from functools import partial

from entropix.config import LLAMA_1B_PARAMS
from entropix.tokenizer import Tokenizer
from entropix.torch_kvcache import KVCache
from entropix.torch_model import xfmr
from entropix.torch_weights import XfmrWeights, LayerWeights, load_weights
from entropix.torch_weights import load_weights
from entropix.torch_sampler import sample
from entropix.prompts import prompt, bp1
from entropix.prompts import create_prompts_from_csv, prompt, bp1

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Device selection, tree is like first apple silicion, then cuda, fallback is cpu.
Expand Down Expand Up @@ -93,12 +85,7 @@ def main():
with torch.inference_mode():
model_params = LLAMA_1B_PARAMS
xfmr_weights = load_weights()

tokenizer = Tokenizer('entropix/tokenizer.model')
raw_tokens1 = tokenizer.encode(prompt, bos=False, eos=False, allowed_special='all')
#this is not used in this script, but can be used to generate base_raw_tokens1
base_raw_tokens1 = tokenizer.encode(bp1, bos=True, eos=False, allowed_special='all')


def generate(xfmr_weights, model_params, tokens):
gen_tokens = None
Expand All @@ -107,7 +94,7 @@ def generate(xfmr_weights, model_params, tokens):
bsz, seqlen = tokens.shape
attn_mask = build_attn_mask(seqlen, cur_pos)
freqs_cis = precompute_freqs_cis(model_params.head_dim, model_params.max_seq_len, model_params.rope_theta, model_params.use_scaled_rope)
kvcache = KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim).to(DEVICE)
kvcache = KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim).to(device)
logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask)
next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True).to(torch.int32)
gen_tokens = next_token
Expand All @@ -123,8 +110,21 @@ def generate(xfmr_weights, model_params, tokens):
if torch.isin(next_token, stop).any():
break

print(prompt)
generate(xfmr_weights, model_params, raw_tokens1)
csv_path = Path('entropix/data/prompts.csv')
prompts = create_prompts_from_csv(csv_path)
PROMPT_TEST = False

if PROMPT_TEST:
for test_prompt in prompts:
print(test_prompt)
tokens = tokenizer.encode(test_prompt, bos=False, eos=False, allowed_special='all')
generate(xfmr_weights, model_params, tokens)
else:
raw_tokens1 = tokenizer.encode(prompt, bos=False, eos=False, allowed_special='all')
#this is not used in this script, but can be used to generate base_raw_tokens1
base_raw_tokens1 = tokenizer.encode(bp1, bos=True, eos=False, allowed_special='all')
print(prompt)
generate(xfmr_weights, model_params, raw_tokens1)

if __name__ == '__main__':
tyro.cli(main)
2 changes: 1 addition & 1 deletion entropix/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def attention(x: torch.Tensor, layer_weights: LayerWeights, model_params, cur_po
scores = scores + attn_mask
mask = torch.where(scores != 0.0, scores, DEFAULT_MASK_VALUE)
padded_logits = torch.where((mask >= DEFAULT_MASK_VALUE * 0.5), scores, DEFAULT_MASK_VALUE)
scores = F.softmax(padded_logits, dim=-1).to(torch.float32)
scores = F.softmax(padded_logits, dim=-1).to(x.dtype)
output = torch.matmul(scores, values)
output = output.transpose(1, 2).reshape(xq.shape[0], xq.shape[2], -1)
out = F.linear(output, layer_weights.wo)
Expand Down
2 changes: 1 addition & 1 deletion entropix/torch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def calculate_metrics(logits: torch.Tensor, attention_scores: torch.Tensor) -> D
entropy, varentropy = calculate_varentropy_logsoftmax(logits)
attention_probs = F.softmax(attention_scores, dim=-1)
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1)
attn_varentropy = torch.var(attn_entropy, dim=-1)
attn_varentropy = torch.var(attn_entropy, dim=1)

# Add a small epsilon to avoid NaN when all values are the same
attn_varentropy = torch.where(torch.isnan(attn_varentropy), torch.zeros_like(attn_varentropy), attn_varentropy)
Expand Down