Skip to content

Commit 9a482bd

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent a0d0d54 commit 9a482bd

File tree

6 files changed

+187
-127
lines changed

6 files changed

+187
-127
lines changed

tests/pytorch/attention/run_attention_with_cp.py

Lines changed: 59 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -228,14 +228,21 @@ def run_dpa_with_cp(
228228
kv_up_proj = None
229229
kv_compressed, k_pos_emb = None, None
230230
else:
231-
kv_compressed = torch.randn(kv_compressed_shape, dtype=dtypes[dtype]).cuda().requires_grad_(True)
231+
kv_compressed = (
232+
torch.randn(kv_compressed_shape, dtype=dtypes[dtype]).cuda().requires_grad_(True)
233+
)
232234
k_pos_emb = torch.randn(k_pos_emb_shape, dtype=dtypes[dtype]).cuda().requires_grad_(True)
233235
head_dim_k_no_pe = config.head_dim_qk - config.qk_pos_emb_head_dim
234-
linear = torch.nn.Linear(
235-
config.kv_lora_rank,
236-
config.num_heads * (head_dim_k_no_pe + config.head_dim_v),
237-
bias=False
238-
).cuda().to(dtypes[dtype])
236+
linear = (
237+
torch.nn.Linear(
238+
config.kv_lora_rank,
239+
config.num_heads * (head_dim_k_no_pe + config.head_dim_v),
240+
bias=False,
241+
)
242+
.cuda()
243+
.to(dtypes[dtype])
244+
)
245+
239246
def kv_up_proj(kv_compressed, k_pos_emb):
240247
kv = linear(kv_compressed).view(*kv_compressed.shape[:-1], config.num_heads, -1)
241248
k_no_pe, v = torch.split(kv, [head_dim_k_no_pe, config.head_dim_v], dim=-1)
@@ -249,6 +256,7 @@ def kv_up_proj(kv_compressed, k_pos_emb):
249256
k_pos_emb = k_pos_emb.expand(-1, config.num_heads, -1)
250257
k = torch.cat([k_no_pe, k_pos_emb], dim=-1)
251258
return k, v
259+
252260
k, v = kv_up_proj(kv_compressed, k_pos_emb)
253261
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
254262
dout_quantizer = Float8Quantizer(
@@ -289,9 +297,7 @@ def kv_up_proj(kv_compressed, k_pos_emb):
289297
out.backward(dout)
290298

291299
# run core_attn wit CP
292-
q_, dout_, *rest = [
293-
x.clone().detach() for x in [q, dout] + ([] if bias is None else [bias])
294-
]
300+
q_, dout_, *rest = [x.clone().detach() for x in [q, dout] + ([] if bias is None else [bias])]
295301
if config.kv_lora_rank is None:
296302
k_ = k.clone().detach()
297303
v_ = v.clone().detach()
@@ -306,12 +312,16 @@ def kv_up_proj(kv_compressed, k_pos_emb):
306312
if qkv_format == "bshd" or qkv_format == "sbhd":
307313
seq_dim = qkv_format.index("s")
308314
q_, k_, v_, kv_compressed_, k_pos_emb_, dout_ = [
309-
x.view(
310-
*x.shape[:seq_dim],
311-
2 * world_size,
312-
x.shape[seq_dim] // (2 * world_size),
313-
*x.shape[(seq_dim + 1) :],
314-
) if x is not None else None
315+
(
316+
x.view(
317+
*x.shape[:seq_dim],
318+
2 * world_size,
319+
x.shape[seq_dim] // (2 * world_size),
320+
*x.shape[(seq_dim + 1) :],
321+
)
322+
if x is not None
323+
else None
324+
)
315325
for x in [q_, k_, v_, kv_compressed_, k_pos_emb_, dout_]
316326
]
317327
seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device)
@@ -392,35 +402,39 @@ def kv_up_proj(kv_compressed, k_pos_emb):
392402

393403
for x in [out_, q_.grad] + (
394404
[k_.grad, v_.grad]
395-
if config.kv_lora_rank is None else
396-
[kv_compressed_.grad, k_pos_emb_.grad]
405+
if config.kv_lora_rank is None
406+
else [kv_compressed_.grad, k_pos_emb_.grad]
397407
):
398408
assert torch.all(~torch.isnan(x))
399409
assert torch.all(~torch.isinf(x))
400410

401411
# compare results with and without CP
402412
if qkv_format == "bshd" or qkv_format == "sbhd":
403413
dq, dk, dv, dkv_compressed, dk_pos_emb = [
404-
x.grad if x is not None else None
405-
for x in [q, k, v, kv_compressed, k_pos_emb]
414+
x.grad if x is not None else None for x in [q, k, v, kv_compressed, k_pos_emb]
406415
]
407416
dq, dk, dv, dkv_compressed, dk_pos_emb, out = [
408-
x.view(
409-
*x.shape[:seq_dim],
410-
2 * world_size,
411-
x.shape[seq_dim] // (2 * world_size),
412-
*x.shape[(seq_dim + 1) :],
413-
).index_select(seq_dim, seq_idx)
414-
if x is not None else None
417+
(
418+
x.view(
419+
*x.shape[:seq_dim],
420+
2 * world_size,
421+
x.shape[seq_dim] // (2 * world_size),
422+
*x.shape[(seq_dim + 1) :],
423+
).index_select(seq_dim, seq_idx)
424+
if x is not None
425+
else None
426+
)
415427
for x in [dq, dk, dv, dkv_compressed, dk_pos_emb, out]
416428
]
417429
dq_, dk_, dv_, dkv_compressed_, dk_pos_emb_ = [
418-
x.grad if x is not None else None
419-
for x in [q_, k_, v_, kv_compressed_, k_pos_emb_]
430+
x.grad if x is not None else None for x in [q_, k_, v_, kv_compressed_, k_pos_emb_]
420431
]
421432
dq_, dk_, dv_, dkv_compressed_, dk_pos_emb_, out_ = [
422-
x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :])
423-
if x is not None else None
433+
(
434+
x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :])
435+
if x is not None
436+
else None
437+
)
424438
for x in [dq_, dk_, dv_, dkv_compressed_, dk_pos_emb_, out_]
425439
]
426440
elif qkv_format == "thd":
@@ -431,11 +445,11 @@ def kv_up_proj(kv_compressed, k_pos_emb):
431445
else:
432446
dk, dv = None, None
433447
dkv_compressed, dk_pos_emb = [
434-
x.index_select(0, seq_idx_kv).contiguous() for x in [kv_compressed.grad, k_pos_emb.grad]
448+
x.index_select(0, seq_idx_kv).contiguous()
449+
for x in [kv_compressed.grad, k_pos_emb.grad]
435450
]
436451
dq_, dk_, dv_, dkv_compressed_, dk_pos_emb_ = [
437-
x.grad if x is not None else None
438-
for x in [q_, k_, v_, kv_compressed_, k_pos_emb_]
452+
x.grad if x is not None else None for x in [q_, k_, v_, kv_compressed_, k_pos_emb_]
439453
]
440454
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
441455
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
@@ -519,34 +533,40 @@ def _error(a, b, tensor_name):
519533
for tensor_name, a, b in zip(
520534
["out", "dq", "dk", "dv", "dkv_compressed", "dk_pos_emb"],
521535
[out_, dq_, dk_, dv_, dkv_compressed_, dk_pos_emb_],
522-
[out, dq, dk, dv, dkv_compressed, dk_pos_emb]
536+
[out, dq, dk, dv, dkv_compressed, dk_pos_emb],
523537
):
524538
if a is None or b is None:
525539
a_is_None = "is" if a is None else "is not"
526540
b_is_None = "is" if b is None else "is not"
527-
assert a is None and b is None, f"{tensor_name}_ {a_is_None} None and {tensor_name} {b_is_None} None!"
541+
assert (
542+
a is None and b is None
543+
), f"{tensor_name}_ {a_is_None} None and {tensor_name} {b_is_None} None!"
528544
continue
529545
_error(a[:, 0], b[:, 0], tensor_name)
530546
_error(a[:, 1], b[:, 1], tensor_name)
531547
elif qkv_format == "sbhd":
532548
for tensor_name, a, b in zip(
533549
["out", "dq", "dk", "dv", "dkv_compressed", "dk_pos_emb"],
534550
[out_, dq_, dk_, dv_, dkv_compressed_, dk_pos_emb_],
535-
[out, dq, dk, dv, dkv_compressed, dk_pos_emb]
551+
[out, dq, dk, dv, dkv_compressed, dk_pos_emb],
536552
):
537553
if a is None or b is None:
538-
assert a is None and b is None, f"{tensor_name} and {tensor_name}_ are not both None!"
554+
assert (
555+
a is None and b is None
556+
), f"{tensor_name} and {tensor_name}_ are not both None!"
539557
continue
540558
_error(a[0], b[0], tensor_name)
541559
_error(a[1], b[1], tensor_name)
542560
elif qkv_format == "thd":
543561
for tensor_name, a, b in zip(
544562
["out", "dq", "dk", "dv", "dkv_compressed", "dk_pos_emb"],
545563
[out_, dq_, dk_, dv_, dkv_compressed_, dk_pos_emb_],
546-
[out, dq, dk, dv, dkv_compressed, dk_pos_emb]
564+
[out, dq, dk, dv, dkv_compressed, dk_pos_emb],
547565
):
548566
if a is None or b is None:
549-
assert a is None and b is None, f"{tensor_name} and {tensor_name}_ are not both None!"
567+
assert (
568+
a is None and b is None
569+
), f"{tensor_name} and {tensor_name}_ are not both None!"
550570
continue
551571
_error(a, b, tensor_name)
552572
else:

tests/pytorch/attention/test_attention_with_cp.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,19 +128,38 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
128128
), # MLA
129129
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
130130
"cp_4_0": ModelConfig(
131-
2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64,
132-
kv_lora_rank=512, qk_pos_emb_head_dim=32
131+
2,
132+
4096,
133+
12,
134+
128,
135+
attn_mask_type="causal",
136+
head_dim_v=64,
137+
kv_lora_rank=512,
138+
qk_pos_emb_head_dim=32,
133139
), # MLA
134140
"cp_4_1": ModelConfig(
135141
2, 4096, 12, 128, head_dim_v=64, kv_lora_rank=512, qk_pos_emb_head_dim=32
136142
), # MLA
137143
"cp_4_2": ModelConfig(
138-
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias",
139-
head_dim_v=64, kv_lora_rank=512, qk_pos_emb_head_dim=32
144+
2,
145+
4096,
146+
12,
147+
128,
148+
attn_mask_type="causal",
149+
attn_bias_type="post_scale_bias",
150+
head_dim_v=64,
151+
kv_lora_rank=512,
152+
qk_pos_emb_head_dim=32,
140153
), # MLA
141154
"cp_4_3": ModelConfig(
142-
2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64,
143-
kv_lora_rank=512, qk_pos_emb_head_dim=32
155+
2,
156+
4096,
157+
12,
158+
128,
159+
attn_bias_type="post_scale_bias",
160+
head_dim_v=64,
161+
kv_lora_rank=512,
162+
qk_pos_emb_head_dim=32,
144163
), # MLA
145164
}
146165

0 commit comments

Comments
 (0)