-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecoding.py
More file actions
126 lines (94 loc) · 3.29 KB
/
decoding.py
File metadata and controls
126 lines (94 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import argparse
from cs336_basics.tokenizer import BPETokenizer
def load_tokenizer():
tokenizer = BPETokenizer.from_files(
vocab_filepath="data/tinystories_valid_tokenizer/tinystories_vocab.json",
merges_filepath="data/tinystories_valid_tokenizer/tinystories_merges.txt",
special_tokens=["<|endoftext|>"],
)
return tokenizer
def load_model(checkpoint_path: str, device: str):
from cs336_basics.transformer import Transformer
ckpt = torch.load(checkpoint_path)
config = ckpt["model_state"]
model = Transformer(
vocab_size=10000,
context_length=256,
num_layers=4,
d_model=512,
num_heads=16,
d_ff=1344,
rope_theta=10000.0,
device=device,
)
model.load_state_dict(config)
model.to(device)
return model
def sample_next_token(logits, temperature=1.0, top_p=1.0):
from cs336_basics.softmax import softmax
if temperature <= 0:
return int(torch.argmax(logits).item())
logits = logits / temperature
probs = softmax(logits, -1)
if top_p is None or top_p >= 1.0:
return int(torch.multinomial(probs, num_samples=1).item())
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumulative = torch.cumsum(sorted_probs, dim=-1)
mask = cumulative <= top_p
if not torch.any(mask):
mask[0] = True
cutoff = torch.nonzero(mask)[-1].item()
mask[: cutoff + 1] = True
truncated_probs = sorted_probs * mask
truncated_probs /= truncated_probs.sum()
sampled = torch.multinomial(truncated_probs, 1)
next_id = sorted_idx[sampled]
return int(next_id.item())
@torch.no_grad()
def decode(
model: torch.nn.Module,
tokenizer: BPETokenizer,
prompt_ids: torch.Tensor,
max_tokens: int,
device,
temperature=1.0,
top_p=1.0,
):
model.eval()
ids = prompt_ids.to(device)
eos_id = tokenizer.vocab_reverse[b"<|endoftext|>"]
for _ in range(max_tokens):
logits = model(ids.unsqueeze(0))
last_logits = logits[0, -1]
next_id = sample_next_token(last_logits, temperature=temperature, top_p=top_p)
next_id_tensor = torch.tensor([next_id], dtype=torch.long, device=device)
ids = torch.cat([ids, next_id_tensor], dim=0)
if next_id == eos_id:
break
return ids
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--max-tokens", type=int, default=128)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--device", type=str, default="mps")
args = parser.parse_args()
device = args.device
tokenizer = load_tokenizer()
prompt_ids = tokenizer.encode(args.prompt)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.long)
model = load_model(args.checkpoint, device)
full_ids = decode(
model=model,
tokenizer=tokenizer,
prompt_ids=prompt_ids,
max_tokens=args.max_tokens,
device=device,
temperature=args.temperature,
top_p=args.top_p,
)
text = tokenizer.decode(full_ids.tolist())
print(text)