Skip to content

Commit

Permalink
[Stablehlo] Add AtenIndexTensor StableHlo support (llvm#2107)
Browse files Browse the repository at this point in the history
* Add AtenIndexTensor StableHlo support

* clean up

* Empty commit, trigger test

* try to debug hanging test

* fix segfulat

* fix bad include

---------

Co-authored-by: zhekun.zhang <[email protected]>
  • Loading branch information
zhekunz2 and zhekunz2 authored May 24, 2023
1 parent a426363 commit eb8f56a
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 1 deletion.
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexSelectNegativeDimModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"LayerNormLastDimModule_basic",
"LayerNormModule_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
Expand Down
157 changes: 156 additions & 1 deletion lib/Conversion/TorchToStablehlo/GatherScatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -375,6 +376,159 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
return success();
}

// AtenIndexTensorOp
// Convert AtenIndexTensorOp to StableHlo::GatherOp
// Step 1: broadcast indices to the same shape
// Step 2: reshape broadcasted indices to have extra last dimension and concat
// Step 3: Create StableHlo::GatherOp with input tensor and indices
//
// Example:
// Input: [[1, 2, 3],
// [4, 5, 6],
// [7, 8, 9]]
// Indices[0]: [[0, 0, 0],
// [2, 2, 0]]
// Indices[1]: [[2],
// [1]]
// Step 1:
// Indices[0]: [[0, 0, 0],
// [2, 2, 0]]
// Indices[1]: [[2, 2, 2],
// [1, 1, 1]]
// Step 2:
// Indices: [[[0, 2], [0, 2], [0, 2]],
// [[2, 1], [2, 1], [0, 1]]]
// Step 3:
// Output: [[3, 3, 3],
// [8, 8, 2]]
template <>
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
AtenIndexTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
Value input = adaptor.getSelf();
auto inputTensorType = input.getType().dyn_cast<RankedTensorType>();
// Check input is a tensor type.
if (!inputTensorType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
Value indexList = op.getIndices();
SmallVector<Value> indicesTorchType;
if (!getListConstructElements(indexList, indicesTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");

auto indexTensors = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
indicesTorchType);

// Step 1: broadcast indices tensors
int maxRank = -1;
SmallVector<int64_t> indicesShape;
SmallVector<int64_t> expandShape;
SmallVector<int64_t> concatShape;
// concat index tensor into to indices tensor for concat
for (size_t i = 0; i < indexTensors.size(); i++) {
auto indexTensor = indexTensors[i];
auto indexTorchTensor = indicesTorchType[i];
// TODO: add support for none index input
if (indexTorchTensor.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Only list ranked tensor types index are supported");
auto indexTensorType = indexTensor.getType().cast<RankedTensorType>();
for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) {
if (size == kUnknownSize)
return rewriter.notifyMatchFailure(op, "Dynamic index support TBD");
}
maxRank = std::max(maxRank, (int)indexTensorType.getRank());
}

RankedTensorType resultType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
SmallVector<int64_t> refinedResultShape =
makeShapeTorchCompatible(resultType.getShape());
for (int64_t size : refinedResultShape) {
if (size == kUnknownSize)
return rewriter.notifyMatchFailure(op, "Dynamic index support TBD");
}
for (int i = 0; i < maxRank; i++) {
indicesShape.push_back(refinedResultShape[i]);
expandShape.push_back(refinedResultShape[i]);
concatShape.push_back(refinedResultShape[i]);
}
if (indexTensors.size() > 1) {
expandShape.push_back(1);
concatShape.push_back(indexTensors.size());
}

SmallVector<Value> broadcastedIndices;
Type indexElemTy =
indexTensors[0].getType().cast<RankedTensorType>().getElementType();
RankedTensorType bcastIndexType =
RankedTensorType::get(indicesShape, indexElemTy);
for (auto indexTensor : indexTensors) {
Value bcastVal =
hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType);
if (indexTensors.size() > 1) {
RankedTensorType reshapeType =
RankedTensorType::get(expandShape, indexElemTy);
bcastVal =
rewriter.create<stablehlo::ReshapeOp>(loc, reshapeType, bcastVal);
}
broadcastedIndices.push_back(bcastVal);
}

// Step 2: concat index tensors
Value finalIndexTensor = broadcastedIndices[0];
if (broadcastedIndices.size() > 1) {
RankedTensorType concatTy = RankedTensorType::get(concatShape, indexElemTy);
finalIndexTensor = rewriter.create<stablehlo::ConcatenateOp>(
loc, concatTy, ValueRange(broadcastedIndices), concatShape.size() - 1);
}

// Step 3: create stablehlo::GatherOp
RankedTensorType finalIndexTy =
finalIndexTensor.getType().cast<RankedTensorType>();
int64_t indicesRank = finalIndexTy.getRank();
int64_t numIndicesDim = broadcastedIndices.size();
int64_t indexVecDim = numIndicesDim > 1 ? indicesRank - 1 : indicesRank;

SmallVector<int64_t> offsetDims;
SmallVector<int64_t> collapsedDims;
SmallVector<int64_t> startIndexMap;
for (int64_t i = 0; i < numIndicesDim; ++i) {
collapsedDims.push_back(i);
startIndexMap.push_back(i);
}
for (int64_t i = numIndicesDim; i < inputTensorType.getRank(); i++) {
if (numIndicesDim > 1) {
offsetDims.push_back(i + indicesRank - 1 - numIndicesDim);
} else {
offsetDims.push_back(i + indicesRank - numIndicesDim);
}
}
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
rewriter.getContext(),
/*offsetDims=*/offsetDims,
/*collapsedSliceDims=*/collapsedDims,
/*startIndexMap=*/startIndexMap,
/*indexVecDim=*/indexVecDim);

SmallVector<int64_t> sliceSizes;
auto inputShape = makeShapeTorchCompatible(inputTensorType.getShape());
for (int64_t i = 0; i < inputTensorType.getRank(); ++i) {
if (i < numIndicesDim) {
sliceSizes.push_back(1);
} else {
sliceSizes.push_back(inputShape[i]);
}
}

rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
op, resultType, input, finalIndexTensor, dimsAttr,
rewriter.getI64TensorAttr(sliceSizes));
return success();
}

void mlir::torch::torch_to_stablehlo::
populateGatherScatterOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
Expand All @@ -388,5 +542,6 @@ void mlir::torch::torch_to_stablehlo::
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
INSERT_ATENOP_PATTERN(AtenIndexTensorOp);
#undef INSERT_ATENOP_PATTERN
}

0 comments on commit eb8f56a

Please sign in to comment.