|
26 | 26 | #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" |
27 | 27 | #include "torch-mlir/Dialect/Torch/Utils/Utils.h" |
28 | 28 | #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" |
| 29 | +#include "llvm/ADT/STLExtras.h" |
29 | 30 | #include "llvm/ADT/TypeSwitch.h" |
30 | 31 | #include <cmath> |
31 | 32 | #include <numeric> |
@@ -4168,69 +4169,149 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite( |
4168 | 4169 | AtenSliceTensorOp op, OpAdaptor adaptor, |
4169 | 4170 | ConversionPatternRewriter &rewriter) const { |
4170 | 4171 |
|
4171 | | - auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType()); |
| 4172 | + Value self = adaptor.getSelf(); |
| 4173 | + auto selfType = dyn_cast<TensorType>(self.getType()); |
4172 | 4174 | if (!selfType || !selfType.hasStaticShape()) |
4173 | 4175 | return rewriter.notifyMatchFailure( |
4174 | 4176 | op, "Only tensor types with static shape are supported"); |
4175 | 4177 |
|
4176 | | - // Only statically deducible values are currently supported |
4177 | 4178 | int64_t dim; |
4178 | 4179 | if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) |
4179 | 4180 | return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); |
4180 | | - |
4181 | 4181 | dim = toPositiveDim(dim, selfType.getRank()); |
4182 | | - |
4183 | 4182 | if (!isValidDim(dim, selfType.getRank())) |
4184 | | - return rewriter.notifyMatchFailure(op, "dim must less than tensor rank"); |
| 4183 | + return rewriter.notifyMatchFailure(op, "dim out of range"); |
| 4184 | + |
| 4185 | + SmallVector<int64_t> inputShape = |
| 4186 | + llvm::to_vector(makeShapeTorchCompatible(selfType.getShape())); |
| 4187 | + const int64_t K = inputShape[dim]; |
4185 | 4188 |
|
4186 | 4189 | int64_t start; |
4187 | 4190 | if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) |
4188 | 4191 | return rewriter.notifyMatchFailure(op, "start must be a Scalar constant"); |
4189 | | - |
4190 | | - if (start < 0) { |
4191 | | - start = toPositiveDim(start, selfType.getShape()[dim]); |
4192 | | - if (!isValidDim(start, selfType.getShape()[dim])) |
4193 | | - return rewriter.notifyMatchFailure(op, "start is not a valid index"); |
4194 | | - } |
4195 | | - start = std::min(selfType.getShape()[dim], start); |
| 4192 | + // Torch accepts negative `start`/`end`; translate them to positive indices in |
| 4193 | + // the canonical [0, K] range before clamping. |
| 4194 | + if (start < 0) |
| 4195 | + start = toPositiveDim(start, K); |
| 4196 | + start = std::clamp<int64_t>(start, /*Min=*/0, /*Max=*/K); |
4196 | 4197 |
|
4197 | 4198 | int64_t end; |
4198 | | - if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { |
4199 | | - if (isa<ConstantNoneOp>(op.getEnd().getDefiningOp())) |
4200 | | - end = selfType.getShape()[dim]; |
4201 | | - else |
4202 | | - return rewriter.notifyMatchFailure(op, "end must be a Scalar constant"); |
4203 | | - } |
4204 | | - // support for end < 0 |
4205 | | - end = toPositiveDim(end, selfType.getShape()[dim]); |
4206 | | - // support for end out of upper bound |
4207 | | - end = (end > selfType.getShape()[dim] ? selfType.getShape()[dim] : end); |
4208 | | - |
4209 | | - // FIXME: add support for start < 0 and end < start |
4210 | | - if (end < start) |
4211 | | - return rewriter.notifyMatchFailure(op, |
4212 | | - "Currently unsupported: end < start"); |
| 4199 | + if (matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { |
| 4200 | + if (end == std::numeric_limits<int64_t>::max()) |
| 4201 | + end = K; |
| 4202 | + } else if (isa<ConstantNoneOp>(op.getEnd().getDefiningOp())) { |
| 4203 | + end = K; |
| 4204 | + } else { |
| 4205 | + return rewriter.notifyMatchFailure(op, "end must be a Scalar constant"); |
| 4206 | + } |
| 4207 | + if (end < 0) |
| 4208 | + end = toPositiveDim(end, K); |
| 4209 | + end = std::clamp<int64_t>(end, /*Min=*/0, /*Max=*/K); |
4213 | 4210 |
|
4214 | 4211 | int64_t step; |
4215 | 4212 | if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) |
4216 | 4213 | return rewriter.notifyMatchFailure(op, "step must be a Scalar constant"); |
| 4214 | + if (step <= 0) |
| 4215 | + return rewriter.notifyMatchFailure(op, "step <= 0 unsupported"); |
4217 | 4216 |
|
4218 | | - if (step != 1) |
4219 | | - return rewriter.notifyMatchFailure( |
4220 | | - op, "step value other than 1 is currently unsupported"); |
| 4217 | + auto loc = op->getLoc(); |
| 4218 | + auto elemTy = selfType.getElementType(); |
4221 | 4219 |
|
4222 | | - SmallVector<int64_t> startSlice(selfType.getRank(), 0); |
4223 | | - SmallVector<int64_t> sizeSlice = |
4224 | | - llvm::to_vector(makeShapeTorchCompatible(selfType.getShape())); |
| 4220 | + auto convertedResultTy = dyn_cast_or_null<RankedTensorType>( |
| 4221 | + getTypeConverter()->convertType(op.getType())); |
| 4222 | + if (!convertedResultTy || !convertedResultTy.hasStaticShape()) |
| 4223 | + return rewriter.notifyMatchFailure(op, |
| 4224 | + "result type must be statically shaped"); |
| 4225 | + |
| 4226 | + // When the stride is 1 the original tosa.slice lowering is still optimal. |
| 4227 | + if (step == 1) { |
| 4228 | + SmallVector<int64_t> startSlice(selfType.getRank(), 0); |
| 4229 | + SmallVector<int64_t> sizeSlice = inputShape; |
| 4230 | + startSlice[dim] = start; |
| 4231 | + sizeSlice[dim] = std::max<int64_t>(end - start, 0); |
| 4232 | + |
| 4233 | + rewriter.replaceOpWithNewOp<tosa::SliceOp>( |
| 4234 | + op, convertedResultTy, self, |
| 4235 | + tosa::getTosaConstShape(rewriter, loc, startSlice), |
| 4236 | + tosa::getTosaConstShape(rewriter, loc, sizeSlice)); |
| 4237 | + return success(); |
| 4238 | + } |
| 4239 | + |
| 4240 | + int64_t N = 1, C = 1; |
| 4241 | + for (int64_t i = 0; i < dim; ++i) |
| 4242 | + N *= inputShape[i]; |
| 4243 | + for (int64_t i = dim + 1; i < (int64_t)inputShape.size(); ++i) |
| 4244 | + C *= inputShape[i]; |
| 4245 | + |
| 4246 | + // Stride > 1: rewrite Torch slicing into TOSA as follows: |
| 4247 | + // 1) reshape the tensor to [N, K, C] so the sliced dimension is isolated, |
| 4248 | + // 2) materialize the index vector {start + i*step}, |
| 4249 | + // 3) tile indices across the batch dimension and gather the desired rows, |
| 4250 | + // 4) reshape the gathered result back to the original rank. |
| 4251 | + // Number of elements that survive after applying the stride. |
| 4252 | + int64_t W = (end > start) ? ((end - start + step - 1) / step) : 0; |
| 4253 | + |
| 4254 | + SmallVector<int64_t> nkcShape = {N, K, C}; |
| 4255 | + auto nkcTy = RankedTensorType::get(makeShapeLLVMCompatible(nkcShape), elemTy); |
| 4256 | + // Reshape the input tensor into [N, K, C] so that the sliced dimension |
| 4257 | + // becomes the middle axis (K) and all prefix/suffix dimensions are grouped |
| 4258 | + // into batch (N) and channel (C) components. When the original tensor is |
| 4259 | + // already three-dimensional with this layout, reuse it directly. |
| 4260 | + Value reshaped = (inputShape.size() == 3 && inputShape[0] == N && |
| 4261 | + inputShape[1] == K && inputShape[2] == C) |
| 4262 | + ? self |
| 4263 | + : tosa::ReshapeOp::create( |
| 4264 | + rewriter, loc, nkcTy, self, |
| 4265 | + tosa::getTosaConstShape(rewriter, loc, nkcShape)) |
| 4266 | + .getResult(); |
| 4267 | + |
| 4268 | + // Build the 1-D index vector [start, start + step, ...] that encodes the |
| 4269 | + // positions we want to gather from the K dimension. |
| 4270 | + SmallVector<int32_t> idxVals; |
| 4271 | + idxVals.reserve(W); |
| 4272 | + for (int64_t i = 0; i < W; ++i) |
| 4273 | + idxVals.push_back(static_cast<int32_t>(start + i * step)); |
| 4274 | + |
| 4275 | + auto idx1DTy = RankedTensorType::get({W}, rewriter.getI32Type()); |
| 4276 | + auto idxAttr = DenseIntElementsAttr::get(idx1DTy, idxVals); |
| 4277 | + Value idx1D = |
| 4278 | + tosa::ConstOp::create(rewriter, loc, idx1DTy, idxAttr).getResult(); |
| 4279 | + |
| 4280 | + // Gather expects a 2-D index tensor, so reshape to [1, W] prior to tiling. |
| 4281 | + auto idx1xWTy = RankedTensorType::get({1, W}, rewriter.getI32Type()); |
| 4282 | + Value idx1xW = |
| 4283 | + tosa::ReshapeOp::create( |
| 4284 | + rewriter, loc, idx1xWTy, idx1D, |
| 4285 | + tosa::getTosaConstShape(rewriter, loc, SmallVector<int64_t>{1, W})) |
| 4286 | + .getResult(); |
4225 | 4287 |
|
4226 | | - startSlice[dim] = start; |
4227 | | - sizeSlice[dim] = end - start; |
| 4288 | + // Tile the single row of indices across the batch dimension so every |
| 4289 | + // [batch, channel] slice uses the same sequence. |
| 4290 | + auto tileMul = |
| 4291 | + tosa::getTosaConstShape(rewriter, loc, SmallVector<int64_t>{N, 1}); |
| 4292 | + auto idxNWTy = RankedTensorType::get({N, W}, rewriter.getI32Type()); |
| 4293 | + Value idxNW = |
| 4294 | + tosa::TileOp::create(rewriter, loc, idxNWTy, idx1xW, tileMul).getResult(); |
| 4295 | + |
| 4296 | + // Duplicate the 1-D index vector across the batch dimension so that we can |
| 4297 | + // use a single tosa.gather to materialize the strided slice. |
| 4298 | + auto gatherTy = RankedTensorType::get({N, W, C}, elemTy); |
| 4299 | + Value gathered = |
| 4300 | + tosa::GatherOp::create(rewriter, loc, gatherTy, reshaped, idxNW) |
| 4301 | + .getResult(); |
4228 | 4302 |
|
4229 | | - rewriter.replaceOpWithNewOp<tosa::SliceOp>( |
4230 | | - op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), |
4231 | | - tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice), |
4232 | | - tosa::getTosaConstShape(rewriter, op->getLoc(), sizeSlice)); |
| 4303 | + SmallVector<int64_t> outShape = inputShape; |
| 4304 | + outShape[dim] = W; |
| 4305 | + assert(llvm::equal(convertedResultTy.getShape(), outShape) && |
| 4306 | + "type converter mismatch for slice result"); |
4233 | 4307 |
|
| 4308 | + // Restore the original rank with the newly strided dimension size. |
| 4309 | + Value result = |
| 4310 | + tosa::ReshapeOp::create(rewriter, loc, convertedResultTy, gathered, |
| 4311 | + tosa::getTosaConstShape(rewriter, loc, outShape)) |
| 4312 | + .getResult(); |
| 4313 | + |
| 4314 | + rewriter.replaceOp(op, result); |
4234 | 4315 | return success(); |
4235 | 4316 | } |
4236 | 4317 |
|
|
0 commit comments