@@ -162,9 +162,12 @@ struct MHAToFlashAttention
162162 rewriter.create <linalg::FillOp>(loc, minusInf, maxSlice).getResult (0 );
163163 Value sumSliceFilled =
164164 rewriter.create <linalg::FillOp>(loc, zero, sumSlice).getResult (0 );
165+ Value collapsedOSliceFilled =
166+ rewriter.create <linalg::FillOp>(loc, zero, collapsedOSlice)
167+ .getResult (0 );
165168 // create the innermost for loop for columnBlock
166169 SmallVector<Value> innermostDestinationTensors{
167- collapsedOSlice , maxSliceFilled, sumSliceFilled};
170+ collapsedOSliceFilled , maxSliceFilled, sumSliceFilled};
168171 auto columnBlockLoop = rewriter.create <scf::ForOp>(
169172 loc,
170173 getValueOrCreateConstantIndexOp (
@@ -241,9 +244,9 @@ struct MHAToFlashAttention
241244 ValueRange args) {
242245 Value constant = nestedBuilder.create <arith::ConstantOp>(
243246 loc, nestedBuilder.getFloatAttr (dtype, rsqrtHead));
244- Value added = nestedBuilder.create <arith::MulFOp>(
247+ Value scaled = nestedBuilder.create <arith::MulFOp>(
245248 loc, args[0 ], constant);
246- nestedBuilder.create <linalg::YieldOp>(nestedLoc, added );
249+ nestedBuilder.create <linalg::YieldOp>(nestedLoc, scaled );
247250 })
248251 .getResult (0 );
249252 Value add = rewriter
@@ -338,22 +341,32 @@ struct MHAToFlashAttention
338341 ValueRange{PSlice, collapsedVSlice},
339342 ValueRange{matmulVOutFilled})
340343 .getResult (0 );
341- Value expMaxDiffRecip =
342- rewriter
343- .create <linalg::ReciprocalOp>(loc, reducedShapeOut.getType (),
344- ValueRange{expMaxDiff},
345- ValueRange{reducedShapeOut})
346- .getResult (0 );
347- Value expMaxDiffRecipBroadcasted =
344+ Value expMaxDiffBroadcasted =
348345 rewriter
349- .create <linalg::BroadcastOp>(loc, expMaxDiffRecip , VShapeOut,
346+ .create <linalg::BroadcastOp>(loc, expMaxDiff , VShapeOut,
350347 SmallVector<int64_t >{1 })
351348 .getResults ()[0 ];
349+ Value expMaxDiffBroadcastedEps =
350+ rewriter
351+ .create <linalg::GenericOp>(
352+ loc, VShapeOut.getType (), ValueRange{expMaxDiffBroadcasted},
353+ ValueRange{VShapeOut}, indexingMaps,
354+ SmallVector<utils::IteratorType>(2 ,
355+ utils::IteratorType::parallel),
356+ [&](OpBuilder &nestedBuilder, Location nestedLoc,
357+ ValueRange args) {
358+ Value eps = nestedBuilder.create <arith::ConstantOp>(
359+ loc, nestedBuilder.getFloatAttr (dtype, 1e-9 ));
360+ Value added =
361+ nestedBuilder.create <arith::AddFOp>(loc, args[0 ], eps);
362+ nestedBuilder.create <linalg::YieldOp>(nestedLoc, added);
363+ })
364+ .getResult (0 );
352365 Value rescaledOSlice =
353366 rewriter
354- .create <linalg::MulOp >(
367+ .create <linalg::DivOp >(
355368 loc, VShapeOut.getType (),
356- ValueRange{prevOSlice, expMaxDiffRecipBroadcasted },
369+ ValueRange{prevOSlice, expMaxDiffBroadcastedEps },
357370 ValueRange{VShapeOut})
358371 .getResult (0 );
359372 Value newOSlice =
@@ -372,25 +385,19 @@ struct MHAToFlashAttention
372385 sumSliceFinal = innermostLoopResults[2 ];
373386 Value sliceShapeOut =
374387 rewriter.create <tensor::EmptyOp>(loc, reducedShape, dtype);
375- Value sumSliceFinalRecip =
376- rewriter
377- .create <linalg::ReciprocalOp>(loc, sliceShapeOut.getType (),
378- ValueRange{sumSliceFinal},
379- ValueRange{sliceShapeOut})
380- .getResult (0 );
381388 Value broadcastedSliceShapeOut =
382389 rewriter.create <tensor::EmptyOp>(loc, VShape, dtype);
383- Value sumSliceFinalRecipBroadcasted =
390+ Value sumSliceFinalBroadcasted =
384391 rewriter
385- .create <linalg::BroadcastOp>(loc, sumSliceFinalRecip ,
392+ .create <linalg::BroadcastOp>(loc, sumSliceFinal ,
386393 broadcastedSliceShapeOut,
387394 SmallVector<int64_t >{1 })
388395 .getResults ()[0 ];
389396 Value rescaledOSliceFinal =
390397 rewriter
391- .create <linalg::MulOp >(
398+ .create <linalg::DivOp >(
392399 loc, broadcastedSliceShapeOut.getType (),
393- ValueRange{sumSliceFinalRecipBroadcasted, OSliceFinal },
400+ ValueRange{OSliceFinal, sumSliceFinalBroadcasted },
394401 ValueRange{broadcastedSliceShapeOut})
395402 .getResult (0 );
396403 SmallVector<OpFoldResult> outputOffsets;
0 commit comments