@@ -275,19 +275,19 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
275275      copy (params.gmem_tiled_copy_q , tQgQ (_,_,_,k_tile), tQrQ);
276276      copy (gmem_tiled_copy_k, tKgK (_,_,_,k_tile), tKrK);
277277      if  constexpr  (is_fp8_v<ElementQ> && is_fp8_v<ElementK>) {
278-         auto  tCrQ_  = make_fragment_like<half_t >(tCrQ);
279-         convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_ );
280-         auto  tCrK_  = make_fragment_like<half_t >(tCrK);
281-         convert_FP8_to_FP16<ElementK>(tCrK, tCrK_ );
282-         cute::gemm (tiled_mma, accum, tCrQ_, tCrK_ , frag_src);
278+         auto  tCrQ_fp16  = make_fragment_like<half_t >(tCrQ);
279+         convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_fp16 );
280+         auto  tCrK_fp16  = make_fragment_like<half_t >(tCrK);
281+         convert_FP8_to_FP16<ElementK>(tCrK, tCrK_fp16 );
282+         cute::gemm (tiled_mma, accum, tCrQ_fp16, tCrK_fp16 , frag_src);
283283      } else  if  constexpr  (is_fp8_v<ElementQ> && !is_fp8_v<ElementK>) {
284-         auto  tCrQ_  = make_fragment_like<half_t >(tCrQ);
285-         convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_ );
286-         cute::gemm (tiled_mma, accum, tCrQ_  , tCrK, frag_src);
284+         auto  tCrQ_fp16  = make_fragment_like<half_t >(tCrQ);
285+         convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_fp16 );
286+         cute::gemm (tiled_mma, accum, tCrQ_fp16  , tCrK, frag_src);
287287      } else  if  constexpr  (!is_fp8_v<ElementQ> && is_fp8_v<ElementK>) {
288-         auto  tCrK_  = make_fragment_like<half_t >(tCrK);
289-         convert_FP8_to_FP16<ElementK>(tCrK, tCrK_ );
290-         cute::gemm (tiled_mma, accum, tCrQ , tCrK_ , frag_src);
288+         auto  tCrK_fp16  = make_fragment_like<half_t >(tCrK);
289+         convert_FP8_to_FP16<ElementK>(tCrK, tCrK_fp16 );
290+         cute::gemm (tiled_mma, accum, tCrQ , tCrK_fp16 , frag_src);
291291      } else  {
292292        cute::gemm (tiled_mma, accum, tCrQ , tCrK, frag_src);
293293      }
@@ -343,9 +343,9 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
343343    for (int  i = 0 ; i< tile_count; i++) {
344344      copy (gmem_tiled_copy_v, tVgV (_,_,_,i), tVrV);
345345      if  constexpr  (is_fp8_v<ElementV>) {
346-         auto  tCrV_  = make_fragment_like<half_t >(tCrV);
347-         convert_FP8_to_FP16<ElementV>(tCrV, tCrV_ );
348-         cute::gemm (tiled_mma, accum (_,_,_,i), tPr, tCrV_ , frag_src (_,_,_,i));
346+         auto  tCrV_fp16  = make_fragment_like<half_t >(tCrV);
347+         convert_FP8_to_FP16<ElementV>(tCrV, tCrV_fp16 );
348+         cute::gemm (tiled_mma, accum (_,_,_,i), tPr, tCrV_fp16 , frag_src (_,_,_,i));
349349      } else  {
350350        cute::gemm (tiled_mma, accum (_,_,_,i), tPr, tCrV, frag_src (_,_,_,i));
351351      }    
0 commit comments