Skip to content

Commit 5b049a1

Browse files
[TOSA] Handle strided AtenSlice lowering (#4372)
Lower step-greater-than-one slices by reshaping to an NxKxC view, gathering rows, then restoring the original shape.
1 parent bbd46d0 commit 5b049a1

File tree

2 files changed

+120
-41
lines changed

2 files changed

+120
-41
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 120 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
2727
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
2828
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
29+
#include "llvm/ADT/STLExtras.h"
2930
#include "llvm/ADT/TypeSwitch.h"
3031
#include <cmath>
3132
#include <numeric>
@@ -4168,69 +4169,149 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
41684169
AtenSliceTensorOp op, OpAdaptor adaptor,
41694170
ConversionPatternRewriter &rewriter) const {
41704171

4171-
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
4172+
Value self = adaptor.getSelf();
4173+
auto selfType = dyn_cast<TensorType>(self.getType());
41724174
if (!selfType || !selfType.hasStaticShape())
41734175
return rewriter.notifyMatchFailure(
41744176
op, "Only tensor types with static shape are supported");
41754177

4176-
// Only statically deducible values are currently supported
41774178
int64_t dim;
41784179
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
41794180
return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant");
4180-
41814181
dim = toPositiveDim(dim, selfType.getRank());
4182-
41834182
if (!isValidDim(dim, selfType.getRank()))
4184-
return rewriter.notifyMatchFailure(op, "dim must less than tensor rank");
4183+
return rewriter.notifyMatchFailure(op, "dim out of range");
4184+
4185+
SmallVector<int64_t> inputShape =
4186+
llvm::to_vector(makeShapeTorchCompatible(selfType.getShape()));
4187+
const int64_t K = inputShape[dim];
41854188

41864189
int64_t start;
41874190
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
41884191
return rewriter.notifyMatchFailure(op, "start must be a Scalar constant");
4189-
4190-
if (start < 0) {
4191-
start = toPositiveDim(start, selfType.getShape()[dim]);
4192-
if (!isValidDim(start, selfType.getShape()[dim]))
4193-
return rewriter.notifyMatchFailure(op, "start is not a valid index");
4194-
}
4195-
start = std::min(selfType.getShape()[dim], start);
4192+
// Torch accepts negative `start`/`end`; translate them to positive indices in
4193+
// the canonical [0, K] range before clamping.
4194+
if (start < 0)
4195+
start = toPositiveDim(start, K);
4196+
start = std::clamp<int64_t>(start, /*Min=*/0, /*Max=*/K);
41964197

41974198
int64_t end;
4198-
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
4199-
if (isa<ConstantNoneOp>(op.getEnd().getDefiningOp()))
4200-
end = selfType.getShape()[dim];
4201-
else
4202-
return rewriter.notifyMatchFailure(op, "end must be a Scalar constant");
4203-
}
4204-
// support for end < 0
4205-
end = toPositiveDim(end, selfType.getShape()[dim]);
4206-
// support for end out of upper bound
4207-
end = (end > selfType.getShape()[dim] ? selfType.getShape()[dim] : end);
4208-
4209-
// FIXME: add support for start < 0 and end < start
4210-
if (end < start)
4211-
return rewriter.notifyMatchFailure(op,
4212-
"Currently unsupported: end < start");
4199+
if (matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
4200+
if (end == std::numeric_limits<int64_t>::max())
4201+
end = K;
4202+
} else if (isa<ConstantNoneOp>(op.getEnd().getDefiningOp())) {
4203+
end = K;
4204+
} else {
4205+
return rewriter.notifyMatchFailure(op, "end must be a Scalar constant");
4206+
}
4207+
if (end < 0)
4208+
end = toPositiveDim(end, K);
4209+
end = std::clamp<int64_t>(end, /*Min=*/0, /*Max=*/K);
42134210

42144211
int64_t step;
42154212
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
42164213
return rewriter.notifyMatchFailure(op, "step must be a Scalar constant");
4214+
if (step <= 0)
4215+
return rewriter.notifyMatchFailure(op, "step <= 0 unsupported");
42174216

4218-
if (step != 1)
4219-
return rewriter.notifyMatchFailure(
4220-
op, "step value other than 1 is currently unsupported");
4217+
auto loc = op->getLoc();
4218+
auto elemTy = selfType.getElementType();
42214219

4222-
SmallVector<int64_t> startSlice(selfType.getRank(), 0);
4223-
SmallVector<int64_t> sizeSlice =
4224-
llvm::to_vector(makeShapeTorchCompatible(selfType.getShape()));
4220+
auto convertedResultTy = dyn_cast_or_null<RankedTensorType>(
4221+
getTypeConverter()->convertType(op.getType()));
4222+
if (!convertedResultTy || !convertedResultTy.hasStaticShape())
4223+
return rewriter.notifyMatchFailure(op,
4224+
"result type must be statically shaped");
4225+
4226+
// When the stride is 1 the original tosa.slice lowering is still optimal.
4227+
if (step == 1) {
4228+
SmallVector<int64_t> startSlice(selfType.getRank(), 0);
4229+
SmallVector<int64_t> sizeSlice = inputShape;
4230+
startSlice[dim] = start;
4231+
sizeSlice[dim] = std::max<int64_t>(end - start, 0);
4232+
4233+
rewriter.replaceOpWithNewOp<tosa::SliceOp>(
4234+
op, convertedResultTy, self,
4235+
tosa::getTosaConstShape(rewriter, loc, startSlice),
4236+
tosa::getTosaConstShape(rewriter, loc, sizeSlice));
4237+
return success();
4238+
}
4239+
4240+
int64_t N = 1, C = 1;
4241+
for (int64_t i = 0; i < dim; ++i)
4242+
N *= inputShape[i];
4243+
for (int64_t i = dim + 1; i < (int64_t)inputShape.size(); ++i)
4244+
C *= inputShape[i];
4245+
4246+
// Stride > 1: rewrite Torch slicing into TOSA as follows:
4247+
// 1) reshape the tensor to [N, K, C] so the sliced dimension is isolated,
4248+
// 2) materialize the index vector {start + i*step},
4249+
// 3) tile indices across the batch dimension and gather the desired rows,
4250+
// 4) reshape the gathered result back to the original rank.
4251+
// Number of elements that survive after applying the stride.
4252+
int64_t W = (end > start) ? ((end - start + step - 1) / step) : 0;
4253+
4254+
SmallVector<int64_t> nkcShape = {N, K, C};
4255+
auto nkcTy = RankedTensorType::get(makeShapeLLVMCompatible(nkcShape), elemTy);
4256+
// Reshape the input tensor into [N, K, C] so that the sliced dimension
4257+
// becomes the middle axis (K) and all prefix/suffix dimensions are grouped
4258+
// into batch (N) and channel (C) components. When the original tensor is
4259+
// already three-dimensional with this layout, reuse it directly.
4260+
Value reshaped = (inputShape.size() == 3 && inputShape[0] == N &&
4261+
inputShape[1] == K && inputShape[2] == C)
4262+
? self
4263+
: tosa::ReshapeOp::create(
4264+
rewriter, loc, nkcTy, self,
4265+
tosa::getTosaConstShape(rewriter, loc, nkcShape))
4266+
.getResult();
4267+
4268+
// Build the 1-D index vector [start, start + step, ...] that encodes the
4269+
// positions we want to gather from the K dimension.
4270+
SmallVector<int32_t> idxVals;
4271+
idxVals.reserve(W);
4272+
for (int64_t i = 0; i < W; ++i)
4273+
idxVals.push_back(static_cast<int32_t>(start + i * step));
4274+
4275+
auto idx1DTy = RankedTensorType::get({W}, rewriter.getI32Type());
4276+
auto idxAttr = DenseIntElementsAttr::get(idx1DTy, idxVals);
4277+
Value idx1D =
4278+
tosa::ConstOp::create(rewriter, loc, idx1DTy, idxAttr).getResult();
4279+
4280+
// Gather expects a 2-D index tensor, so reshape to [1, W] prior to tiling.
4281+
auto idx1xWTy = RankedTensorType::get({1, W}, rewriter.getI32Type());
4282+
Value idx1xW =
4283+
tosa::ReshapeOp::create(
4284+
rewriter, loc, idx1xWTy, idx1D,
4285+
tosa::getTosaConstShape(rewriter, loc, SmallVector<int64_t>{1, W}))
4286+
.getResult();
42254287

4226-
startSlice[dim] = start;
4227-
sizeSlice[dim] = end - start;
4288+
// Tile the single row of indices across the batch dimension so every
4289+
// [batch, channel] slice uses the same sequence.
4290+
auto tileMul =
4291+
tosa::getTosaConstShape(rewriter, loc, SmallVector<int64_t>{N, 1});
4292+
auto idxNWTy = RankedTensorType::get({N, W}, rewriter.getI32Type());
4293+
Value idxNW =
4294+
tosa::TileOp::create(rewriter, loc, idxNWTy, idx1xW, tileMul).getResult();
4295+
4296+
// Duplicate the 1-D index vector across the batch dimension so that we can
4297+
// use a single tosa.gather to materialize the strided slice.
4298+
auto gatherTy = RankedTensorType::get({N, W, C}, elemTy);
4299+
Value gathered =
4300+
tosa::GatherOp::create(rewriter, loc, gatherTy, reshaped, idxNW)
4301+
.getResult();
42284302

4229-
rewriter.replaceOpWithNewOp<tosa::SliceOp>(
4230-
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
4231-
tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice),
4232-
tosa::getTosaConstShape(rewriter, op->getLoc(), sizeSlice));
4303+
SmallVector<int64_t> outShape = inputShape;
4304+
outShape[dim] = W;
4305+
assert(llvm::equal(convertedResultTy.getShape(), outShape) &&
4306+
"type converter mismatch for slice result");
42334307

4308+
// Restore the original rank with the newly strided dimension size.
4309+
Value result =
4310+
tosa::ReshapeOp::create(rewriter, loc, convertedResultTy, gathered,
4311+
tosa::getTosaConstShape(rewriter, loc, outShape))
4312+
.getResult();
4313+
4314+
rewriter.replaceOp(op, result);
42344315
return success();
42354316
}
42364317

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3916,8 +3916,6 @@
39163916
"SliceCopyStartGreaterThanDimSize_Module_basic",
39173917
"SliceEndSleStartModule_basic",
39183918
"SliceOutOfLowerBoundEndIndexModule_basic",
3919-
"SliceOutOfLowerBoundStartIndexModule_basic",
3920-
"SliceSizeTwoStepModule_basic",
39213919
"SortIntListReverse_basic",
39223920
"SortIntList_basic",
39233921
"SortTensorDescending_basic",

0 commit comments

Comments
 (0)