-
Notifications
You must be signed in to change notification settings - Fork 251
/
Copy pathgenerate.py
396 lines (328 loc) · 15.5 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import sys
import time
from pathlib import Path
from typing import Optional, Tuple
import torch
import torch._dynamo.config
import torch._inductor.config
torch.manual_seed(0)
def device_sync(device):
if "cuda" in device:
torch.cuda.synchronize(device)
elif "cpu" in device:
pass
else:
print(f"device={device} is not yet suppported")
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
torch._dynamo.config.capture_scalar_outputs = True
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from sentencepiece import SentencePieceProcessor
from model import Transformer
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
# input_pos: [B, S]
logits = model(x, input_pos)
return sample(logits, **sampling_kwargs)[0]
def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
logits = model(x, input_pos)
return sample(logits, **sampling_kwargs)
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
new_tokens, new_probs = [], []
for i in range(num_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
next_token, next_prob = next_token.clone(), next_prob.clone()
input_pos += 1
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
cur_token = next_token
return new_tokens, new_probs
def model_forward(model, x, input_pos):
return model(x, input_pos)
@torch.no_grad()
def generate(
model: Transformer,
prompt: torch.Tensor,
max_new_tokens: int,
batch_size: int,
*,
interactive: bool,
callback = lambda x: x,
**sampling_kwargs
) -> torch.Tensor:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""
device, dtype = prompt.device, prompt.dtype
T = prompt.size(-1)
max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350
new_tokens = max_seq_length - T
# duplicate prompt for batch_size
prompt = prompt.repeat(batch_size, 1)
# create an empty tensor of the expected final shape and fill in the current tokens
seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device)
seq[:, :T] = prompt
with torch.device(device):
model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length)
input_pos = torch.arange(0, T, device=device)
next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs)
seq[:, T] = next_token.squeeze()
input_pos = torch.tensor([T], device=device, dtype=torch.int)
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1)
return seq
def encode_tokens(tokenizer, string, bos=True, device='cuda'):
tokens = tokenizer.encode(string)
if bos:
tokens = [tokenizer.bos_id()] + tokens
return torch.tensor(tokens, dtype=torch.int, device=device)
def _load_model(checkpoint_path, device, precision):
with torch.device('meta'):
model = Transformer.from_name(checkpoint_path.parent.name)
try:
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
except:
model = Transformer.from_name(checkpoint_path.parent.name)
model = model.to(device=device, dtype=precision)
return model.eval()
B_INST, E_INST = "[INST]", "[/INST]"
def main(
prompt: str = "Hello, my name is",
interactive: bool = False,
num_samples: int = 5,
max_new_tokens: int = 100,
batch_size: int = 1,
top_k: int = 200,
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"),
compile: bool = True,
compile_prefill: bool = False,
moe_quant: Optional[str] = None,
profile: Optional[Path] = None,
device='cuda',
) -> None:
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
"""
assert checkpoint_path.is_file(), checkpoint_path
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), str(tokenizer_path)
print(f"Using device={device}")
precision = torch.bfloat16
is_chat = "chat" in str(checkpoint_path)
print("Loading model ...")
t0 = time.time()
model = _load_model(checkpoint_path, device, precision)
device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
prompt_length = encoded.size(0)
torch.manual_seed(1234)
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
import torchao
from torchao.quantization.quant_api import (
quantize_,
Int8WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int4WeightOnlyConfig,
Float8WeightOnlyConfig,
Float8DynamicActivationFloat8WeightConfig,
PerRow,
_replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.prototype.moe_quant.utils import MoEQuantConfig, cond_ffn_filter
if moe_quant:
torch._dynamo.config.capture_dynamic_output_shape_ops = True
config = None
if "int8wo" in moe_quant:
config = MoEQuantConfig(Int8WeightOnlyConfig())
elif "int8wo-base" in moe_quant:
config=1
def int8wo_quant_convert_fn(module, config):
def quant_tensor(weight):
from torchao.quantization.quant_api import (
MappingType,
to_affine_quantized_intx,
)
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
block_size = [1 for x in range(weight.dim())]
block_size[-1] = weight.shape[-1]
block_size = tuple(block_size)
new_weight = to_affine_quantized_intx(
weight,
mapping_type,
block_size,
target_dtype,
eps=eps,
zero_point_dtype=zero_point_dtype,
)
return new_weight
assert "ConditionalFeedForwardAOQuantizable" in str(type(module))
assert hasattr(module, "w1")
assert hasattr(module, "w2")
assert hasattr(module, "w3")
for weight_attr in ["w1", "w2", "w3"]:
param = getattr(module, weight_attr)
new_param = quant_tensor(param)
new_param = torch.nn.Parameter(new_param, requires_grad=False)
setattr(module, weight_attr, new_param)
del param
return module
_replace_with_custom_fn_if_matches_filter(
model,
replacement_fn=int8wo_quant_convert_fn,
filter_fn=cond_ffn_filter,
extra_args=(Int8WeightOnlyConfig(),)
)
elif "int8dq" in moe_quant:
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
elif "int8dq-base" in moe_quant:
pass
elif "int4wo" in moe_quant:
config = MoEQuantConfig(Int4WeightOnlyConfig())
elif "int4wo-base" in moe_quant:
pass
elif "fp8wo" in moe_quant:
config = MoEQuantConfig(Float8WeightOnlyConfig())
elif "fp8dq" in moe_quant:
config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
else:
assert config is not None, f"expected moe_quant to match one of the options but got {moe_quant}"
if isinstance(config, MoEQuantConfig):
quantize_(model, config, filter_fn=cond_ffn_filter)
if compile:
# moe quant + compile causes repeated warnings
import warnings
warnings.simplefilter("ignore", lineno=84)
warnings.simplefilter("ignore", lineno=105)
torch._inductor.config.assert_indirect_indexing = False
global decode_one_token, prefill
if batch_size > 1 or (isinstance(moe_quant, str) and "base" not in moe_quant):
# if batch_size > 1: # MoE code has graph break for multi token path so can't fullgraph compile
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead")
else:
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
if args.compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
aggregate_metrics = {
'tokens_per_sec': [],
}
start = -1 if compile else 0
for i in range(start, num_samples):
device_sync(device=device) # MKG
if i >= 0 and interactive:
prompt = input("What is your prompt? ")
if is_chat:
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
if interactive and i >= 0:
buffer = []
period_id = tokenizer.encode('.')[0]
done_generating = False
def callback(x):
nonlocal done_generating
if done_generating:
return
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
if x.item() == tokenizer.eos_id():
done_generating = True
if len(buffer) == 4 or done_generating:
print(''.join(buffer), end='', flush=True)
buffer.clear()
# print(, end='', flush=True)
else:
callback = lambda x : x
t0 = time.perf_counter()
import contextlib
if (i != num_samples - 1 or not profile):
prof = contextlib.nullcontext()
else:
torch.profiler._utils._init_for_cuda_graphs()
prof = torch.profiler.profile()
with prof:
y = generate(
model,
encoded,
max_new_tokens,
batch_size,
interactive=interactive,
callback=callback,
temperature=temperature,
top_k=top_k,
)
if i == -1:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
continue
if hasattr(prof, "export_chrome_trace"):
prof.export_chrome_trace(f"{profile}.json")
device_sync(device=device) # MKG
t = time.perf_counter() - t0
if not interactive:
print(tokenizer.decode(y[0].tolist()))
else:
print()
tokens_generated = y.size(-1) - prompt_length
tokens_sec = tokens_generated / t
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
print(f"Average tokens/sec: {tokpersec:.2f}")
if batch_size > 1:
print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}")
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Your CLI description.')
parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with')
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
# parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8')
parser.add_argument('--moe_quant', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq')
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
parser.add_argument('--device', type=str, default="cuda", help='device to use')
args = parser.parse_args()
print(args)
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.moe_quant, args.profile, args.device
)