@@ -464,30 +464,36 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
464
464
currentOp.getNumLoops (), getAsIndexOpFoldResult (b.getContext (), 0 ));
465
465
SmallVector<unsigned > reductionDims;
466
466
currentOp.getReductionDims (reductionDims);
467
+ bool tileOnReduction = false ;
467
468
for (auto [d, tile] : llvm::zip (currentDim, currentTileSize)) {
469
+ if (llvm::find (reductionDims, d) != reductionDims.end ()) {
470
+ tileOnReduction = true ;
471
+ }
468
472
if (llvm::find (reductionDims, d) != reductionDims.end () &&
469
- !dyn_cast<PartialReductionOpInterface>(currentOp.getOperation ()))
473
+ !dyn_cast<PartialReductionOpInterface>(currentOp.getOperation ())) {
470
474
tileSizes[d] = getAsIndexOpFoldResult (b.getContext (), 0 );
471
- else
475
+ tileOnReduction = false ;
476
+ } else
472
477
tileSizes[d] = getAsIndexOpFoldResult (b.getContext (), tile);
473
478
}
474
479
SmallVector<Range> loopRanges =
475
480
cast<TilingInterface>(currentOp.getOperation ()).getIterationDomain (b);
476
481
OpBuilder::InsertionGuard guard (b);
477
482
b.setInsertionPoint (currentOp);
478
- if (auto partialInterface =
479
- dyn_cast<PartialReductionOpInterface>(currentOp.getOperation ())) {
483
+ if (tileOnReduction) {
484
+ auto partialInterface =
485
+ dyn_cast<PartialReductionOpInterface>(currentOp.getOperation ());
480
486
for (auto [idx, tile] : llvm::enumerate (tileSizes)) {
481
- if (isConstantIntValue (tile, 0 )) {
487
+ if (isConstantIntValue (tile, 0 ) &&
488
+ llvm::find (reductionDims, d) != reductionDims.end ()) {
482
489
tileSizes[idx] = loopRanges[idx].size ;
483
490
}
484
491
}
485
-
486
492
SmallVector<OpFoldResult> newParallelDims;
487
493
for (auto i = 0UL ; i < reductionDims.size (); i++) {
488
494
newParallelDims.push_back (getAsIndexOpFoldResult (b.getContext (), i));
489
495
}
490
- auto tilingResult = linalgX::tileAllUsingForall (
496
+ auto tilingResult = linalgX::tileReductionUsingForall (
491
497
b, cast<PartialReductionOpInterface>(currentOp.getOperation ()), {},
492
498
tileSizes, newParallelDims, std::nullopt);
493
499
if (failed (tilingResult) &&
@@ -503,8 +509,8 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
503
509
}
504
510
}
505
511
}
506
- } else if ( auto tilingInterface =
507
- cast<TilingInterface>(currentOp.getOperation ())) {
512
+ } else {
513
+ auto tilingInterface = cast<TilingInterface>(currentOp.getOperation ());
508
514
auto tilingResult = linalg::tileToForallOpUsingTileSizes (
509
515
b, tilingInterface, tileSizes, std::nullopt);
510
516
if (failed (tilingResult))
@@ -597,11 +603,15 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
597
603
? (cfg.NBlock - 1 ) / cfg.innerMostNBlock + 1
598
604
: cfg.NBlock ;
599
605
// Outer
600
- option.nestedTileSizes .emplace_back (SmallVector<size_t >{
601
- MParallelBlockSize, NParallelBlockSize, KParallelBlockSize});
602
- option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForallOp);
603
- option.loopDim .emplace_back (
604
- SmallVector<size_t >{MDimPos[0 ], NDimPos[0 ], KDimPos[0 ]});
606
+ for (auto [tile, dim] :
607
+ llvm::zip (SmallVector<size_t >{KParallelBlockSize, MParallelBlockSize,
608
+ NParallelBlockSize},
609
+ SmallVector<size_t >{KDimPos[0 ], MDimPos[0 ], NDimPos[0 ]})) {
610
+ option.nestedTileSizes .emplace_back (SmallVector<size_t >{tile});
611
+ option.loopType .emplace_back (
612
+ OuterLoopGenerationOption::LoopType::ForallOp);
613
+ option.loopDim .emplace_back (SmallVector<size_t >{dim});
614
+ }
605
615
// Middle
606
616
for (auto [tile, dim] :
607
617
llvm::zip (SmallVector<size_t >{MOuterBlockSize, NOuterBlockSize,
0 commit comments