-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_pt.py
109 lines (87 loc) · 2.23 KB
/
test_pt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import time
import sys
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def pw(size, iters, warmup):
a = torch.randn(size).to(device)
def op(K):
b = a
for i in range(K):
b = b + a
o = b.sum().item()
op(warmup)
t0 = time.time()
op(iters)
t1 = time.time()
bs = size * 4 * iters + size * 4
gbs = bs / (1e9 * (t1 - t0))
kitersec = iters / ((t1 - t0) * 1e3)
print(f"{kitersec:.3f}K iter/s", f"{gbs:.3f} GB/s")
def mm(size, iters, warmup):
N = size
a = torch.randn(N, N).to(device)
b = torch.eye(N).to(device)
def op(K):
c = a
for i in range(K):
c = c @ b
o = c.sum().item()
op(warmup)
t0 = time.time()
op(iters)
t1 = time.time()
flops = N**3 * 2 * iters + N * N
gflops = flops / (1e9 * (t1 - t0))
print(f"{gflops:.3f} GFlop/s")
def mm_pw(size, iters, warmup):
N = size
a = torch.randn(64, N).to(device)
b = torch.eye(N).to(device)
c = torch.randn(1, N).to(device)
def op(K):
d = a
for i in range(K):
d = d @ b
for j in range(5):
d = d + c
o = d.sum().item()
op(warmup)
t0 = time.time()
op(iters)
t1 = time.time()
kitersec = iters / ((t1 - t0) * 1e3)
print(f"{kitersec:.3f}K iter/s")
def mha(size, iters, warmup):
N = size
mha_model = torch.nn.MultiheadAttention(N, 8).eval().to(device)
qk = torch.rand((32, 32, N)).to(device)
v = torch.rand((32, 32, N)).to(device)
def op(K):
for i in range(K):
out = mha_model(qk, qk, v)[0]
o = out.sum().item()
op(warmup)
t0 = time.time()
op(iters)
t1 = time.time()
itersec = iters / (t1 - t0)
print(f"{itersec:.3f} iter/s")
def err(n):
print(f"usage: python {n} {pw,mm} size iters warmup")
exit(1)
if len(sys.argv) < 5:
err(sys.argv[0])
var = sys.argv[1]
if not var in ["mm", "pw", "mm_pw", "mha"]:
err(sys.argv[0])
size = int(sys.argv[2])
iters = int(sys.argv[3])
warmup = int(sys.argv[4])
if var == "pw":
pw(size, iters, warmup)
elif var == "mm_pw":
mm_pw(size, iters, warmup)
elif var == "mm":
mm(size, iters, warmup)
elif var == "mha":
mha(size, iters, warmup)