From e06efc51363227f425ef40e8b5bf0f76e0cfb592 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 21 Nov 2023 21:02:55 -0800 Subject: [PATCH] Initial TorchOnnxToTorch conversion pipeline. (#2585) Adds a pipeline to convert custom ops and metadata represented as `torch.operator` custom ops to corresponding `torch` ops where possible. This is part of a multi-part approach for building ONNX import in as a regular feature of torch-mlir. It is focused on the conversions vs the infra. We will end up maintaining a [pure-python importer](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py) to go with this in torch-mlir, and we will also maintain test case generation utilities derived from it. I have left substantial documentation in the README of the conversion directory, including the recommended approach that we will take to keep building this out. (note that this organizes the code to coincide with the refactoring in #2442 versus the current flat arrangement) --- include/torch-mlir/Conversion/CMakeLists.txt | 2 + .../TorchOnnxToTorch/CMakeLists.txt | 4 + .../Conversion/TorchOnnxToTorch/Passes.h | 27 +++ .../Conversion/TorchOnnxToTorch/Passes.td | 26 +++ .../Conversion/TorchOnnxToTorch/Patterns.h | 169 ++++++++++++++++++ .../Conversion/TorchOnnxToTorch/README.md | 133 ++++++++++++++ lib/CMakeLists.txt | 13 +- lib/Conversion/CMakeLists.txt | 1 + .../TorchOnnxToTorch/CMakeLists.txt | 19 ++ .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 146 +++++++++++++++ .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 29 +++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 29 +++ lib/Conversion/TorchOnnxToTorch/PassDetail.h | 24 +++ lib/Conversion/TorchOnnxToTorch/Passes.cpp | 19 ++ lib/Conversion/TorchOnnxToTorch/Patterns.cpp | 57 ++++++ .../TorchOnnxToTorch/TorchOnnxToTorch.cpp | 87 +++++++++ lib/InitAll.cpp | 3 +- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 97 ++++++++++ .../unsupported_simple_ops.mlir | 18 ++ 19 files changed, 897 insertions(+), 6 deletions(-) create mode 100644 include/torch-mlir/Conversion/TorchOnnxToTorch/CMakeLists.txt create mode 100644 include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h create mode 100644 include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td create mode 100644 include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h create mode 100644 include/torch-mlir/Conversion/TorchOnnxToTorch/README.md create mode 100644 lib/Conversion/TorchOnnxToTorch/CMakeLists.txt create mode 100644 lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp create mode 100644 lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp create mode 100644 lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp create mode 100644 lib/Conversion/TorchOnnxToTorch/PassDetail.h create mode 100644 lib/Conversion/TorchOnnxToTorch/Passes.cpp create mode 100644 lib/Conversion/TorchOnnxToTorch/Patterns.cpp create mode 100644 lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp create mode 100644 test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir create mode 100644 test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir diff --git a/include/torch-mlir/Conversion/CMakeLists.txt b/include/torch-mlir/Conversion/CMakeLists.txt index d6552314999b..c2e757f7a0ff 100644 --- a/include/torch-mlir/Conversion/CMakeLists.txt +++ b/include/torch-mlir/Conversion/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(TorchOnnxToTorch) + set(LLVM_TARGET_DEFINITIONS Passes.td) if(TORCH_MLIR_ENABLE_STABLEHLO) mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/CMakeLists.txt b/include/torch-mlir/Conversion/TorchOnnxToTorch/CMakeLists.txt new file mode 100644 index 000000000000..a58ce5bf9b7d --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(TorchMLIRConversionTorchOnnxToTorchPassIncGen) +add_mlir_doc(Passes TorchMLIRConversionTorchOnnxToTorchPasses ./ -gen-pass-doc) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h new file mode 100644 index 000000000000..6eea35c9d255 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h @@ -0,0 +1,27 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H +#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir::torch::onnx_c { + +std::unique_ptr> createTorchOnnxToTorchPass(); + +/// Registers all torch-mlir conversion passes. +void registerTorchOnnxToTorchPasses(); + +} // namespace mlir::torch::onnx_c + +#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td b/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td new file mode 100644 index 000000000000..b92649d025a6 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td @@ -0,0 +1,26 @@ +//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES +#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTorchOnnxToTorch : Pass<"convert-torch-onnx-to-torch", "func::FuncOp"> { + let summary = "Converts ONNX custom ops in the torch dialect to native torch ops"; + let description = [{ + Converts equivalent ONNX custom ops to built-in equivalents. + + See the README for a detailed description of how this operates. + }]; + + let constructor = "mlir::torch::onnx_c::createTorchOnnxToTorchPass()"; +} + +#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h new file mode 100644 index 000000000000..4c8d73a48116 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -0,0 +1,169 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H +#define TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir::torch::onnx_c { + +/// Used during ONNX pattern matching to bind common patterns of operands, +/// result types and attributes to local variables in a way that is easy +/// to fail the pattern if constraints are violated. Most methods return +/// a ParseResult, which allows for chaining like: +/// +/// if (binder.tensorOperand(foo) || binder.tensorResultType(t)) +/// return failure(); +struct OpBinder { + OpBinder(Operation *op) : op(op) {} + + Location getLoc() { return op->getLoc(); } + + // Operand matches of different arities. + ParseResult tensorOperand(Value &value0) { + if (op->getNumOperands() != 1) + return failure(); + value0 = op->getOperand(0); + if (!toValidTensorType(value0.getType())) + return failure(); + return success(); + } + + ParseResult tensorOperands(Value &value0, Value &value1) { + if (op->getNumOperands() != 2) + return failure(); + value0 = op->getOperand(0); + value1 = op->getOperand(1); + if (!toValidTensorType(value0.getType()) || + !toValidTensorType(value1.getType())) + return failure(); + return success(); + } + + // Result type matchers of different arities. + ParseResult tensorResultType(Torch::ValueTensorType &type0) { + if (op->getNumResults() != 1) + return failure(); + auto t = toValidTensorType(op->getResult(0).getType()); + if (!t) + return failure(); + type0 = t; + return success(); + } + + // Attribute accessors. + ParseResult s64BoolAttr(bool &value, StringRef nameSuffix, + bool defaultValue = false) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + value = defaultValue; + return success(); + } + if (auto integerAttr = dyn_cast(attr)) { + IntegerType t = cast(integerAttr.getType()); + if (!t.isSigned() || t.getWidth() != 64) + return failure(); + value = static_cast(integerAttr.getSInt()); + return success(); + } + return failure(); + } + + ParseResult s64IntegerAttr(int64_t &value, StringRef nameSuffix, + int64_t defaultValue = 0) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + value = defaultValue; + return success(); + } + if (auto integerAttr = dyn_cast(attr)) { + IntegerType t = cast(integerAttr.getType()); + if (!t.isSigned() || t.getWidth() != 64) + return failure(); + value = integerAttr.getSInt(); + return success(); + } + return failure(); + } + + Torch::ValueTensorType toValidTensorType(Type t) { + auto tt = dyn_cast(t); + if (tt && tt.hasSizes()) + return tt; + return {}; + } + + Operation *op; +}; + +/// We use a single pattern per ONNX domain to handle all named custom +/// ops. +/// This allows us to avoid the n^2 problem on pattern application by +/// implementing a secondary index based on the name and sinceVersion +/// attributes. +/// It also lets us add some ergonomics for trivial cases. +class OnnxCustomOpConversionPattern + : public OpConversionPattern { +public: + using HandlerFn = LogicalResult (*)(OpBinder binder, + ConversionPatternRewriter &rewriter); + struct HandlerReg { + HandlerReg(HandlerFn callback, int64_t sinceVersion) + : callback(callback), sinceVersion(sinceVersion) {} + HandlerFn callback; + int64_t sinceVersion; + }; + + OnnxCustomOpConversionPattern(MLIRContext *context, std::string domainPrefix, + int64_t domainVersion) + : OpConversionPattern(context), domainPrefix(std::move(domainPrefix)), + domainVersion(domainVersion) {} + + LogicalResult + matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + + /// Adds all fully qualified operator names to the given set. + /// This is typically used for implementing a dynamic legality + /// check for torch.operator names. + void populateLegalizedNames(DenseSet &legalizedNames); + + /// Register a conversion for a specific ONNX operator. For the + /// default domain, this is the canonical ONNX operator name (i.e. + /// "Acos"). + /// Multiple conversions can be registered for the same op, most + /// commonly differing by their `sinceVersion`. + void onOp(StringRef name, int64_t sinceVersion, HandlerFn callback); + +private: + std::string domainPrefix; + int64_t domainVersion; + DenseMap> namedHandlers; +}; + +// Patterns are split into chunks to speed compile time and reduce some +// contention on the same source files. +void populateDefaultDomainAtoF(OnnxCustomOpConversionPattern &patterns); +void populateDefaultDomainGtoP(OnnxCustomOpConversionPattern &patterns); +void populateDefaultDomainQtoZ(OnnxCustomOpConversionPattern &patterns); + +} // namespace mlir::torch::onnx_c + +#endif // TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md b/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md new file mode 100644 index 000000000000..6de1cc923411 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md @@ -0,0 +1,133 @@ +# TorchOnnx To Torch Conversions + +We enable the direct representation of many ONNX features directly in +the `torch` dialect as `torch.operator` custom ops with names like +`onnx.{OperatorName}`. The majority of ONNX operators are represented +with a systematic transformation. See +[onnx_importer.py](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py) +for the reference importer which complies with the rules below +(this is planned to be upstreamed to torch-mlir proper in the near +future). + +## Adding new ONNX operators + +With the exception of certain special or complicated ONNX operators, most +are relatively straight-forward to map, following this general procedure: + +* Plan the ops you wish to support by consulting the + [ONNX operator database](https://onnx.ai/onnx/operators/). + * This database has detailed diffs wrt different support versions but + at the level of detail we operate, most version diffs are inconsequential + and just require a bit more pattern support. + * This typically applies to generalization of broadcasting semantics, + expanded type support, and other things of the like. +* *Prerequisite*: Add support for the op to torch-mlir if it does not + already exist. +* Open the corresponding implementation file `DefaultDomainXtoY.cpp` + corresponding with the alphabetic sort of the op and add a conversion. +* Generate successful test cases: + * Either run the Turbine importer to produce MLIR output for all + ops/models in the ONNX test suite or use a dump that someone has + generated: + * [2023-Nov-21](https://drive.google.com/file/d/1P6QaRXGnCeApjdjNmykLxWa-yqMmIO-d/view?usp=sharing) + * There are often many variants of tests for checking conformance of + different historic ONNX encodings, but these are often not load bearing + at the MLIR level. + * Pick a handful of test cases and add them to + `test/Conversion/TorchOnnxToTorch/simple_ops_x_to_y.mlir` corresponding to an + alphabetic breakdown. At this time, ignore tests that are not exercising + useful differences in the pattern implementations. +* Generate failure test cases: + * Some ops have forms that do not (easily) map to torch-mlir. If you leave + an op under-implemented, add a failing test case to + `test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir`. +* Optional but recommended: Use your test case files to fuzz against the + torch-mlir backend of your choice by running a backend conversion pipeline + and fixing any crashes/issues. +* Send a patch with your changes. + +## ONNX proto to `torch` dialect mapping + +### Type Conversion + +* Tensors: ONNX tensor types are converted to `torch.vtensor` + with static and dynamic dimensions. We require that shape + inference has run to produce ranked tensors. +* Tensor element types are directly converted to corresponding + MLIR types as used by the rest of torch-mlir. +* String, sequence and sparse tensor types are presently not mapped. + +### Attributes + +A subset of attributes types are converted directly to an attribute +dict on the op with a name like `torch.onnx.{AttributeName}`. The +following attribute type mappings are made: + +* `FLOAT`: `FloatAttr` +* `INT`: Signed `IntegerAttr` of width 64 +* `STRING`: `StringAttr` +* `TENSOR`: Converted to one of: + * `DenseResourceElementsAttr` for inlined `raw_data` + * `DenseElementsAttr` for splats + * `DenseElementsAttr` for inlined typed proto initialization +* `FLOATS`: `ArrayAttr` of `FloatAttr` +* `INTS`: `ArrayAttr` of signed `IntegerAttr` of width 64 +* `STRINGS`: `ArrayAttr` of `StringAttr` +* `TENSORS`: `ArrayAttr` of corresponding `TENSOR` conversion + +The following attribute types have no present, systematic conversion. +Their presence on an op indicates that the op is a special form, which +must be handled specially: + +* `GRAPH` +* `SPARSE_TENSOR` (TBD: it is possible to handle this systematically if + useful). +* `TYPE_PROTO` (TBD: it may be possible to handle this systematically if + useful). +* Plural equivalents of the above. + +### Default operation conversion + +Operations are converted to a `torch.operator` with name `onnx.{OperatorName}`. +The constraint that the ONNX graph is topologically sorted and free of +cycles matches the SSA form. Operands and results are mapped directly. + +This conversion only applies to the default (empty) domain. + +### Quantization information + +Quantization parameters are carried out of line in the ONNX protobuf +and will be repatriated upon import to torch. The exact mechanism is +not yet implemented. + +### Version and metadata + +The `IsolatedFromAbove` parent of the ops can contain the following +metadata: + +* `torch.onnx_meta.ir_version`: 64bit `IntegerAttr` corresponding to + `ModelProto.ir_version`. +* `torch.onnx_meta.producer_name`: `StringAttr` corresponding to + `ModelProto.producer_name`. +* `torch.onnx_meta.producer_version`: `StringAttr` corresponding to + `ModelProto.producer_version`. +* `torch.onnx_meta.opset_version`: 64bit `IntegerAttr` corresponding + to `ModelProto.opset_import.version` for the domain "" (empty). + Will be ommitted if the default opset is not included. +* `torch.onnx_meta.opset_versions`: DictAttr of 64bit `IntegerAttr` + for each non default domain. + +Generally, the importer handles variations in `ir_version` whereas +the transformations here handle opset version differences. Version +independent transformations are encouraged where possible if there +are only minor variations of an op. Major variations should use +`since_version` sensitive patterns. + +### Special op forms + +Certain ONNX operators map to different structural components of +torch-mlir's representation: + +* `ConstantOfShape`: Mapped to `torch.vtensor.literal` with + a corresponding `value` attribute. + diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 8956066b8769..d9030c23a66f 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -14,16 +14,19 @@ set(LinkedLibs MLIRTosaDialect MLIRSupport - TorchMLIRTorchPasses - TorchMLIRTorchConversionDialect - + # Dialects. + TorchMLIRTMTensorDialect TorchMLIRTorchDialect - TorchMLIRTorchConversionPasses + TorchMLIRTorchConversionDialect + # Dialect passes. TorchMLIRTMTensorPasses - TorchMLIRTMTensorDialect + TorchMLIRTorchConversionPasses + TorchMLIRTorchPasses + # Conversion passes. TorchMLIRConversionPasses + TorchMLIRTorchOnnxToTorch ) if(TORCH_MLIR_ENABLE_REFBACKEND) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index f26b4d6e895e..afbe775d3a20 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(TorchOnnxToTorch) add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) add_subdirectory(TorchToArith) diff --git a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt new file mode 100644 index 000000000000..807db64eac64 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch + DefaultDomainAtoF.cpp + DefaultDomainGtoP.cpp + DefaultDomainQtoZ.cpp + Passes.cpp + Patterns.cpp + TorchOnnxToTorch.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchOnnxToTorch + + DEPENDS + TorchMLIRConversionTorchOnnxToTorchPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + TorchMLIRTorchDialect +) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp new file mode 100644 index 000000000000..5bcf17a1fd92 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -0,0 +1,146 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +// Simple rewrites for the default domain. +// See: https://onnx.ai/onnx/operators/ +// For operators that are effectively version invariant, we register with +// sinceVersion==1. We interpret this to include the following spec +// diffs that are irrelevant to this level of lowering: +// * Supported element types. +// * Limited broadcasting to full broadcasting support. +// +// There are a lot of spec revisions that basically generalized elementwise +// to be more normal and a direct translation vs a special case. This +// results in a lot of ONNX test cases that all reduce to the exact same +// thing here, so we simplify. +void mlir::torch::onnx_c::populateDefaultDomainAtoF( + OnnxCustomOpConversionPattern &patterns) { + patterns.onOp("Abs", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + // TODO: Acos unimplemented in torch-mlir + // TODO: Acosh unimplemented in torch-mlir + // Add became forward compatible with Torch in version 7. + patterns.onOp("Add", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + Value const1 = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs, const1); + return success(); + }); + // TODO: AffineGrid + patterns.onOp("And", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp( + "ArgMax", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + bool keepDims; + int64_t axis; + bool selectLastIndex; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.s64BoolAttr(keepDims, "keepdims", true) || + binder.s64IntegerAttr(axis, "axis", 0) || + binder.s64BoolAttr(selectLastIndex, "select_last_index", false)) + return failure(); + + if (selectLastIndex) { + // TODO: Figure out how to support this case. Need to add a reverse + // or something. + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: select_last_index=true"); + } + + // ONNX allows negative axis. + if (axis < 0) + axis += + cast(operand.getType()).getSizes().size(); + + Value constAxis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + Value constKeepDims = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getBoolAttr(keepDims)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, constAxis, constKeepDims); + return success(); + }); + patterns.onOp( + "ArgMin", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + bool keepDims; + int64_t axis; + bool selectLastIndex; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.s64BoolAttr(keepDims, "keepdims", true) || + binder.s64IntegerAttr(axis, "axis", 0) || + binder.s64BoolAttr(selectLastIndex, "select_last_index", false)) + return failure(); + + if (selectLastIndex) { + // TODO: Figure out how to support this case. Need to add a reverse + // or something. + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: select_last_index=true"); + } + + // ONNX allows negative axis. + if (axis < 0) + axis += + cast(operand.getType()).getSizes().size(); + + Value constAxis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + Value constKeepDims = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getBoolAttr(keepDims)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, constAxis, constKeepDims); + return success(); + }); + // TODO: Asin unimplemented in torch-mlir + // TODO: Asinh unimplemented in torch-mlir + // TODO: Atan unimplemented in torch-mlir + // TODO: Atanh unimplemented in torch-mlir +} diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp new file mode 100644 index 000000000000..af4f06fdef77 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -0,0 +1,29 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +// Simple rewrites for the default domain. +// See: https://onnx.ai/onnx/operators/ +// For operators that are effectively version invariant, we register with +// sinceVersion==1. We interpret this to include the following spec +// diffs that are irrelevant to this level of lowering: +// * Supported element types. +// * Limited broadcasting to full broadcasting support. +// +// There are a lot of spec revisions that basically generalized elementwise +// to be more normal and a direct translation vs a special case. This +// results in a lot of ONNX test cases that all reduce to the exact same +// thing here, so we simplify. +void mlir::torch::onnx_c::populateDefaultDomainGtoP( + OnnxCustomOpConversionPattern &patterns) {} diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp new file mode 100644 index 000000000000..23af89f329ab --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -0,0 +1,29 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +// Simple rewrites for the default domain. +// See: https://onnx.ai/onnx/operators/ +// For operators that are effectively version invariant, we register with +// sinceVersion==1. We interpret this to include the following spec +// diffs that are irrelevant to this level of lowering: +// * Supported element types. +// * Limited broadcasting to full broadcasting support. +// +// There are a lot of spec revisions that basically generalized elementwise +// to be more normal and a direct translation vs a special case. This +// results in a lot of ONNX test cases that all reduce to the exact same +// thing here, so we simplify. +void mlir::torch::onnx_c::populateDefaultDomainQtoZ( + OnnxCustomOpConversionPattern &patterns) {} diff --git a/lib/Conversion/TorchOnnxToTorch/PassDetail.h b/lib/Conversion/TorchOnnxToTorch/PassDetail.h new file mode 100644 index 000000000000..bbcd3413c59c --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/PassDetail.h @@ -0,0 +1,24 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H +#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::torch::onnx_c { + +#define GEN_PASS_CLASSES +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc" + +} // namespace mlir::torch::onnx_c + +#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H diff --git a/lib/Conversion/TorchOnnxToTorch/Passes.cpp b/lib/Conversion/TorchOnnxToTorch/Passes.cpp new file mode 100644 index 000000000000..1f8cb05fa02c --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/Passes.cpp @@ -0,0 +1,19 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" + +namespace { +#define GEN_PASS_REGISTRATION +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc" +} // end namespace + +void mlir::torch::onnx_c::registerTorchOnnxToTorchPasses() { + ::registerPasses(); +} diff --git a/lib/Conversion/TorchOnnxToTorch/Patterns.cpp b/lib/Conversion/TorchOnnxToTorch/Patterns.cpp new file mode 100644 index 000000000000..6ca7824165d3 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/Patterns.cpp @@ -0,0 +1,57 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using llvm::dbgs; +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +#define DEBUG_TYPE "torch-onnx" + +LogicalResult OnnxCustomOpConversionPattern::matchAndRewrite( + Torch::OperatorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto foundIt = namedHandlers.find(op.getNameAttr()); + if (foundIt == namedHandlers.end()) + return failure(); + auto ®gies = foundIt->second; + for (const HandlerReg ® : reggies) { + if (domainVersion < reg.sinceVersion) { + LLVM_DEBUG(dbgs() << ": skipping conversion " << foundIt->first + << ", sinceVersion=" << reg.sinceVersion + << ", for domainVersion=" << domainVersion << "\n"); + continue; + } + if (succeeded(reg.callback(OpBinder(op), rewriter))) { + return success(); + } else { + LLVM_DEBUG(dbgs() << ": conversion failed to apply: " << foundIt->first + << ", sinceVersion=" << reg.sinceVersion << "\n"); + } + } + return rewriter.notifyMatchFailure(op, "no matching versioned converter"); +} + +void OnnxCustomOpConversionPattern::populateLegalizedNames( + DenseSet &legalizedNames) { + for (auto it : namedHandlers) + legalizedNames.insert(it.first); +} + +void OnnxCustomOpConversionPattern::onOp(StringRef name, int64_t sinceVersion, + HandlerFn callback) { + SmallString<64> fullName(domainPrefix); + fullName.append(name); + StringAttr nameAttr = StringAttr::get(getContext(), fullName); + namedHandlers[nameAttr].push_back(HandlerReg(callback, sinceVersion)); +} diff --git a/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp b/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp new file mode 100644 index 000000000000..ea890bf0f4b6 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp @@ -0,0 +1,87 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "./PassDetail.h" +#include "mlir/Support/LLVM.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using llvm::dbgs; +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +#define DEBUG_TYPE "torch-onnx" + +namespace { + +int64_t getDefaultOpsetVersion(Operation *containerOp) { + auto attr = + containerOp->getAttrOfType("torch.onnx_meta.opset_version"); + if (!attr) + return 0; + if (auto type = dyn_cast(attr.getType())) { + if (!type || !type.isSigned()) + return 0; + } + return attr.getSInt(); +} + +class ConvertTorchOnnxToTorch + : public ConvertTorchOnnxToTorchBase { +public: + ConvertTorchOnnxToTorch() = default; + void runOnOperation() override { + MLIRContext *context = &getContext(); + + // Populate our patterns for each handled domain. + int64_t defaultOpsetVersion = getDefaultOpsetVersion(getOperation()); + if (defaultOpsetVersion == 0) { + emitError(getOperation().getLoc()) + << "function is missing onnx opset version attribute " + "(torch.onnx_meta.opset_version)"; + return signalPassFailure(); + } + + auto defaultDomainPatterns = + std::make_unique( + context, "onnx.", + /*domainVersion=*/defaultOpsetVersion); + populateDefaultDomainAtoF(*defaultDomainPatterns); + populateDefaultDomainGtoP(*defaultDomainPatterns); + populateDefaultDomainQtoZ(*defaultDomainPatterns); + + // Ask each domain for its handled names and configure the + // conversion target. + ConversionTarget target(*context); + DenseSet legalizedNames; + defaultDomainPatterns->populateLegalizedNames(legalizedNames); + target.addLegalDialect(); + target.addDynamicallyLegalOp([&](Torch::OperatorOp op) { + return !legalizedNames.contains(op.getNameAttr()); + }); + + RewritePatternSet patterns(context); + patterns.insert(std::move(defaultDomainPatterns)); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +mlir::torch::onnx_c::createTorchOnnxToTorchPass() { + return std::make_unique(); +} diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index 0be0ec8ba3ea..ace6c1a40e74 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -22,6 +22,7 @@ #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h" #include "torch-mlir/Conversion/Passes.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" @@ -47,8 +48,8 @@ void mlir::torch::registerOptionalInputDialects( void mlir::torch::registerAllPasses() { mlir::torch::registerTorchPasses(); mlir::torch::registerTorchConversionPasses(); - mlir::torch::registerConversionPasses(); + mlir::torch::onnx_c::registerTorchOnnxToTorchPasses(); mlir::torch::TMTensor::registerPasses(); #ifdef TORCH_MLIR_ENABLE_REFBACKEND diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir new file mode 100644 index 000000000000..e2123ac5e057 --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -0,0 +1,97 @@ +// RUN: torch-mlir-opt <%s -convert-torch-onnx-to-torch | FileCheck %s +// Generally, the test cases accumulated here come from running the importer +// over all included backend tests that involve simple ops with no model +// level constants. This is a pragmatic choice which lets us have a lot +// of tests in this file, whereas the others tend to be more bespoke. + +// CHECK-LABEL: func.func @test_abs +func.func @test_abs(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.abs %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Abs"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_add +func.func @test_add(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.add.Tensor %arg0, %arg1, %[[INT1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Add"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_add_bcast +func.func @test_add_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.add.Tensor %arg0, %arg1, %[[INT1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Add"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_add_uint8 +func.func @test_add_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.add.Tensor %arg0, %arg1, %[[INT1]] : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8>, !torch.int -> !torch.vtensor<[3,4,5],ui8> + %0 = torch.operator "onnx.Add"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> + return %0 : !torch.vtensor<[3,4,5],ui8> +} + +// CHECK-LABEL: @test_and_bcast3v1d +func.func @test_and_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.And"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// CHECK-LABEL: @test_argmax_default_axis_example +func.func @test_argmax_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 0 + // CHECK: %[[BOOL:.*]] = torch.constant.bool true + // CHECK: torch.aten.argmax %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,2],si64> + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> + return %0 : !torch.vtensor<[1,2],si64> +} + +// CHECK-LABEL: @test_argmax_negative_axis_keepdims_example +func.func @test_argmax_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 1 + // CHECK: %[[BOOL:.*]] = torch.constant.bool true + // CHECK: torch.aten.argmax %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,1],si64> + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> + return %0 : !torch.vtensor<[2,1],si64> +} + +// CHECK-LABEL: @test_argmax_no_keepdims_example +func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 1 + // CHECK: %[[BOOL:.*]] = torch.constant.bool false + // CHECK: torch.aten.argmax %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2],si64> + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} + +// CHECK-LABEL: @test_argmin_default_axis_example +func.func @test_argmin_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 0 + // CHECK: %[[BOOL:.*]] = torch.constant.bool true + // CHECK: torch.aten.argmin %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,2],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> + return %0 : !torch.vtensor<[1,2],si64> +} + +// CHECK-LABEL: @test_argmin_negative_axis_keepdims_example +func.func @test_argmin_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 1 + // CHECK: %[[BOOL:.*]] = torch.constant.bool true + // CHECK: torch.aten.argmin %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,1],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> + return %0 : !torch.vtensor<[2,1],si64> +} + +// CHECK-LABEL: @test_argmin_no_keepdims_example +func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 1 + // CHECK: %[[BOOL:.*]] = torch.constant.bool false + // CHECK: torch.aten.argmin %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir new file mode 100644 index 000000000000..22d5e2d35183 --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir @@ -0,0 +1,18 @@ +// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch + +module { + func.func @test_argmax_no_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // TODO: Unsupported torch.onnx.select_last_index + // expected-error @+1 {{failed to legalize operation 'torch.operator'}} + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> + return %0 : !torch.vtensor<[2,4],si64> + } +} + +// ----- +func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // TODO: Unsupported torch.onnx.select_last_index + // expected-error @+1 {{failed to legalize operation 'torch.operator'}} + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +}