@@ -118,6 +118,16 @@ def run_dpa_with_cp(
118
118
config .num_gqa_groups ,
119
119
config .head_dim_v ,
120
120
)
121
+ kv_compressed_shape = (
122
+ config .batch_size ,
123
+ config .max_seqlen_kv ,
124
+ config .kv_lora_rank or 0 ,
125
+ )
126
+ k_pos_emb_shape = (
127
+ config .batch_size ,
128
+ config .max_seqlen_kv ,
129
+ config .qk_pos_emb_head_dim or 0 ,
130
+ )
121
131
attn_output_shape = (
122
132
config .batch_size ,
123
133
config .max_seqlen_q ,
@@ -146,6 +156,16 @@ def run_dpa_with_cp(
146
156
config .num_gqa_groups ,
147
157
config .head_dim_v ,
148
158
)
159
+ kv_compressed_shape = (
160
+ config .max_seqlen_kv ,
161
+ config .batch_size ,
162
+ config .kv_lora_rank or 0 ,
163
+ )
164
+ k_pos_emb_shape = (
165
+ config .max_seqlen_kv ,
166
+ config .batch_size ,
167
+ config .qk_pos_emb_head_dim or 0 ,
168
+ )
149
169
attn_output_shape = (
150
170
config .max_seqlen_q ,
151
171
config .batch_size ,
@@ -162,15 +182,23 @@ def run_dpa_with_cp(
162
182
config .head_dim_qk ,
163
183
)
164
184
k_input_shape = (
165
- config .batch_size * config .max_seqlen_q ,
185
+ config .batch_size * config .max_seqlen_kv ,
166
186
config .num_gqa_groups ,
167
187
config .head_dim_qk ,
168
188
)
169
189
v_input_shape = (
170
- config .batch_size * config .max_seqlen_q ,
190
+ config .batch_size * config .max_seqlen_kv ,
171
191
config .num_gqa_groups ,
172
192
config .head_dim_v ,
173
193
)
194
+ kv_compressed_shape = (
195
+ config .batch_size * config .max_seqlen_kv ,
196
+ config .kv_lora_rank or 0 ,
197
+ )
198
+ k_pos_emb_shape = (
199
+ config .batch_size * config .max_seqlen_kv ,
200
+ config .qk_pos_emb_head_dim or 0 ,
201
+ )
174
202
attn_output_shape = (
175
203
config .batch_size * config .max_seqlen_q ,
176
204
config .num_heads * config .head_dim_v ,
@@ -193,9 +221,35 @@ def run_dpa_with_cp(
193
221
else :
194
222
assert False , f"{ qkv_format } is an unsupported qkv_format!"
195
223
196
- q = torch .randn (q_input_shape , dtype = dtypes [dtype ]).cuda ()
197
- k = torch .randn (k_input_shape , dtype = dtypes [dtype ]).cuda ()
198
- v = torch .randn (v_input_shape , dtype = dtypes [dtype ]).cuda ()
224
+ q = torch .randn (q_input_shape , dtype = dtypes [dtype ]).cuda ().requires_grad_ (True )
225
+ if config .kv_lora_rank is None :
226
+ k = torch .randn (k_input_shape , dtype = dtypes [dtype ]).cuda ().requires_grad_ (True )
227
+ v = torch .randn (v_input_shape , dtype = dtypes [dtype ]).cuda ().requires_grad_ (True )
228
+ kv_up_proj = None
229
+ kv_compressed , k_pos_emb = None , None
230
+ else :
231
+ kv_compressed = torch .randn (kv_compressed_shape , dtype = dtypes [dtype ]).cuda ().requires_grad_ (True )
232
+ k_pos_emb = torch .randn (k_pos_emb_shape , dtype = dtypes [dtype ]).cuda ().requires_grad_ (True )
233
+ head_dim_k_no_pe = config .head_dim_qk - config .qk_pos_emb_head_dim
234
+ linear = torch .nn .Linear (
235
+ config .kv_lora_rank ,
236
+ config .num_heads * (head_dim_k_no_pe + config .head_dim_v ),
237
+ bias = False
238
+ ).cuda ().to (dtypes [dtype ])
239
+ def kv_up_proj (kv_compressed , k_pos_emb ):
240
+ kv = linear (kv_compressed ).view (* kv_compressed .shape [:- 1 ], config .num_heads , - 1 )
241
+ k_no_pe , v = torch .split (kv , [head_dim_k_no_pe , config .head_dim_v ], dim = - 1 )
242
+ k_pos_emb = torch .unsqueeze (k_pos_emb , - 2 )
243
+ if k_pos_emb .ndim == 5 :
244
+ k_pos_emb = k_pos_emb .expand (- 1 , - 1 , - 1 , config .num_heads , - 1 )
245
+ elif k_pos_emb .ndim == 4 :
246
+ k_pos_emb = k_pos_emb .expand (- 1 , - 1 , config .num_heads , - 1 )
247
+ else :
248
+ assert k_pos_emb .ndim == 3 , f"{ k_pos_emb .shape = } is not supported!"
249
+ k_pos_emb = k_pos_emb .expand (- 1 , config .num_heads , - 1 )
250
+ k = torch .cat ([k_no_pe , k_pos_emb ], dim = - 1 )
251
+ return k , v
252
+ k , v = kv_up_proj (kv_compressed , k_pos_emb )
199
253
dout = torch .randn (attn_output_shape , dtype = dtypes [dtype ]).cuda ()
200
254
dout_quantizer = Float8Quantizer (
201
255
fp8_dtype = tex .DType .kFloat8E5M2 ,
@@ -211,9 +265,6 @@ def run_dpa_with_cp(
211
265
bias = None
212
266
213
267
# run core_attn without CP
214
- for x in [q , k , v ]:
215
- x .requires_grad = True
216
-
217
268
if dtype == "fp8" :
218
269
fp8_context = fp8_autocast (enabled = True , fp8_recipe = fp8_recipe , fp8_group = cp_comm_group )
219
270
else :
@@ -238,38 +289,61 @@ def run_dpa_with_cp(
238
289
out .backward (dout )
239
290
240
291
# run core_attn wit CP
241
- q_ , k_ , v_ , dout_ , * rest = [
242
- x .clone ().detach () for x in [q , k , v , dout ] + ([] if bias is None else [bias ])
292
+ q_ , dout_ , * rest = [
293
+ x .clone ().detach () for x in [q , dout ] + ([] if bias is None else [bias ])
243
294
]
295
+ if config .kv_lora_rank is None :
296
+ k_ = k .clone ().detach ()
297
+ v_ = v .clone ().detach ()
298
+ kv_compressed_ = None
299
+ k_pos_emb_ = None
300
+ else :
301
+ k_ = None
302
+ v_ = None
303
+ kv_compressed_ = kv_compressed .clone ().detach ()
304
+ k_pos_emb_ = k_pos_emb .clone ().detach ()
244
305
bias_ = rest [0 ] if len (rest ) else None
245
306
if qkv_format == "bshd" or qkv_format == "sbhd" :
246
307
seq_dim = qkv_format .index ("s" )
247
- q_ , k_ , v_ , dout_ = [
308
+ q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ , dout_ = [
248
309
x .view (
249
310
* x .shape [:seq_dim ],
250
311
2 * world_size ,
251
312
x .shape [seq_dim ] // (2 * world_size ),
252
313
* x .shape [(seq_dim + 1 ) :],
253
- )
254
- for x in [q_ , k_ , v_ , dout_ ]
314
+ ) if x is not None else None
315
+ for x in [q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ , dout_ ]
255
316
]
256
317
seq_idx = torch .tensor ([rank , 2 * world_size - rank - 1 ], device = q_ .device )
257
- q_ , k_ , v_ , dout_ = [x .index_select (seq_dim , seq_idx ) for x in [q_ , k_ , v_ , dout_ ]]
258
- q_ , k_ , v_ , dout_ = [
259
- x .view (* x .shape [:seq_dim ], - 1 , * x .shape [(seq_dim + 2 ) :]) for x in [q_ , k_ , v_ , dout_ ]
318
+ q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ , dout_ = [
319
+ x .index_select (seq_dim , seq_idx ) if x is not None else None
320
+ for x in [q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ , dout_ ]
321
+ ]
322
+ q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ , dout_ = [
323
+ x .view (* x .shape [:seq_dim ], - 1 , * x .shape [(seq_dim + 2 ) :]) if x is not None else None
324
+ for x in [q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ , dout_ ]
260
325
]
261
326
elif qkv_format == "thd" :
262
327
seq_idx_q = tex .thd_get_partitioned_indices (
263
- cu_seqlens_q_padded , q_ . shape [0 ], world_size , rank
328
+ cu_seqlens_q_padded , q_input_shape [0 ], world_size , rank
264
329
)
265
330
seq_idx_kv = tex .thd_get_partitioned_indices (
266
- cu_seqlens_kv_padded , k_ . shape [0 ], world_size , rank
331
+ cu_seqlens_kv_padded , k_input_shape [0 ], world_size , rank
267
332
)
268
333
q_ , dout_ = [x .index_select (0 , seq_idx_q ) for x in [q_ , dout_ ]]
269
- k_ , v_ = [x .index_select (0 , seq_idx_kv ) for x in [k_ , v_ ]]
334
+ if config .kv_lora_rank is None :
335
+ k_ , v_ = [x .index_select (0 , seq_idx_kv ) for x in [k_ , v_ ]]
336
+ else :
337
+ kv_compressed_ , k_pos_emb_ = [
338
+ x .index_select (0 , seq_idx_kv ) for x in [kv_compressed_ , k_pos_emb_ ]
339
+ ]
340
+
270
341
else :
271
342
assert False , f"{ qkv_format } is an unsupported qkv_format!"
272
- q_ , k_ , v_ = [x .requires_grad_ () for x in [q_ , k_ , v_ ]]
343
+ q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ = [
344
+ x .requires_grad_ () if x is not None else None
345
+ for x in [q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ ]
346
+ ]
273
347
if bias_ is not None :
274
348
bias_ = bias_ .view (
275
349
* bias_ .shape [:- 2 ], 2 * world_size , bias_ .shape [- 2 ] // (2 * world_size ), bias_ .shape [- 1 ]
@@ -300,6 +374,9 @@ def run_dpa_with_cp(
300
374
cu_seqlens_kv = cu_seqlens_kv ,
301
375
cu_seqlens_q_padded = cu_seqlens_q_padded ,
302
376
cu_seqlens_kv_padded = cu_seqlens_kv_padded ,
377
+ kv_compressed = kv_compressed_ ,
378
+ k_pos_emb = k_pos_emb_ ,
379
+ kv_up_proj_fn = kv_up_proj ,
303
380
)
304
381
if fp8_mha :
305
382
dout_fp8_ = dout_quantizer (dout_ )
@@ -313,30 +390,53 @@ def run_dpa_with_cp(
313
390
out = out .dequantize ()
314
391
out_ = out_ .dequantize ()
315
392
316
- for x in [out_ , q_ .grad , k_ .grad , v_ .grad ]:
393
+ for x in [out_ , q_ .grad ] + (
394
+ [k_ .grad , v_ .grad ]
395
+ if config .kv_lora_rank is None else
396
+ [kv_compressed_ .grad , k_pos_emb_ .grad ]
397
+ ):
317
398
assert torch .all (~ torch .isnan (x ))
318
399
assert torch .all (~ torch .isinf (x ))
319
400
320
401
# compare results with and without CP
321
402
if qkv_format == "bshd" or qkv_format == "sbhd" :
322
- dq , dk , dv , out = [
403
+ dq , dk , dv , dkv_compressed , dk_pos_emb = [
404
+ x .grad if x is not None else None
405
+ for x in [q , k , v , kv_compressed , k_pos_emb ]
406
+ ]
407
+ dq , dk , dv , dkv_compressed , dk_pos_emb , out = [
323
408
x .view (
324
409
* x .shape [:seq_dim ],
325
410
2 * world_size ,
326
411
x .shape [seq_dim ] // (2 * world_size ),
327
412
* x .shape [(seq_dim + 1 ) :],
328
- )
329
- for x in [q .grad , k .grad , v .grad , out ]
413
+ ).index_select (seq_dim , seq_idx )
414
+ if x is not None else None
415
+ for x in [dq , dk , dv , dkv_compressed , dk_pos_emb , out ]
416
+ ]
417
+ dq_ , dk_ , dv_ , dkv_compressed_ , dk_pos_emb_ = [
418
+ x .grad if x is not None else None
419
+ for x in [q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ ]
330
420
]
331
- dq , dk , dv , out = [x .index_select (seq_dim , seq_idx ) for x in [dq , dk , dv , out ]]
332
- dq_ , dk_ , dv_ , out_ = [
421
+ dq_ , dk_ , dv_ , dkv_compressed_ , dk_pos_emb_ , out_ = [
333
422
x .view (* x .shape [:seq_dim ], 2 , x .shape [seq_dim ] // 2 , * x .shape [(seq_dim + 1 ) :])
334
- for x in [q_ .grad , k_ .grad , v_ .grad , out_ ]
423
+ if x is not None else None
424
+ for x in [dq_ , dk_ , dv_ , dkv_compressed_ , dk_pos_emb_ , out_ ]
335
425
]
336
426
elif qkv_format == "thd" :
337
427
dq , out = [x .index_select (0 , seq_idx_q ).contiguous () for x in [q .grad , out ]]
338
- dk , dv = [x .index_select (0 , seq_idx_kv ).contiguous () for x in [k .grad , v .grad ]]
339
- dq_ , dk_ , dv_ , out_ = [q_ .grad , k_ .grad , v_ .grad , out_ ]
428
+ if config .kv_lora_rank is None :
429
+ dk , dv = [x .index_select (0 , seq_idx_kv ).contiguous () for x in [k .grad , v .grad ]]
430
+ dkv_compressed , dk_pos_emb = None , None
431
+ else :
432
+ dk , dv = None , None
433
+ dkv_compressed , dk_pos_emb = [
434
+ x .index_select (0 , seq_idx_kv ).contiguous () for x in [kv_compressed .grad , k_pos_emb .grad ]
435
+ ]
436
+ dq_ , dk_ , dv_ , dkv_compressed_ , dk_pos_emb_ = [
437
+ x .grad if x is not None else None
438
+ for x in [q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ ]
439
+ ]
340
440
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
341
441
cu_seqlens_q = get_cu_seqlens_on_cp_rank (
342
442
cu_seqlens_q , cu_seqlens_q_padded , world_size , rank , True , True
@@ -359,7 +459,9 @@ def run_dpa_with_cp(
359
459
)
360
460
cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv
361
461
num_pads_kv = cu_pads_kv [1 :] - cu_pads_kv [:- 1 ]
362
- for x in [dk , dv , dk_ , dv_ ]:
462
+ for x in [dk , dv , dk_ , dv_ , dkv_compressed , dk_pos_emb , dkv_compressed_ , dk_pos_emb_ ]:
463
+ if x is None :
464
+ continue
363
465
assert torch .count_nonzero (x [cu_seqlens_kv_padded [- 1 ] :]).item () == 0
364
466
for b in range (config .batch_size ):
365
467
assert (
@@ -392,34 +494,61 @@ def run_dpa_with_cp(
392
494
def _rmse (a , b ):
393
495
return torch .sqrt ((a - b ).square ().mean ()).item ()
394
496
395
- def _error (a , b ):
497
+ def _error (a , b , tensor_name ):
396
498
if dtype != "fp8" :
397
- torch .testing .assert_close (a , b , ** tols )
499
+ try :
500
+ torch .testing .assert_close (a , b , ** tols )
501
+ except Exception as e :
502
+ logging .debug (f"{ tensor_name } is not close.\n { e } " )
503
+ raise e
398
504
else :
399
505
try :
400
506
torch .testing .assert_close (a , b , ** tols )
401
507
except Exception as e :
402
- logging .debug (e )
508
+ logging .debug (f" { tensor_name } is not close. \n { e } " )
403
509
404
510
rmse = _rmse (a , b )
405
511
rmse_range = max (a .max ().item (), b .max ().item ()) - min (a .min ().item (), b .min ().item ())
406
512
assert (
407
513
rmse < rmse_tol * rmse_range
408
- ), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})" .format (
409
- rmse , rmse_tol * rmse_range , rmse_tol , rmse_range
514
+ ), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f}) for {} " .format (
515
+ rmse , rmse_tol * rmse_range , rmse_tol , rmse_range , tensor_name
410
516
)
411
517
412
518
if qkv_format == "bshd" :
413
- for a , b in zip ([out_ , dq_ , dk_ , dv_ ], [out , dq , dk , dv ]):
414
- _error (a [:, 0 ], b [:, 0 ])
415
- _error (a [:, 1 ], b [:, 1 ])
519
+ for tensor_name , a , b in zip (
520
+ ["out" , "dq" , "dk" , "dv" , "dkv_compressed" , "dk_pos_emb" ],
521
+ [out_ , dq_ , dk_ , dv_ , dkv_compressed_ , dk_pos_emb_ ],
522
+ [out , dq , dk , dv , dkv_compressed , dk_pos_emb ]
523
+ ):
524
+ if a is None or b is None :
525
+ a_is_None = "is" if a is None else "is not"
526
+ b_is_None = "is" if b is None else "is not"
527
+ assert a is None and b is None , f"{ tensor_name } _ { a_is_None } None and { tensor_name } { b_is_None } None!"
528
+ continue
529
+ _error (a [:, 0 ], b [:, 0 ], tensor_name )
530
+ _error (a [:, 1 ], b [:, 1 ], tensor_name )
416
531
elif qkv_format == "sbhd" :
417
- for a , b in zip ([out_ , dq_ , dk_ , dv_ ], [out , dq , dk , dv ]):
418
- _error (a [0 ], b [0 ])
419
- _error (a [1 ], b [1 ])
532
+ for tensor_name , a , b in zip (
533
+ ["out" , "dq" , "dk" , "dv" , "dkv_compressed" , "dk_pos_emb" ],
534
+ [out_ , dq_ , dk_ , dv_ , dkv_compressed_ , dk_pos_emb_ ],
535
+ [out , dq , dk , dv , dkv_compressed , dk_pos_emb ]
536
+ ):
537
+ if a is None or b is None :
538
+ assert a is None and b is None , f"{ tensor_name } and { tensor_name } _ are not both None!"
539
+ continue
540
+ _error (a [0 ], b [0 ], tensor_name )
541
+ _error (a [1 ], b [1 ], tensor_name )
420
542
elif qkv_format == "thd" :
421
- for a , b in zip ([out_ , dq_ , dk_ , dv_ ], [out , dq , dk , dv ]):
422
- _error (a , b )
543
+ for tensor_name , a , b in zip (
544
+ ["out" , "dq" , "dk" , "dv" , "dkv_compressed" , "dk_pos_emb" ],
545
+ [out_ , dq_ , dk_ , dv_ , dkv_compressed_ , dk_pos_emb_ ],
546
+ [out , dq , dk , dv , dkv_compressed , dk_pos_emb ]
547
+ ):
548
+ if a is None or b is None :
549
+ assert a is None and b is None , f"{ tensor_name } and { tensor_name } _ are not both None!"
550
+ continue
551
+ _error (a , b , tensor_name )
423
552
else :
424
553
assert False , f"{ qkv_format } is an unsupported qkv_format!"
425
554
0 commit comments