Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with 061bbc5e (Dec 19) (140) #530

Merged
merged 9 commits into from
Feb 13, 2025
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Torch-MLIR is primarily a project that is integrated into compilers to bridge th

* [IREE](https://github.com/iree-org/iree.git)
* [Blade](https://github.com/alibaba/BladeDISC)
* [MPACT](https://github.com/MPACT-ORG/mpact-compiler)

While most of the project is exercised via testing paths, there are some ways that an end user can directly use the APIs without further integration:

Expand Down
127 changes: 75 additions & 52 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3688,6 +3688,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
patterns.onOp(
"NonMaxSuppression", 10,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
Torch::ValueTensorType resultType;
SmallVector<Value> operands;
int64_t centerPointBox;
Expand All @@ -3702,96 +3703,132 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, "unimplemented: expected center_point_box "
"attribute value to be 0");

// TODO: Add support for optional arguments to be absent.
if (operands.size() < 4)
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected at least 4 arguments");

// TODO: Support multiple batches and classes
// Squeeze the boxes and scores tensor.
// In Onnx, the shape of boxes is [BxNx4] while the
// torchvision expects it to be of shape [Nx4]. Similarly, for
// the scores tensor shape in Onnx is [BxCxN] while the
// torchvision expects it to be of shape [N].
Value boxes = operands[0], scores = operands[1];
FailureOr<Value> squeezedBoxes = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, boxes);
FailureOr<Value> squeezedBoxes =
Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes);
if (failed(squeezedBoxes))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze boxes tensor");

FailureOr<Value> squeezedScores = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, scores);
FailureOr<Value> squeezedScores =
Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores);
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");
squeezedScores = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, squeezedScores.value());
squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0,
squeezedScores.value());
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");

boxes = squeezedBoxes.value();
scores = squeezedScores.value();

// TODO: Support score_threshold input
// Filter out the boxes if the score < score_threshold
if (operands.size() == 5) {
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
operands[4]);
loc, rewriter.getType<Torch::FloatType>(), operands[4]);
Value minScores = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
loc,
Torch::ValueTensorType::get(binder.op->getContext(),
SmallVector<int64_t>{},
rewriter.getF32Type()),
scores);
minScores = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), minScores);
loc, rewriter.getType<Torch::FloatType>(), minScores);

Value scoresCond = rewriter.create<Torch::AtenGeFloatOp>(
binder.getLoc(), minScores, scoreThreshold);
loc, minScores, scoreThreshold);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), scoresCond,
loc, scoresCond,
rewriter.getStringAttr(
"unimplemented: score_threshold should be <= min(scores)"));
}

// TODO: Support default iou_threshold
Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
// Get max_output_boxes_per_class and iou_threshold
Value cst0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value maxOutputBoxesPerClass = cst0;
Value iouThreshold = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(0.0));
if (operands.size() > 3 &&
!isa<Torch::NoneType>(operands[3].getType())) {
iouThreshold = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::FloatType>(), operands[3]);
}
if (operands.size() > 2 &&
!isa<Torch::NoneType>(operands[2].getType())) {
maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), operands[2]);
}

auto nmsTy = Torch::ValueTensorType::get(
binder.op->getContext(), SmallVector<int64_t>{-1},
rewriter.getIntegerType(64, /*signed=*/true));
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
loc, nmsTy, boxes, scores, iouThreshold);

// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
Value numOutputBoxes =
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
Value boxesCond = rewriter.create<Torch::AtenGtIntOp>(
loc, numOutputBoxes, maxOutputBoxesPerClass);

auto nmsResultTy = Torch::ValueTensorType::get(
binder.op->getContext(),
SmallVector<int64_t>{resultType.getSizes()[0]},
rewriter.getIntegerType(64, /*signed=*/true));
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
binder.getLoc(), nmsTy, boxes, scores, iouThreshold);
auto ifSlice = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({nmsResultTy}), boxesCond);
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifSlice.getThenRegion(),
ifSlice.getThenRegion().begin());

Value curResult = rewriter.create<Torch::AtenSliceTensorOp>(
loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0,
/*end=*/maxOutputBoxesPerClass, /*step=*/cst1);
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
}
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifSlice.getElseRegion(),
ifSlice.getElseRegion().begin());

Value curResult = rewriter.create<Torch::TensorStaticInfoCastOp>(
loc, nmsResultTy, result);
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
}
result = ifSlice.getResult(0);

// The result generated by torchvision.nms op is of shape [n], while the
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
// and make it of shape [n, 1] and then concatenate it with a zero
// tensor of shape [n, 2] to make it of shape [n, 3].
Value dim = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
FailureOr<Value> unsqueezedResult =
Torch::unsqueezeTensor(rewriter, binder.op, result, dim);
Torch::unsqueezeTensor(rewriter, binder.op, result, cst1);
if (failed(unsqueezedResult))
return rewriter.notifyMatchFailure(
binder.op, "failed to unsqueeze result tensor");
result = unsqueezedResult.value();

Value numOutputBoxes = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), result,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
numOutputBoxes =
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
SmallVector<Value> zerosShapeValues{numOutputBoxes};
zerosShapeValues.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2)));
loc, rewriter.getI64IntegerAttr(2)));
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
zerosShapeValues);

std::optional<ArrayRef<int64_t>> resultShape =
cast<Torch::ValueTensorType>(result.getType()).getOptionalSizes();
if (!resultShape.has_value())
Expand All @@ -3800,33 +3837,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
llvm::SmallVector<int64_t> zerosShape = {resultShape->front(), 2};
auto zerosTy = Torch::ValueTensorType::get(
resultType.getContext(), zerosShape, resultType.getOptionalDtype());
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value zeros = rewriter.create<Torch::AtenZerosOp>(
binder.getLoc(), zerosTy, zerosShapeList, cstNone, cstNone, cstNone,
cstNone);
loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone);

Type listElemType =
cast<Torch::BaseTensorType>(resultType)
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), listType, SmallVector<Value>{zeros, result});

// TODO: Support max_output_boxes_per_class input
// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
Value maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[2]);
Value boxesCond = rewriter.create<Torch::AtenLeIntOp>(
binder.getLoc(), numOutputBoxes, maxOutputBoxesPerClass);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), boxesCond,
rewriter.getStringAttr(
"unimplemented: number of output boxes per class should be "
"<= max_output_boxes_per_class"));

loc, listType, SmallVector<Value>{zeros, result});
rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
tensorList, dim);
tensorList, cst1);
return success();
});
}
10 changes: 8 additions & 2 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,15 +727,21 @@ class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
// Check the matrixs shapes are valid for mulplication.
checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1);

Type accumulatorDType = getDefaultAccType(rewriter, resultElementType);
Value initTensor0 = createZeroInitTensor(
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2},
resultElementType);
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, accumulatorDType);

Value bmm =
rewriter
.create<linalg::BatchMatmulOp>(loc, initTensor0.getType(),
ValueRange{lhs, rhs}, initTensor0)
.getResult(0);

if (accumulatorDType != resultElementType) {
bmm = torch_to_linalg::convertTensorToElementType(rewriter, loc, bmm,
resultElementType);
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, bmm);
return success();
}
Expand Down
Loading
Loading