forked from turboderp/exllama
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_profile.py
36 lines (27 loc) · 958 Bytes
/
test_profile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
import torch
import cProfile, pstats, io
from pstats import SortKey
tokenizer_model_path = "/mnt/str/models/llama-30b-4bit-128g/tokenizer.model"
model_config_path = "/mnt/str/models/llama-30b-4bit-128g/config.json"
model_path = "/mnt/str/models/llama-30b-4bit-128g/llama-30b-4bit-128g.safetensors"
tokenizer = ExLlamaTokenizer(tokenizer_model_path)
config = ExLlamaConfig(model_config_path)
config.model_path = model_path
model = ExLlama(config)
cache = ExLlamaCache(model)
ids = torch.randint(0, 31999, (1, 1024))
pr = cProfile.Profile()
pr.enable()
with torch.no_grad():
for i in range(128):
model.forward(ids, cache)
ids = torch.randint(0, 31999, (1, 1))
cache.current_seq_len = 0
pr.disable()
s = io.StringIO()
sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())