@@ -228,14 +228,21 @@ def run_dpa_with_cp(
228228        kv_up_proj  =  None 
229229        kv_compressed , k_pos_emb  =  None , None 
230230    else :
231-         kv_compressed  =  torch .randn (kv_compressed_shape , dtype = dtypes [dtype ]).cuda ().requires_grad_ (True )
231+         kv_compressed  =  (
232+             torch .randn (kv_compressed_shape , dtype = dtypes [dtype ]).cuda ().requires_grad_ (True )
233+         )
232234        k_pos_emb  =  torch .randn (k_pos_emb_shape , dtype = dtypes [dtype ]).cuda ().requires_grad_ (True )
233235        head_dim_k_no_pe  =  config .head_dim_qk  -  config .qk_pos_emb_head_dim 
234-         linear  =  torch .nn .Linear (
235-             config .kv_lora_rank ,
236-             config .num_heads  *  (head_dim_k_no_pe  +  config .head_dim_v ),
237-             bias = False 
238-         ).cuda ().to (dtypes [dtype ])
236+         linear  =  (
237+             torch .nn .Linear (
238+                 config .kv_lora_rank ,
239+                 config .num_heads  *  (head_dim_k_no_pe  +  config .head_dim_v ),
240+                 bias = False ,
241+             )
242+             .cuda ()
243+             .to (dtypes [dtype ])
244+         )
245+ 
239246        def  kv_up_proj (kv_compressed , k_pos_emb ):
240247            kv  =  linear (kv_compressed ).view (* kv_compressed .shape [:- 1 ], config .num_heads , - 1 )
241248            k_no_pe , v  =  torch .split (kv , [head_dim_k_no_pe , config .head_dim_v ], dim = - 1 )
@@ -249,6 +256,7 @@ def kv_up_proj(kv_compressed, k_pos_emb):
249256                k_pos_emb  =  k_pos_emb .expand (- 1 , config .num_heads , - 1 )
250257            k  =  torch .cat ([k_no_pe , k_pos_emb ], dim = - 1 )
251258            return  k , v 
259+ 
252260        k , v  =  kv_up_proj (kv_compressed , k_pos_emb )
253261    dout  =  torch .randn (attn_output_shape , dtype = dtypes [dtype ]).cuda ()
254262    dout_quantizer  =  Float8Quantizer (
@@ -289,9 +297,7 @@ def kv_up_proj(kv_compressed, k_pos_emb):
289297            out .backward (dout )
290298
291299    # run core_attn wit CP 
292-     q_ , dout_ , * rest  =  [
293-         x .clone ().detach () for  x  in  [q , dout ] +  ([] if  bias  is  None  else  [bias ])
294-     ]
300+     q_ , dout_ , * rest  =  [x .clone ().detach () for  x  in  [q , dout ] +  ([] if  bias  is  None  else  [bias ])]
295301    if  config .kv_lora_rank  is  None :
296302        k_  =  k .clone ().detach ()
297303        v_  =  v .clone ().detach ()
@@ -306,12 +312,16 @@ def kv_up_proj(kv_compressed, k_pos_emb):
306312    if  qkv_format  ==  "bshd"  or  qkv_format  ==  "sbhd" :
307313        seq_dim  =  qkv_format .index ("s" )
308314        q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ , dout_  =  [
309-             x .view (
310-                 * x .shape [:seq_dim ],
311-                 2  *  world_size ,
312-                 x .shape [seq_dim ] //  (2  *  world_size ),
313-                 * x .shape [(seq_dim  +  1 ) :],
314-             ) if  x  is  not   None  else  None 
315+             (
316+                 x .view (
317+                     * x .shape [:seq_dim ],
318+                     2  *  world_size ,
319+                     x .shape [seq_dim ] //  (2  *  world_size ),
320+                     * x .shape [(seq_dim  +  1 ) :],
321+                 )
322+                 if  x  is  not   None 
323+                 else  None 
324+             )
315325            for  x  in  [q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ , dout_ ]
316326        ]
317327        seq_idx  =  torch .tensor ([rank , 2  *  world_size  -  rank  -  1 ], device = q_ .device )
@@ -392,35 +402,39 @@ def kv_up_proj(kv_compressed, k_pos_emb):
392402
393403    for  x  in  [out_ , q_ .grad ] +  (
394404        [k_ .grad , v_ .grad ]
395-         if  config .kv_lora_rank  is  None   else 
396-         [kv_compressed_ .grad , k_pos_emb_ .grad ]
405+         if  config .kv_lora_rank  is  None 
406+         else   [kv_compressed_ .grad , k_pos_emb_ .grad ]
397407    ):
398408        assert  torch .all (~ torch .isnan (x ))
399409        assert  torch .all (~ torch .isinf (x ))
400410
401411    # compare results with and without CP 
402412    if  qkv_format  ==  "bshd"  or  qkv_format  ==  "sbhd" :
403413        dq , dk , dv , dkv_compressed , dk_pos_emb  =  [
404-             x .grad  if  x  is  not   None  else  None 
405-             for  x  in  [q , k , v , kv_compressed , k_pos_emb ]
414+             x .grad  if  x  is  not   None  else  None  for  x  in  [q , k , v , kv_compressed , k_pos_emb ]
406415        ]
407416        dq , dk , dv , dkv_compressed , dk_pos_emb , out  =  [
408-             x .view (
409-                 * x .shape [:seq_dim ],
410-                 2  *  world_size ,
411-                 x .shape [seq_dim ] //  (2  *  world_size ),
412-                 * x .shape [(seq_dim  +  1 ) :],
413-             ).index_select (seq_dim , seq_idx )
414-             if  x  is  not   None  else  None 
417+             (
418+                 x .view (
419+                     * x .shape [:seq_dim ],
420+                     2  *  world_size ,
421+                     x .shape [seq_dim ] //  (2  *  world_size ),
422+                     * x .shape [(seq_dim  +  1 ) :],
423+                 ).index_select (seq_dim , seq_idx )
424+                 if  x  is  not   None 
425+                 else  None 
426+             )
415427            for  x  in  [dq , dk , dv , dkv_compressed , dk_pos_emb , out ]
416428        ]
417429        dq_ , dk_ , dv_ , dkv_compressed_ , dk_pos_emb_  =  [
418-             x .grad  if  x  is  not   None  else  None 
419-             for  x  in  [q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ ]
430+             x .grad  if  x  is  not   None  else  None  for  x  in  [q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ ]
420431        ]
421432        dq_ , dk_ , dv_ , dkv_compressed_ , dk_pos_emb_ , out_  =  [
422-             x .view (* x .shape [:seq_dim ], 2 , x .shape [seq_dim ] //  2 , * x .shape [(seq_dim  +  1 ) :])
423-             if  x  is  not   None  else  None 
433+             (
434+                 x .view (* x .shape [:seq_dim ], 2 , x .shape [seq_dim ] //  2 , * x .shape [(seq_dim  +  1 ) :])
435+                 if  x  is  not   None 
436+                 else  None 
437+             )
424438            for  x  in  [dq_ , dk_ , dv_ , dkv_compressed_ , dk_pos_emb_ , out_ ]
425439        ]
426440    elif  qkv_format  ==  "thd" :
@@ -431,11 +445,11 @@ def kv_up_proj(kv_compressed, k_pos_emb):
431445        else :
432446            dk , dv  =  None , None 
433447            dkv_compressed , dk_pos_emb  =  [
434-                 x .index_select (0 , seq_idx_kv ).contiguous () for  x  in  [kv_compressed .grad , k_pos_emb .grad ]
448+                 x .index_select (0 , seq_idx_kv ).contiguous ()
449+                 for  x  in  [kv_compressed .grad , k_pos_emb .grad ]
435450            ]
436451        dq_ , dk_ , dv_ , dkv_compressed_ , dk_pos_emb_  =  [
437-             x .grad  if  x  is  not   None  else  None 
438-             for  x  in  [q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ ]
452+             x .grad  if  x  is  not   None  else  None  for  x  in  [q_ , k_ , v_ , kv_compressed_ , k_pos_emb_ ]
439453        ]
440454        cu_seqlens_q_padded  =  cu_seqlens_q_padded  //  world_size 
441455        cu_seqlens_q  =  get_cu_seqlens_on_cp_rank (
@@ -519,34 +533,40 @@ def _error(a, b, tensor_name):
519533        for  tensor_name , a , b  in  zip (
520534            ["out" , "dq" , "dk" , "dv" , "dkv_compressed" , "dk_pos_emb" ],
521535            [out_ , dq_ , dk_ , dv_ , dkv_compressed_ , dk_pos_emb_ ],
522-             [out , dq , dk , dv , dkv_compressed , dk_pos_emb ]
536+             [out , dq , dk , dv , dkv_compressed , dk_pos_emb ], 
523537        ):
524538            if  a  is  None  or  b  is  None :
525539                a_is_None  =  "is"  if  a  is  None  else  "is not" 
526540                b_is_None  =  "is"  if  b  is  None  else  "is not" 
527-                 assert  a  is  None  and  b  is  None , f"{ tensor_name }  _ { a_is_None }   None and { tensor_name }   { b_is_None }   None!" 
541+                 assert  (
542+                     a  is  None  and  b  is  None 
543+                 ), f"{ tensor_name }  _ { a_is_None }   None and { tensor_name }   { b_is_None }   None!" 
528544                continue 
529545            _error (a [:, 0 ], b [:, 0 ], tensor_name )
530546            _error (a [:, 1 ], b [:, 1 ], tensor_name )
531547    elif  qkv_format  ==  "sbhd" :
532548        for  tensor_name , a , b  in  zip (
533549            ["out" , "dq" , "dk" , "dv" , "dkv_compressed" , "dk_pos_emb" ],
534550            [out_ , dq_ , dk_ , dv_ , dkv_compressed_ , dk_pos_emb_ ],
535-             [out , dq , dk , dv , dkv_compressed , dk_pos_emb ]
551+             [out , dq , dk , dv , dkv_compressed , dk_pos_emb ], 
536552        ):
537553            if  a  is  None  or  b  is  None :
538-                 assert  a  is  None  and  b  is  None , f"{ tensor_name }   and { tensor_name }  _ are not both None!" 
554+                 assert  (
555+                     a  is  None  and  b  is  None 
556+                 ), f"{ tensor_name }   and { tensor_name }  _ are not both None!" 
539557                continue 
540558            _error (a [0 ], b [0 ], tensor_name )
541559            _error (a [1 ], b [1 ], tensor_name )
542560    elif  qkv_format  ==  "thd" :
543561        for  tensor_name , a , b  in  zip (
544562            ["out" , "dq" , "dk" , "dv" , "dkv_compressed" , "dk_pos_emb" ],
545563            [out_ , dq_ , dk_ , dv_ , dkv_compressed_ , dk_pos_emb_ ],
546-             [out , dq , dk , dv , dkv_compressed , dk_pos_emb ]
564+             [out , dq , dk , dv , dkv_compressed , dk_pos_emb ], 
547565        ):
548566            if  a  is  None  or  b  is  None :
549-                 assert  a  is  None  and  b  is  None , f"{ tensor_name }   and { tensor_name }  _ are not both None!" 
567+                 assert  (
568+                     a  is  None  and  b  is  None 
569+                 ), f"{ tensor_name }   and { tensor_name }  _ are not both None!" 
550570                continue 
551571            _error (a , b , tensor_name )
552572    else :
0 commit comments