@@ -275,19 +275,22 @@ static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
275275 SmallVector<OpFoldResult> mixedOffsets = extractSlice.getMixedOffsets ();
276276 SmallVector<OpFoldResult> mixedSizes = extractSlice.getMixedSizes ();
277277 SmallVector<OpFoldResult> mixedStrides = extractSlice.getMixedStrides ();
278+ auto targetTensor = mlir::RankedTensorType::get (
279+ SmallVector<int64_t >(size.begin () + shrinDimNum, size.end ()),
280+ extractSlice.getResult ().getType ().getElementType ());
278281 for (auto &&[i, s] : llvm::enumerate (size))
279282 mixedSizes[i] = getAsIndexOpFoldResult (rewriter.getContext (), s);
280- if (shrinDimNum > 0 )
281- rewriter. replaceOpWithNewOp <tensor::ExtractSliceOp>(
282- extractSlice,
283- mlir::RankedTensorType::get (
284- SmallVector< int64_t >(size. begin () + shrinDimNum, size. end ()),
285- extractSlice. getResult (). getType (). getElementType ()),
286- extractSlice. getSource (), mixedOffsets, mixedSizes, mixedStrides );
287- else
288- rewriter. replaceOpWithNewOp <tensor::ExtractSliceOp>(
289- extractSlice, extractSlice. getSource (), mixedOffsets, mixedSizes,
290- mixedStrides);
283+ Operation *newExtractSliceOp = rewriter. create <tensor::ExtractSliceOp>(
284+ extractSlice-> getLoc (), extractSlice. getSource (), mixedOffsets,
285+ mixedSizes, mixedStrides);
286+ if (shrinDimNum > 0 ) {
287+ rewriter. setInsertionPointAfter (newExtractSliceOp);
288+ Value viewResult = tensorViewRankedTensor (
289+ rewriter, targetTensor, newExtractSliceOp-> getResult ( 0 ) );
290+ rewriter. replaceOp (extractSlice, viewResult);
291+ } else {
292+ rewriter. replaceOp (extractSlice, newExtractSliceOp);
293+ }
291294 }
292295}
293296
@@ -304,9 +307,12 @@ static void setStaticSizeForInsertSliceOp(RewriterBase &rewriter, Operation *op,
304307 SmallVector<OpFoldResult> mixedStrides = insertSlice.getMixedStrides ();
305308 for (auto &&[i, s] : llvm::enumerate (size))
306309 mixedSizes[i] = getAsIndexOpFoldResult (rewriter.getContext (), s);
310+ auto targetTensor = mlir::RankedTensorType::get (
311+ size, insertSlice.getDest ().getType ().getElementType ());
312+ Value viewResult = tensorViewRankedTensor (rewriter, targetTensor, source);
307313 rewriter.replaceOpWithNewOp <tensor::InsertSliceOp>(
308- insertSlice, source , insertSlice.getDest (), mixedOffsets, mixedSizes ,
309- mixedStrides);
314+ insertSlice, viewResult , insertSlice.getDest (), mixedOffsets,
315+ mixedSizes, mixedStrides);
310316 }
311317}
312318
0 commit comments