@@ -331,6 +331,12 @@ struct MHAToFlashAttention
331331 ValueRange{curSumSlice, rescaledPrevSumSlice},
332332 ValueRange{reducedShapeOut})
333333 .getResult (0 );
334+ Value newSumSliceRecip =
335+ rewriter
336+ .create <linalg::ReciprocalOp>(loc, reducedShapeOut.getType (),
337+ ValueRange{newSumSlice},
338+ ValueRange{reducedShapeOut})
339+ .getResult (0 );
334340 SmallVector<int64_t > VShape{cfg.RowBlockSize , headDim};
335341 Value VShapeOut = rewriter.create <tensor::EmptyOp>(loc, VShape, dtype);
336342 Value matmulVOutFilled =
@@ -341,38 +347,40 @@ struct MHAToFlashAttention
341347 ValueRange{PSlice, collapsedVSlice},
342348 ValueRange{matmulVOutFilled})
343349 .getResult (0 );
344- Value expMaxDiffBroadcasted =
350+ Value newSumSliceRecipBroadcasted =
345351 rewriter
346- .create <linalg::BroadcastOp>(loc, expMaxDiff , VShapeOut,
352+ .create <linalg::BroadcastOp>(loc, newSumSliceRecip , VShapeOut,
347353 SmallVector<int64_t >{1 })
348354 .getResults ()[0 ];
349- Value expMaxDiffBroadcastedEps =
355+ Value rescaledPrevSumSliceBroadcasted =
350356 rewriter
351- .create <linalg::GenericOp>(
352- loc, VShapeOut.getType (), ValueRange{expMaxDiffBroadcasted},
353- ValueRange{VShapeOut}, indexingMaps,
354- SmallVector<utils::IteratorType>(2 ,
355- utils::IteratorType::parallel),
356- [&](OpBuilder &nestedBuilder, Location nestedLoc,
357- ValueRange args) {
358- Value eps = nestedBuilder.create <arith::ConstantOp>(
359- loc, nestedBuilder.getFloatAttr (dtype, 1e-9 ));
360- Value added =
361- nestedBuilder.create <arith::AddFOp>(loc, args[0 ], eps);
362- nestedBuilder.create <linalg::YieldOp>(nestedLoc, added);
363- })
357+ .create <linalg::BroadcastOp>(loc, rescaledPrevSumSlice, VShapeOut,
358+ SmallVector<int64_t >{1 })
359+ .getResults ()[0 ];
360+ Value rescaledMatmulV =
361+ rewriter
362+ .create <linalg::MulOp>(
363+ loc, matmulVOutFilled.getType (),
364+ ValueRange{matmulV, newSumSliceRecipBroadcasted},
365+ ValueRange{matmulVOutFilled})
366+ .getResult (0 );
367+ Value sumSliceQuotient =
368+ rewriter
369+ .create <linalg::MulOp>(loc, matmulVOutFilled.getType (),
370+ ValueRange{rescaledPrevSumSliceBroadcasted,
371+ newSumSliceRecipBroadcasted},
372+ ValueRange{matmulVOutFilled})
364373 .getResult (0 );
365374 Value rescaledOSlice =
366375 rewriter
367- .create <linalg::DivOp>(
368- loc, VShapeOut.getType (),
369- ValueRange{prevOSlice, expMaxDiffBroadcastedEps},
370- ValueRange{VShapeOut})
376+ .create <linalg::MulOp>(loc, matmulVOutFilled.getType (),
377+ ValueRange{prevOSlice, sumSliceQuotient},
378+ ValueRange{matmulVOutFilled})
371379 .getResult (0 );
372380 Value newOSlice =
373381 rewriter
374382 .create <linalg::AddOp>(loc, VShapeOut.getType (),
375- ValueRange{rescaledOSlice, matmulV },
383+ ValueRange{rescaledOSlice, rescaledMatmulV },
376384 ValueRange{VShapeOut})
377385 .getResult (0 );
378386 // yield all the results of the innermost loop.
@@ -381,25 +389,7 @@ struct MHAToFlashAttention
381389 // yield rowBlockLoop results
382390 rewriter.setInsertionPointToEnd (rowBlockLoop.getBody ());
383391 auto innermostLoopResults = columnBlockLoop->getResults ();
384- Value OSliceFinal = innermostLoopResults[0 ],
385- sumSliceFinal = innermostLoopResults[2 ];
386- Value sliceShapeOut =
387- rewriter.create <tensor::EmptyOp>(loc, reducedShape, dtype);
388- Value broadcastedSliceShapeOut =
389- rewriter.create <tensor::EmptyOp>(loc, VShape, dtype);
390- Value sumSliceFinalBroadcasted =
391- rewriter
392- .create <linalg::BroadcastOp>(loc, sumSliceFinal,
393- broadcastedSliceShapeOut,
394- SmallVector<int64_t >{1 })
395- .getResults ()[0 ];
396- Value rescaledOSliceFinal =
397- rewriter
398- .create <linalg::DivOp>(
399- loc, broadcastedSliceShapeOut.getType (),
400- ValueRange{OSliceFinal, sumSliceFinalBroadcasted},
401- ValueRange{broadcastedSliceShapeOut})
402- .getResult (0 );
392+ Value OSliceFinal = innermostLoopResults[0 ];
403393 SmallVector<OpFoldResult> outputOffsets;
404394 outputOffsets.push_back (getAsOpFoldResult (ivs[0 ]));
405395 outputOffsets.push_back (getAsOpFoldResult (ivs[1 ]));
@@ -409,8 +399,8 @@ struct MHAToFlashAttention
409399 outputSizes[2 ] = rewriter.getIndexAttr (cfg.RowBlockSize );
410400 outputSizes[3 ] = rewriter.getIndexAttr (headDim);
411401 Value insertedRescaledOSlice = rewriter.create <tensor::InsertSliceOp>(
412- loc, rescaledOSliceFinal , rowBlockLoop.getRegionIterArgs ()[0 ],
413- outputOffsets, outputSizes, strides);
402+ loc, OSliceFinal , rowBlockLoop.getRegionIterArgs ()[0 ], outputOffsets ,
403+ outputSizes, strides);
414404 rewriter.create <scf::YieldOp>(loc, ValueRange{insertedRescaledOSlice});
415405 // Add the scf.yield operations for all the outer loops.
416406 for (auto [outerLoop, innerLoop] :
0 commit comments