Skip to content

Commit cd706df

Browse files
support mla cp exchanging latent
1 parent 6ba98d4 commit cd706df

File tree

8 files changed

+710
-206
lines changed

8 files changed

+710
-206
lines changed

tests/pytorch/attention/run_attention_with_cp.py

Lines changed: 172 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,16 @@ def run_dpa_with_cp(
118118
config.num_gqa_groups,
119119
config.head_dim_v,
120120
)
121+
kv_compressed_shape = (
122+
config.batch_size,
123+
config.max_seqlen_kv,
124+
config.kv_lora_rank or 0,
125+
)
126+
k_pos_emb_shape = (
127+
config.batch_size,
128+
config.max_seqlen_kv,
129+
config.qk_pos_emb_head_dim or 0,
130+
)
121131
attn_output_shape = (
122132
config.batch_size,
123133
config.max_seqlen_q,
@@ -146,6 +156,16 @@ def run_dpa_with_cp(
146156
config.num_gqa_groups,
147157
config.head_dim_v,
148158
)
159+
kv_compressed_shape = (
160+
config.max_seqlen_kv,
161+
config.batch_size,
162+
config.kv_lora_rank or 0,
163+
)
164+
k_pos_emb_shape = (
165+
config.max_seqlen_kv,
166+
config.batch_size,
167+
config.qk_pos_emb_head_dim or 0,
168+
)
149169
attn_output_shape = (
150170
config.max_seqlen_q,
151171
config.batch_size,
@@ -162,15 +182,23 @@ def run_dpa_with_cp(
162182
config.head_dim_qk,
163183
)
164184
k_input_shape = (
165-
config.batch_size * config.max_seqlen_q,
185+
config.batch_size * config.max_seqlen_kv,
166186
config.num_gqa_groups,
167187
config.head_dim_qk,
168188
)
169189
v_input_shape = (
170-
config.batch_size * config.max_seqlen_q,
190+
config.batch_size * config.max_seqlen_kv,
171191
config.num_gqa_groups,
172192
config.head_dim_v,
173193
)
194+
kv_compressed_shape = (
195+
config.batch_size * config.max_seqlen_kv,
196+
config.kv_lora_rank or 0,
197+
)
198+
k_pos_emb_shape = (
199+
config.batch_size * config.max_seqlen_kv,
200+
config.qk_pos_emb_head_dim or 0,
201+
)
174202
attn_output_shape = (
175203
config.batch_size * config.max_seqlen_q,
176204
config.num_heads * config.head_dim_v,
@@ -193,9 +221,35 @@ def run_dpa_with_cp(
193221
else:
194222
assert False, f"{qkv_format} is an unsupported qkv_format!"
195223

196-
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
197-
k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda()
198-
v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda()
224+
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda().requires_grad_(True)
225+
if config.kv_lora_rank is None:
226+
k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda().requires_grad_(True)
227+
v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda().requires_grad_(True)
228+
kv_up_proj = None
229+
kv_compressed, k_pos_emb = None, None
230+
else:
231+
kv_compressed = torch.randn(kv_compressed_shape, dtype=dtypes[dtype]).cuda().requires_grad_(True)
232+
k_pos_emb = torch.randn(k_pos_emb_shape, dtype=dtypes[dtype]).cuda().requires_grad_(True)
233+
head_dim_k_no_pe = config.head_dim_qk - config.qk_pos_emb_head_dim
234+
linear = torch.nn.Linear(
235+
config.kv_lora_rank,
236+
config.num_heads * (head_dim_k_no_pe + config.head_dim_v),
237+
bias=False
238+
).cuda().to(dtypes[dtype])
239+
def kv_up_proj(kv_compressed, k_pos_emb):
240+
kv = linear(kv_compressed).view(*kv_compressed.shape[:-1], config.num_heads, -1)
241+
k_no_pe, v = torch.split(kv, [head_dim_k_no_pe, config.head_dim_v], dim=-1)
242+
k_pos_emb = torch.unsqueeze(k_pos_emb, -2)
243+
if k_pos_emb.ndim == 5:
244+
k_pos_emb = k_pos_emb.expand(-1, -1, -1, config.num_heads, -1)
245+
elif k_pos_emb.ndim == 4:
246+
k_pos_emb = k_pos_emb.expand(-1, -1, config.num_heads, -1)
247+
else:
248+
assert k_pos_emb.ndim == 3, f"{k_pos_emb.shape=} is not supported!"
249+
k_pos_emb = k_pos_emb.expand(-1, config.num_heads, -1)
250+
k = torch.cat([k_no_pe, k_pos_emb], dim=-1)
251+
return k, v
252+
k, v = kv_up_proj(kv_compressed, k_pos_emb)
199253
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
200254
dout_quantizer = Float8Quantizer(
201255
fp8_dtype=tex.DType.kFloat8E5M2,
@@ -211,9 +265,6 @@ def run_dpa_with_cp(
211265
bias = None
212266

213267
# run core_attn without CP
214-
for x in [q, k, v]:
215-
x.requires_grad = True
216-
217268
if dtype == "fp8":
218269
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
219270
else:
@@ -238,38 +289,61 @@ def run_dpa_with_cp(
238289
out.backward(dout)
239290

240291
# run core_attn wit CP
241-
q_, k_, v_, dout_, *rest = [
242-
x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])
292+
q_, dout_, *rest = [
293+
x.clone().detach() for x in [q, dout] + ([] if bias is None else [bias])
243294
]
295+
if config.kv_lora_rank is None:
296+
k_ = k.clone().detach()
297+
v_ = v.clone().detach()
298+
kv_compressed_ = None
299+
k_pos_emb_ = None
300+
else:
301+
k_ = None
302+
v_ = None
303+
kv_compressed_ = kv_compressed.clone().detach()
304+
k_pos_emb_ = k_pos_emb.clone().detach()
244305
bias_ = rest[0] if len(rest) else None
245306
if qkv_format == "bshd" or qkv_format == "sbhd":
246307
seq_dim = qkv_format.index("s")
247-
q_, k_, v_, dout_ = [
308+
q_, k_, v_, kv_compressed_, k_pos_emb_, dout_ = [
248309
x.view(
249310
*x.shape[:seq_dim],
250311
2 * world_size,
251312
x.shape[seq_dim] // (2 * world_size),
252313
*x.shape[(seq_dim + 1) :],
253-
)
254-
for x in [q_, k_, v_, dout_]
314+
) if x is not None else None
315+
for x in [q_, k_, v_, kv_compressed_, k_pos_emb_, dout_]
255316
]
256317
seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device)
257-
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
258-
q_, k_, v_, dout_ = [
259-
x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_]
318+
q_, k_, v_, kv_compressed_, k_pos_emb_, dout_ = [
319+
x.index_select(seq_dim, seq_idx) if x is not None else None
320+
for x in [q_, k_, v_, kv_compressed_, k_pos_emb_, dout_]
321+
]
322+
q_, k_, v_, kv_compressed_, k_pos_emb_, dout_ = [
323+
x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) if x is not None else None
324+
for x in [q_, k_, v_, kv_compressed_, k_pos_emb_, dout_]
260325
]
261326
elif qkv_format == "thd":
262327
seq_idx_q = tex.thd_get_partitioned_indices(
263-
cu_seqlens_q_padded, q_.shape[0], world_size, rank
328+
cu_seqlens_q_padded, q_input_shape[0], world_size, rank
264329
)
265330
seq_idx_kv = tex.thd_get_partitioned_indices(
266-
cu_seqlens_kv_padded, k_.shape[0], world_size, rank
331+
cu_seqlens_kv_padded, k_input_shape[0], world_size, rank
267332
)
268333
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
269-
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
334+
if config.kv_lora_rank is None:
335+
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
336+
else:
337+
kv_compressed_, k_pos_emb_ = [
338+
x.index_select(0, seq_idx_kv) for x in [kv_compressed_, k_pos_emb_]
339+
]
340+
270341
else:
271342
assert False, f"{qkv_format} is an unsupported qkv_format!"
272-
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
343+
q_, k_, v_, kv_compressed_, k_pos_emb_ = [
344+
x.requires_grad_() if x is not None else None
345+
for x in [q_, k_, v_, kv_compressed_, k_pos_emb_]
346+
]
273347
if bias_ is not None:
274348
bias_ = bias_.view(
275349
*bias_.shape[:-2], 2 * world_size, bias_.shape[-2] // (2 * world_size), bias_.shape[-1]
@@ -300,6 +374,9 @@ def run_dpa_with_cp(
300374
cu_seqlens_kv=cu_seqlens_kv,
301375
cu_seqlens_q_padded=cu_seqlens_q_padded,
302376
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
377+
kv_compressed=kv_compressed_,
378+
k_pos_emb=k_pos_emb_,
379+
kv_up_proj_fn=kv_up_proj,
303380
)
304381
if fp8_mha:
305382
dout_fp8_ = dout_quantizer(dout_)
@@ -313,30 +390,53 @@ def run_dpa_with_cp(
313390
out = out.dequantize()
314391
out_ = out_.dequantize()
315392

316-
for x in [out_, q_.grad, k_.grad, v_.grad]:
393+
for x in [out_, q_.grad] + (
394+
[k_.grad, v_.grad]
395+
if config.kv_lora_rank is None else
396+
[kv_compressed_.grad, k_pos_emb_.grad]
397+
):
317398
assert torch.all(~torch.isnan(x))
318399
assert torch.all(~torch.isinf(x))
319400

320401
# compare results with and without CP
321402
if qkv_format == "bshd" or qkv_format == "sbhd":
322-
dq, dk, dv, out = [
403+
dq, dk, dv, dkv_compressed, dk_pos_emb = [
404+
x.grad if x is not None else None
405+
for x in [q, k, v, kv_compressed, k_pos_emb]
406+
]
407+
dq, dk, dv, dkv_compressed, dk_pos_emb, out = [
323408
x.view(
324409
*x.shape[:seq_dim],
325410
2 * world_size,
326411
x.shape[seq_dim] // (2 * world_size),
327412
*x.shape[(seq_dim + 1) :],
328-
)
329-
for x in [q.grad, k.grad, v.grad, out]
413+
).index_select(seq_dim, seq_idx)
414+
if x is not None else None
415+
for x in [dq, dk, dv, dkv_compressed, dk_pos_emb, out]
416+
]
417+
dq_, dk_, dv_, dkv_compressed_, dk_pos_emb_ = [
418+
x.grad if x is not None else None
419+
for x in [q_, k_, v_, kv_compressed_, k_pos_emb_]
330420
]
331-
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
332-
dq_, dk_, dv_, out_ = [
421+
dq_, dk_, dv_, dkv_compressed_, dk_pos_emb_, out_ = [
333422
x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :])
334-
for x in [q_.grad, k_.grad, v_.grad, out_]
423+
if x is not None else None
424+
for x in [dq_, dk_, dv_, dkv_compressed_, dk_pos_emb_, out_]
335425
]
336426
elif qkv_format == "thd":
337427
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]]
338-
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]]
339-
dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_]
428+
if config.kv_lora_rank is None:
429+
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]]
430+
dkv_compressed, dk_pos_emb = None, None
431+
else:
432+
dk, dv = None, None
433+
dkv_compressed, dk_pos_emb = [
434+
x.index_select(0, seq_idx_kv).contiguous() for x in [kv_compressed.grad, k_pos_emb.grad]
435+
]
436+
dq_, dk_, dv_, dkv_compressed_, dk_pos_emb_ = [
437+
x.grad if x is not None else None
438+
for x in [q_, k_, v_, kv_compressed_, k_pos_emb_]
439+
]
340440
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
341441
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
342442
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
@@ -359,7 +459,9 @@ def run_dpa_with_cp(
359459
)
360460
cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv
361461
num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1]
362-
for x in [dk, dv, dk_, dv_]:
462+
for x in [dk, dv, dk_, dv_, dkv_compressed, dk_pos_emb, dkv_compressed_, dk_pos_emb_]:
463+
if x is None:
464+
continue
363465
assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0
364466
for b in range(config.batch_size):
365467
assert (
@@ -392,34 +494,61 @@ def run_dpa_with_cp(
392494
def _rmse(a, b):
393495
return torch.sqrt((a - b).square().mean()).item()
394496

395-
def _error(a, b):
497+
def _error(a, b, tensor_name):
396498
if dtype != "fp8":
397-
torch.testing.assert_close(a, b, **tols)
499+
try:
500+
torch.testing.assert_close(a, b, **tols)
501+
except Exception as e:
502+
logging.debug(f"{tensor_name} is not close.\n{e}")
503+
raise e
398504
else:
399505
try:
400506
torch.testing.assert_close(a, b, **tols)
401507
except Exception as e:
402-
logging.debug(e)
508+
logging.debug(f"{tensor_name} is not close.\n{e}")
403509

404510
rmse = _rmse(a, b)
405511
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
406512
assert (
407513
rmse < rmse_tol * rmse_range
408-
), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
409-
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
514+
), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f}) for {}".format(
515+
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range, tensor_name
410516
)
411517

412518
if qkv_format == "bshd":
413-
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
414-
_error(a[:, 0], b[:, 0])
415-
_error(a[:, 1], b[:, 1])
519+
for tensor_name, a, b in zip(
520+
["out", "dq", "dk", "dv", "dkv_compressed", "dk_pos_emb"],
521+
[out_, dq_, dk_, dv_, dkv_compressed_, dk_pos_emb_],
522+
[out, dq, dk, dv, dkv_compressed, dk_pos_emb]
523+
):
524+
if a is None or b is None:
525+
a_is_None = "is" if a is None else "is not"
526+
b_is_None = "is" if b is None else "is not"
527+
assert a is None and b is None, f"{tensor_name}_ {a_is_None} None and {tensor_name} {b_is_None} None!"
528+
continue
529+
_error(a[:, 0], b[:, 0], tensor_name)
530+
_error(a[:, 1], b[:, 1], tensor_name)
416531
elif qkv_format == "sbhd":
417-
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
418-
_error(a[0], b[0])
419-
_error(a[1], b[1])
532+
for tensor_name, a, b in zip(
533+
["out", "dq", "dk", "dv", "dkv_compressed", "dk_pos_emb"],
534+
[out_, dq_, dk_, dv_, dkv_compressed_, dk_pos_emb_],
535+
[out, dq, dk, dv, dkv_compressed, dk_pos_emb]
536+
):
537+
if a is None or b is None:
538+
assert a is None and b is None, f"{tensor_name} and {tensor_name}_ are not both None!"
539+
continue
540+
_error(a[0], b[0], tensor_name)
541+
_error(a[1], b[1], tensor_name)
420542
elif qkv_format == "thd":
421-
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
422-
_error(a, b)
543+
for tensor_name, a, b in zip(
544+
["out", "dq", "dk", "dv", "dkv_compressed", "dk_pos_emb"],
545+
[out_, dq_, dk_, dv_, dkv_compressed_, dk_pos_emb_],
546+
[out, dq, dk, dv, dkv_compressed, dk_pos_emb]
547+
):
548+
if a is None or b is None:
549+
assert a is None and b is None, f"{tensor_name} and {tensor_name}_ are not both None!"
550+
continue
551+
_error(a, b, tensor_name)
423552
else:
424553
assert False, f"{qkv_format} is an unsupported qkv_format!"
425554

tests/pytorch/attention/test_attention_with_cp.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,21 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
127127
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
128128
), # MLA
129129
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
130+
"cp_4_0": ModelConfig(
131+
2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64,
132+
kv_lora_rank=512, qk_pos_emb_head_dim=32
133+
), # MLA
134+
"cp_4_1": ModelConfig(
135+
2, 4096, 12, 128, head_dim_v=64, kv_lora_rank=512, qk_pos_emb_head_dim=32
136+
), # MLA
137+
"cp_4_2": ModelConfig(
138+
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias",
139+
head_dim_v=64, kv_lora_rank=512, qk_pos_emb_head_dim=32
140+
), # MLA
141+
"cp_4_3": ModelConfig(
142+
2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64,
143+
kv_lora_rank=512, qk_pos_emb_head_dim=32
144+
), # MLA
130145
}
131146

132147

@@ -181,6 +196,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
181196
pytest.skip("Only fp8 works with fp8_mha=True!")
182197
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
183198
pytest.skip("MLA CP currently only support KV P2P!")
199+
if "a2a" in cp_comm_type and config.kv_lora_rank != 0:
200+
pytest.skip("MLA CP exchanging latent does not support QKV A2A!")
184201
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
185202
pytest.skip("MLA CP currently does not support FP8 attention!")
186203
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}

tests/pytorch/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def __init__(
147147
max_seqlen_kv: int = None,
148148
num_gqa_groups: int = None,
149149
head_dim_v: int = None,
150+
kv_lora_rank: int = None,
151+
qk_pos_emb_head_dim: int = None,
150152
dropout_p: float = 0.0,
151153
attn_mask_type: str = "no_mask",
152154
attn_bias_type: str = "no_bias",
@@ -169,6 +171,8 @@ def __init__(
169171
self.kv_channels = self.head_dim_qk
170172
else:
171173
self.kv_channels = (self.head_dim_qk, self.head_dim_v)
174+
self.kv_lora_rank = kv_lora_rank
175+
self.qk_pos_emb_head_dim = qk_pos_emb_head_dim
172176
self.hidden_size = self.num_heads * self.head_dim_qk
173177
self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v
174178
self.dropout_p = dropout_p

0 commit comments

Comments
 (0)