@@ -151,10 +151,10 @@ static void buildLinalgRegion(Operation *op, bool createTemporaryOp = false) {
151151// Check if the linalgOp need to be legalized to f32 accumulation type
152152static bool needToLegalizeDtype (linalg::LinalgOp linalgOp) {
153153 mlir::Type dataType =
154- dyn_cast<mlir::RankedTensorType >(linalgOp.getDpsInputs ()[0 ].getType ())
154+ dyn_cast<mlir::ShapedType >(linalgOp.getDpsInputs ()[0 ].getType ())
155155 .getElementType ();
156156 mlir::Type resultType =
157- dyn_cast<mlir::RankedTensorType >(linalgOp.getDpsInits ()[0 ].getType ())
157+ dyn_cast<mlir::ShapedType >(linalgOp.getDpsInits ()[0 ].getType ())
158158 .getElementType ();
159159 return (dataType.isBF16 () || dataType.isF16 ()) && dataType == resultType;
160160}
@@ -372,7 +372,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
372372 linalg::LinalgOp currentOp = linalgOp;
373373
374374 bool hasFullResult = !option.isPartialResult ;
375- for (auto [i, loopType] : llvm::enumerate (loopType)) {
375+ for (auto && [i, loopType] : llvm::enumerate (loopType)) {
376376 ArrayRef<size_t > currentDim = loopDim[i];
377377 ArrayRef<size_t > currentTileSize = nestedTileSizes[i];
378378 if (loopType == OuterLoopGenerationOption::LoopType::ForOp) {
@@ -420,7 +420,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
420420 cast<TilingInterface>(currentOp.getOperation ()).getIterationDomain (b);
421421 currentOp.getReductionDims (reductionDims);
422422 bool tileOnReduction = false ;
423- for (auto [d, tile] : llvm::zip (currentDim, currentTileSize)) {
423+ for (auto && [d, tile] : llvm::zip (currentDim, currentTileSize)) {
424424 if (llvm::find (reductionDims, d) != reductionDims.end () && tile != 0 &&
425425 (!getConstantIntValue (loopRanges[d].size ) ||
426426 tile != static_cast <size_t >(
@@ -438,22 +438,23 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
438438 OpBuilder::InsertionGuard guard (b);
439439 b.setInsertionPoint (currentOp);
440440 if (tileOnReduction) {
441- for (auto [idx, tile] : llvm::enumerate (tileSizes)) {
441+ for (auto && [idx, tile] : llvm::enumerate (tileSizes)) {
442442 if (isConstantIntValue (tile, 0 ) &&
443443 llvm::find (reductionDims, idx) != reductionDims.end ()) {
444444 tileSizes[idx] = loopRanges[idx].size ;
445445 }
446446 }
447447 SmallVector<OpFoldResult> newParallelDims;
448- for (size_t i = 0UL ; i < reductionDims.size (); i++) {
449- newParallelDims.push_back (getAsIndexOpFoldResult (b.getContext (), i));
448+ for (auto iter : llvm::enumerate (reductionDims)) {
449+ newParallelDims.push_back (
450+ getAsIndexOpFoldResult (b.getContext (), iter.index ()));
450451 }
451452 FailureOr<linalg::ForallReductionTilingResult> tilingResult =
452453 linalgX::tileReductionUsingForall (
453454 b, cast<PartialReductionOpInterface>(currentOp.getOperation ()),
454455 {}, tileSizes, newParallelDims, std::nullopt );
455456 if (failed (tilingResult) &&
456- tilingResult->parallelTiledOps . size () == 1UL )
457+ llvm::hasSingleElement ( tilingResult->parallelTiledOps ) )
457458 return failure ();
458459 currentOp =
459460 dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOps .back ());
@@ -585,7 +586,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
585586 : cfg.NBlock ;
586587
587588 // Outer loop tile size
588- for (auto [tile, dim] :
589+ for (auto && [tile, dim] :
589590 llvm::zip (SmallVector<size_t >{KParallelBlockSize, MParallelBlockSize,
590591 NParallelBlockSize},
591592 SmallVector<size_t >{KDimPos[0 ], MDimPos[0 ], NDimPos[0 ]})) {
@@ -596,27 +597,27 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
596597 }
597598
598599 // Middle loop tile size
599- for (auto [tile, dim] :
600+ for (auto && [tile, dim] :
600601 llvm::zip (SmallVector<size_t >{MOuterBlockSize, NOuterBlockSize,
601602 KOuterBlockSize},
602603 SmallVector<size_t >{MDimPos[0 ], NDimPos[0 ], KDimPos[0 ]})) {
603604 option.nestedTileSizes .emplace_back (SmallVector<size_t >{tile});
604605 option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
605606 option.loopDim .emplace_back (SmallVector<size_t >{dim});
606607 }
607- if (KDimPos. size () == 1 ) {
608+ if (llvm::hasSingleElement (KDimPos) ) {
608609 option.nestedTileSizes .emplace_back (SmallVector<size_t >{cfg.KBlock });
609610 option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
610611 option.loopDim .emplace_back (SmallVector<size_t >{KDimPos.back ()});
611612 }
612613 // Inner loop tile size
613- if (MDimPos. size () == 1 ) {
614+ if (llvm::hasSingleElement (MDimPos) ) {
614615 option.nestedTileSizes .emplace_back (
615616 SmallVector<size_t >{cfg.innerMostMBlock });
616617 option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
617618 option.loopDim .emplace_back (SmallVector<size_t >{MDimPos.back ()});
618619 }
619- if (NDimPos. size () == 1 ) {
620+ if (llvm::hasSingleElement (NDimPos) ) {
620621 option.nestedTileSizes .emplace_back (
621622 SmallVector<size_t >{cfg.innerMostNBlock });
622623 option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
@@ -656,7 +657,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
656657 const linalg::ForallReductionTilingResult &result)
657658 -> FailureOr<linalg::LinalgOp> {
658659 ArrayRef<Value> initValue = result.initialValues ;
659- if (initValue. size () == 1 &&
660+ if (llvm::hasSingleElement (initValue) &&
660661 isa<linalg::FillOp>(initValue[0 ].getDefiningOp ())) {
661662 rewriter.replaceOp (initValue[0 ].getDefiningOp (),
662663 dyn_cast<DestinationStyleOpInterface>(
@@ -706,7 +707,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
706707 SmallVector<int64_t > AInnermostDims, BInnermostDims, CInnermostDims;
707708 bool firstM = true , firstK = true , firstN = true ;
708709 if (MDimNum > 1 ) {
709- for (auto [idx, iter] : llvm::enumerate ((*operandDimTypes)[0 ])) {
710+ for (auto && [idx, iter] : llvm::enumerate ((*operandDimTypes)[0 ])) {
710711 if (iter == DimType::M && firstM) {
711712 AInnermostDims.push_back (1 );
712713 firstM = false ;
@@ -721,7 +722,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
721722 }
722723 firstM = true ;
723724 firstN = true ;
724- for (auto [idx, iter] : llvm::enumerate ((*operandDimTypes)[2 ])) {
725+ for (auto && [idx, iter] : llvm::enumerate ((*operandDimTypes)[2 ])) {
725726 if (iter == DimType::M && firstM) {
726727 CInnermostDims.push_back (1 );
727728 firstM = false ;
@@ -745,7 +746,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
745746 if (NDimNum > 1 ) {
746747 firstN = true ;
747748 firstK = true ;
748- for (auto [idx, iter] : llvm::enumerate ((*operandDimTypes)[1 ])) {
749+ for (auto && [idx, iter] : llvm::enumerate ((*operandDimTypes)[1 ])) {
749750 if (iter == DimType::N && firstN) {
750751 BInnermostDims.push_back (1 );
751752 firstN = false ;
@@ -768,13 +769,13 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
768769 OpBuilder::InsertionGuard guard (rewriter);
769770 rewriter.setInsertionPoint (currentOp);
770771 mlir::Type dataType =
771- dyn_cast<mlir::RankedTensorType >(currentOp.getDpsInputs ()[0 ].getType ())
772+ dyn_cast<mlir::ShapedType >(currentOp.getDpsInputs ()[0 ].getType ())
772773 .getElementType ();
773774 mlir::Type weightType =
774- dyn_cast<mlir::RankedTensorType >(currentOp.getDpsInputs ()[1 ].getType ())
775+ dyn_cast<mlir::ShapedType >(currentOp.getDpsInputs ()[1 ].getType ())
775776 .getElementType ();
776777 mlir::Type resultType =
777- dyn_cast<mlir::RankedTensorType >(currentOp.getDpsInits ()[0 ].getType ())
778+ dyn_cast<mlir::ShapedType >(currentOp.getDpsInits ()[0 ].getType ())
778779 .getElementType ();
779780
780781 // update the extractSlice to static size, replace it with
@@ -821,9 +822,8 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
821822 currentOp.getDpsInits ()[0 ]);
822823 // Create the brgemm op and replace the origin linalg op
823824 linalg::LinalgOp matmul;
824- if (dyn_cast<mlir::RankedTensorType>(weightOprand.getType ())
825- .getShape ()
826- .size () == 3 ) {
825+ if (dyn_cast<mlir::ShapedType>(weightOprand.getType ()).getShape ().size () ==
826+ 3 ) {
827827 matmul = rewriter.create <linalg::BatchReduceMatmulOp>(
828828 loc, resultOprand.getType (), ValueRange{dataOprand, weightOprand},
829829 resultOprand);
@@ -843,7 +843,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
843843 // fuse the low precision cast to the innermost body
844844 rewriter.setInsertionPointAfter (currentOp);
845845 Value cond;
846- for (LoopLikeOpInterface loop : option.KLoopHandles ) {
846+ for (LoopLikeOpInterface & loop : option.KLoopHandles ) {
847847 Value induceVar = turnOpFoldResultIntoValue (
848848 rewriter, loc, *loop.getSingleInductionVar ());
849849 Value upBound = turnOpFoldResultIntoValue (rewriter, loc,
@@ -903,7 +903,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
903903 Value cond;
904904 arith::ConstantIndexOp zeroConst =
905905 rewriter.create <arith::ConstantIndexOp>(loc, 0 );
906- for (LoopLikeOpInterface loop : option.KLoopHandles ) {
906+ for (LoopLikeOpInterface & loop : option.KLoopHandles ) {
907907 Value induceVar = loop.getLoopRegions ().front ()->front ().getArgument (0 );
908908 Value currentCond = rewriter.create <arith::CmpIOp>(
909909 loc, arith::CmpIPredicate::eq, induceVar, zeroConst);
0 commit comments