-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest_sparse_attention.py
109 lines (90 loc) · 3.78 KB
/
test_sparse_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from torch import nn
import torch.quantization as quant
import pytest
import itertools
from layer.sparse_linear import SparseLinear
def set_weights_to_zero(weights, percentage):
# Calculate the number of weights to set to 0
num_weights = weights.numel()
num_zeros = int(num_weights * percentage / 100.0)
# Create a mask tensor
mask = torch.ones_like(weights)
# Randomly select indices to set to 0
zero_indices = torch.randperm(num_weights)[:num_zeros]
# Set the selected indices in the mask to 0
mask.view(-1)[zero_indices] = 0
# Apply the mask to the weights
return weights * mask
test_data = [
# (16,16,16), # Too small for our assumptions of handling big cases.
(32,32,32),
(128,128,128),
(1024,1024,1024),
(32,128,32),
(32,1024,32),
(32,128,64),
(32,128,1024),
(1024,128,16),
]
@pytest.mark.parametrize("out_rows,inner_dim,out_cols", test_data)
def test_matmul_using_kernel(out_rows, inner_dim, out_cols):
input = torch.randn(out_rows, inner_dim, dtype=torch.bfloat16)
weights = torch.randn(inner_dim, out_cols, dtype=torch.bfloat16)
ans = torch.matmul(input, weights)
computer = SparseLinear.from_weights(weights.transpose(0, 1))
output = computer(input.unsqueeze(0)).squeeze(0)
# import pdb; pdb.set_trace()
torch.testing.assert_close(ans, output)
test_data = [
(1024,1024,1024,1),
(1024,1024,1024,20),
(1024,1024,1024,50),
(1024,1024,1024,90),
]
@pytest.mark.parametrize("out_rows,inner_dim,out_cols,percentage", test_data)
def test_sparse_matmul_using_kernel(out_rows, inner_dim, out_cols, percentage):
input = torch.randn(out_rows, inner_dim, dtype=torch.bfloat16)
weights = torch.randn(inner_dim, out_cols, dtype=torch.bfloat16)
weights = set_weights_to_zero(weights, percentage)
ans = torch.matmul(input, weights)
computer = SparseLinear.from_weights(weights.transpose(0, 1))
output = computer(input.unsqueeze(0)).squeeze(0)
# import pdb; pdb.set_trace()
torch.testing.assert_close(ans, output)
test_data = [
# (1, 32, 16, 16),
# (1, 32, 2, 16), # Need to fix the kernel to handle this case. It's caused by filling multiple rows in the same iteration which assumes that all rows will be filled.
# (1, 32, 32, 16),
# (1, 32, 16, 128),
(1, 32, 128, 128),
(1, 32, 256, 128),
(1, 32, 512, 128),
(1, 32, 1024, 128),
# (16, 128, 16, 4096),
]
@pytest.mark.parametrize("batch_size,num_heads,seq_len,head_dim", test_data)
def test_attention_matrix_computation(batch_size, num_heads, seq_len, head_dim):
query_states = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=torch.bfloat16)
key_states = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=torch.bfloat16)
ans = torch.matmul(query_states, key_states.transpose(2,3))
output = SparseLinear.batched_matmul(query_states, key_states.transpose(2,3))
torch.testing.assert_close(ans, output)
test_data = [
(1, 32, 16, 128, 10),
(1, 32, 16, 128, 20),
(1, 32, 16, 128, 50),
(1, 32, 16, 128, 90),
(1, 32, 1024, 128, 10),
(1, 32, 1024, 128, 20),
(1, 32, 1024, 128, 50),
(1, 32, 1024, 128, 90),
]
@pytest.mark.parametrize("batch_size,num_heads,seq_len,head_dim,percentage", test_data)
def test_sparse_attention_matrix_computation(batch_size, num_heads, seq_len, head_dim, percentage):
query_states = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=torch.bfloat16)
key_states = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=torch.bfloat16)
key_states = set_weights_to_zero(key_states, 50)
ans = torch.matmul(query_states, key_states.transpose(2,3))
output = SparseLinear.batched_matmul(query_states, key_states.transpose(2,3))
torch.testing.assert_close(ans, output)