@@ -272,7 +272,7 @@ static Operation *findParentFillOp(Value val) {
272272 !isa<linalg::FillOp>(currentOp)) {
273273 currentOp = currentOp->getResult (0 ).getDefiningOp ();
274274 }
275- if (isa<linalg::FillOp>(currentOp)) {
275+ if (currentOp && isa<linalg::FillOp>(currentOp)) {
276276 return currentOp;
277277 }
278278
@@ -322,11 +322,10 @@ static unsigned getOprandDim(linalg::LinalgOp &linalgOp, unsigned iteratorPos,
322322 return linalgOp.getShape (linalgOp.getDpsInputOperand (operandIdx))[dimPos];
323323}
324324
325- static LogicalResult setStaticSizeForExtractSliceOp (RewriterBase &rewriter,
326- Operation *op,
327- bool isExtract,
328- SmallVector<int64_t > size,
329- int shrinDimNum = 0 ) {
325+ static void setStaticSizeForExtractSliceOp (RewriterBase &rewriter,
326+ Operation *op, bool isExtract,
327+ SmallVector<int64_t > size,
328+ int shrinDimNum = 0 ) {
330329 OpBuilder::InsertionGuard guard (rewriter);
331330 rewriter.setInsertionPoint (op);
332331 if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
@@ -348,15 +347,12 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
348347 extractSlice, extractSlice.getSource (), mixedOffsets, mixedSizes,
349348 mixedStrides);
350349 }
351- } else {
352- return failure ();
353350 }
354- return mlir::success ();
355351}
356352
357- static LogicalResult setStaticSizeForInsertSliceOp (RewriterBase &rewriter,
358- Operation *op, Value source,
359- SmallVector<int64_t > size) {
353+ static void setStaticSizeForInsertSliceOp (RewriterBase &rewriter, Operation *op ,
354+ Value source,
355+ SmallVector<int64_t > size) {
360356 OpBuilder::InsertionGuard guard (rewriter);
361357 rewriter.setInsertionPoint (op);
362358 if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(op)) {
@@ -369,10 +365,7 @@ static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter,
369365 rewriter.replaceOpWithNewOp <tensor::InsertSliceOp>(
370366 insertSlice, source, insertSlice.getDest (), mixedOffsets, mixedSizes,
371367 mixedStrides);
372- } else {
373- return failure ();
374368 }
375- return success ();
376369}
377370
378371using InnermostFullResultCallBackFn = std::function<FailureOr<linalg::LinalgOp>(
@@ -691,7 +684,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
691684 linalg::LinalgOp originOp,
692685 linalg::LinalgOp currentOp,
693686 innerBodyGenerationOption &option) const {
694-
695687 mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc ()};
696688 auto operandDimTypes = getOprandDimType (originOp);
697689 auto cfg = MatmulConfigAnalysis (originOp.getOperation ()).getConfig ();
@@ -744,6 +736,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
744736 CInnermostDims =
745737 SmallVector<int64_t >{cfg.innerMostMBlock , cfg.innerMostNBlock };
746738 }
739+
747740 if (NDimNum > 1 ) {
748741 firstN = true ;
749742 firstK = true ;
@@ -780,21 +773,17 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
780773
781774 // update the extractSlice to static size, replace it with
782775 // useBlockedLayout when
783- if (failed (setStaticSizeForExtractSliceOp (
784- rewriter, currentOp.getDpsInits ()[0 ].getDefiningOp (), true ,
785- CInnermostDims, MDimNum > 1 ? 2 : 0 )) ||
786- failed (setStaticSizeForExtractSliceOp (
787- rewriter, currentOp.getDpsInputs ()[1 ].getDefiningOp (), true ,
788- BInnermostDims, NDimNum > 1 )) ||
789- failed (setStaticSizeForExtractSliceOp (
790- rewriter, currentOp.getDpsInputs ()[0 ].getDefiningOp (), true ,
791- AInnermostDims, MDimNum > 1 )) ||
792- (currentOp.getDpsInits ().size () > 1 &&
793- failed (setStaticSizeForExtractSliceOp (
794- rewriter, currentOp.getDpsInits ()[1 ].getDefiningOp (), true ,
795- CInnermostDims, MDimNum > 1 ? 2 : 0 )))) {
796- return failure ();
776+ setStaticSizeForExtractSliceOp (rewriter,
777+ currentOp.getDpsInputs ()[1 ].getDefiningOp (),
778+ true , BInnermostDims, NDimNum > 1 );
779+ setStaticSizeForExtractSliceOp (rewriter,
780+ currentOp.getDpsInputs ()[0 ].getDefiningOp (),
781+ true , AInnermostDims, MDimNum > 1 );
782+ for (auto init : currentOp.getDpsInits ()) {
783+ setStaticSizeForExtractSliceOp (rewriter, init.getDefiningOp (), true ,
784+ CInnermostDims, MDimNum > 1 ? 2 : 0 );
797785 }
786+
798787 // View the tensor to brgemm required format
799788 Value dataOprand = tensorViewRankedTensor (
800789 rewriter,
@@ -841,10 +830,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
841830
842831 // Insert the result back to the original tensor
843832 for (Operation *user : currentOp->getResult (0 ).getUsers ()) {
844- if (failed (setStaticSizeForInsertSliceOp (rewriter, user, result,
845- CInnermostDims))) {
846- return failure ();
847- }
833+ setStaticSizeForInsertSliceOp (rewriter, user, result, CInnermostDims);
848834 }
849835
850836 if (option.needLowPrecisionCast ) {
@@ -869,10 +855,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
869855 auto ifOp = eb.getLastOperaion ();
870856 // set static size for the insertSliceOp of copyOp
871857 for (Operation *user : currentOp->getResult (1 ).getUsers ()) {
872- if (failed (setStaticSizeForInsertSliceOp (
873- rewriter, user, ifOp->getResult (0 ), CInnermostDims))) {
874- return failure ();
875- }
858+ setStaticSizeForInsertSliceOp (rewriter, user, ifOp->getResult (0 ),
859+ CInnermostDims);
876860 }
877861 rewriter.replaceOp (currentOp, {matmul->getResult (0 ), ifOp->getResult (0 )});
878862 } else {
@@ -885,7 +869,11 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
885869 if (cfg.KThreads <= 1 ) {
886870 // if use k slicing, the fill op is still need to be kept for the reduce
887871 // init
888- rewriter.replaceOp (fillOp, fillOp.getDpsInits ()[0 ]);
872+ rewriter.replaceUsesWithIf (fillOp.getResult (0 ), fillOp.getDpsInits ()[0 ],
873+ [&](OpOperand &operand) {
874+ return isa<LoopLikeOpInterface>(
875+ operand.getOwner ());
876+ });
889877 }
890878
891879 rewriter.setInsertionPointAfter (currentOp);
@@ -954,8 +942,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
954942 }
955943
956944 // Step 2. Outer loop generation
957- auto outerLoopResult = outerLoopGeneration (rewriter, linalgOp, cfg,
958- isa<linalg::FillOp>(fillOp));
945+ auto outerLoopResult = outerLoopGeneration (
946+ rewriter, linalgOp, cfg, fillOp && isa<linalg::FillOp>(fillOp));
959947 if (failed (outerLoopResult)) {
960948 return failure ();
961949 }
0 commit comments