-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcustom_llama_attention.py
287 lines (241 loc) · 13.9 KB
/
custom_llama_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import math
from typing import List, Optional, Tuple, Union
import os.path
import torch
import torch.nn.functional as F
from torch import nn
import time
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.cache_utils import Cache
from layer.sparse_linear import SparseLinear
from layer.dense_linear import DenseLinear
# from layer.avx_sparse_linear import AvxSparseLinear
# SAVED_FILES_DIR = "saved_kv_state_files_per_layer_amx_new/per_head"
SAVED_FILES_DIR = "saved_kv_state_files_per_layer_amx_new/per_layer"
def sparsify(vals, prune_percentage):
if prune_percentage == 0:
return vals
k = int(vals.numel() * (prune_percentage / 100))
original_dtype = vals.dtype
output_float = vals.float()
flat_output = output_float.reshape(-1)
threshold = torch.kthvalue(flat_output.abs(), k).values
mask = (output_float.abs() >= threshold)
# mask = 0
pruned_output = (output_float * mask).to(original_dtype)
return pruned_output
def sparsify_per_head(vals, prune_percentage):
# Loop through the last 2 dimensions only
for i in range(vals.shape[0]):
for j in range(vals.shape[1]):
vals[i, j] = sparsify(vals[i, j], prune_percentage)
return vals
def saved_file_prefix(layer_id, pruning, ctx_length, kernel):
return f"{SAVED_FILES_DIR}/kernel_{kernel}_ctx_{ctx_length}_layer_{layer_id}_pruning_{pruning}"
class CustomLlamaAttention(LlamaAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None, use_custom_k: bool = False, use_custom_v: bool = False, k_pruning: int = 50, v_pruning: int = 50, kernel_type = 'sparse', ctx_length: int = 0):
super().__init__(config, layer_idx)
self.key_computer = None
self.value_computer = None
self.use_custom_k = use_custom_k
self.use_custom_v = use_custom_v
self.k_pruning = k_pruning
self.v_pruning = v_pruning
self.cached_key_states = torch.empty(0)
self.cached_value_states = torch.empty(0)
if kernel_type == 'sparse':
self.kernel = SparseLinear
elif kernel_type == 'dense':
self.kernel = DenseLinear
else:
raise ValueError("Invalid kernel type")
# Add timing accumulators
self.reset_timing_stats()
print(f"Using custom k: {self.use_custom_k}, using custom v: {self.use_custom_v}, k pruning: {self.k_pruning}, v pruning: {self.v_pruning}, with kernel: {kernel_type}")
def reset_timing_stats(self):
"""Reset accumulated timing statistics"""
self.timing_stats = {
'time_1': 0.0,
'time_2': 0.0,
'time_3': {'total': 0.0, '3_1': 0.0, '3_2': 0.0, '3_3': 0.0, '3_4': 0.0},
'time_4': 0.0,
'time_5': 0.0,
'attn_weights': 0.0,
'attn_output': 0.0,
'total': 0.0,
'call_count': 0
}
def load_computers(self, ctx_length):
if self.use_custom_k and os.path.exists(f"{saved_file_prefix(self.layer_idx, self.k_pruning, ctx_length, self.kernel)}_k.pt"):
self.key_computer = self.kernel.load_from_file(f"{saved_file_prefix(self.layer_idx, self.k_pruning, ctx_length, self.kernel)}_k.pt")
if self.use_custom_v and os.path.exists(f"{saved_file_prefix(self.layer_idx, self.v_pruning, ctx_length, self.kernel)}_v.pt"):
self.value_computer = self.kernel.load_from_file(f"{saved_file_prefix(self.layer_idx, self.v_pruning, ctx_length, self.kernel)}_v.pt")
def update_cache(self, key_states, value_states):
# Append key_states and value_states to current cache
if self.cached_key_states.numel() == 0:
self.cached_key_states = key_states
self.cached_value_states = value_states
else:
self.cached_key_states = torch.cat((self.cached_key_states, key_states), dim=2)
self.cached_value_states = torch.cat((self.cached_value_states, value_states), dim=2)
return self.cached_key_states, self.cached_value_states
def reset_cache(self):
self.cached_key_states = torch.empty(0)
self.cached_value_states = torch.empty(0)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# all_function_start = time.time()
bsz, q_len, _ = hidden_states.size()
# past_key_value.key_cache = []
# past_key_value.value_cache = []
# time_1_start = time.time()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# import pdb; pdb.set_trace()
# time_1_end = time.time()
# time_2_start = time.time()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# time_2_end = time.time()
# time_3_start = time.time()
# time_3_1_start = time.time()
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
# time_3_1_end = time.time()
# time_3_2_start = time.time()
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# time_3_2_end = time.time()
# time_3_3_start = time.time()
# import pdb; pdb.set_trace()
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = self.update_cache(key_states, value_states)
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# time_3_3_end = time.time()
# print(f"Current key_states: {key_states.shape}, current value_states: {value_states.shape}")
# time_3_4_start = time.time()
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# time_3_4_end = time.time()
# time_3_end = time.time()
# print("Time taken for time 3_1: ", time_3_1_end - time_3_1_start)
# print("Time taken for time 3_2: ", time_3_2_end - time_3_2_start)
# print("Time taken for time 3_3: ", time_3_3_end - time_3_3_start)
# print("Time taken for time 3_4: ", time_3_4_end - time_3_4_start)
# import pdb; pdb.set_trace()
# attn_weights_start = time.time()
if self.use_custom_k:
if self.key_computer is None:
key_states = sparsify(key_states, self.k_pruning)
self.key_computer = self.kernel.from_batched_weights(key_states.contiguous())
self.key_computer.save_state(f"{saved_file_prefix(self.layer_idx, self.k_pruning, q_len, self.kernel)}_k.pt")
attn_weights = self.key_computer.matmul(query_states.contiguous()) / math.sqrt(self.head_dim)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
# attn_weights_end = time.time()
# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
# time_4_start = time.time()
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# time_4_end = time.time()
# custom_attn_output = custom_attention_output(attn_weights, value_states)
# import pdb; pdb.set_trace()
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
# attn_otuput_start = time.time()
if self.use_custom_v:
if self.value_computer is None:
value_states = sparsify(value_states, self.v_pruning)
self.value_computer = self.kernel.from_batched_weights(value_states.contiguous().transpose(2, 3))
self.value_computer.save_state(f"{saved_file_prefix(self.layer_idx, self.v_pruning, q_len, self.kernel)}_v.pt")
attn_output = self.value_computer.matmul(attn_weights.contiguous())
else:
# import pdb; pdb.set_trace()
attn_output = torch.matmul(attn_weights, value_states[:, :, :attn_weights.size(-1), :])
# attn_output_end = time.time()
attn_output = attn_output.to(torch.bfloat16)
# attn_output = torch.matmul(attn_weights, value_states[:, :, :attn_weights.size(-1), :])
# import pdb; pdb.set_trace()
# time_5_start = time.time()
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
# time_5_end = time.time()
# all_function_end = time.time()
# print(f"Time taken for attention weights: {attn_weights_end - attn_weights_start}, time taken for attn_output: {attn_output_end - attn_otuput_start}, total time taken: {all_function_end - all_function_start}")
# print(f"Time taken for time 1: {time_1_end - time_1_start}, time taken for time 2: {time_2_end - time_2_start}, time taken for time 3: {time_3_end - time_3_start}, time taken for time 4: {time_4_end - time_4_start}, time taken for time 5: {time_5_end - time_5_start}")
# Instead of printing, accumulate times
# self.timing_stats['time_1'] += time_1_end - time_1_start
# self.timing_stats['time_2'] += time_2_end - time_2_start
# self.timing_stats['time_3']['total'] += time_3_end - time_3_start
# self.timing_stats['time_3']['3_1'] += time_3_1_end - time_3_1_start
# self.timing_stats['time_3']['3_2'] += time_3_2_end - time_3_2_start
# self.timing_stats['time_3']['3_3'] += time_3_3_end - time_3_3_start
# self.timing_stats['time_3']['3_4'] += time_3_4_end - time_3_4_start
# self.timing_stats['time_4'] += time_4_end - time_4_start
# self.timing_stats['time_5'] += time_5_end - time_5_start
# self.timing_stats['attn_weights'] += attn_weights_end - attn_weights_start
# self.timing_stats['attn_output'] += attn_output_end - attn_otuput_start
# self.timing_stats['total'] += all_function_end - all_function_start
# self.timing_stats['call_count'] += 1
return attn_output, attn_weights, past_key_value
def print_timing_stats(self):
"""Print accumulated timing statistics"""
if self.timing_stats['call_count'] == 0:
return
print(f"\nTiming statistics for attention layer {self.layer_idx} (averaged over {self.timing_stats['call_count']} calls):")
for key, value in self.timing_stats.items():
if key == 'time_3':
print(f"Time 3 breakdown:")
for subkey, subvalue in value.items():
avg = subvalue / self.timing_stats['call_count']
print(f" - {subkey}: {avg:.6f}s")
elif key != 'call_count':
avg = value / self.timing_stats['call_count']
print(f"{key}: {avg:.6f}s")