Skip to content

Commit 96fcde4

Browse files
authored
[Torch Dialect] Support Einsum Op (#2230)
As title, support torch.aten.einsum op Right now only support Static Shape, because of the known issue, the fixed solution is here: #2154 Co-authored-by: Jiawei Wu [[email protected]](mailto:[email protected])
1 parent 07c3e11 commit 96fcde4

File tree

8 files changed

+554
-0
lines changed

8 files changed

+554
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8447,6 +8447,31 @@ def Torch_AtenOneHotOp : Torch_Op<"aten.one_hot", [
84478447
}];
84488448
}
84498449

8450+
def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [
8451+
AllowsTypeRefinement,
8452+
HasValueSemantics,
8453+
ReadOnly
8454+
]> {
8455+
let summary = "Generated op for `aten::einsum : (str, Tensor[], int[]?) -> (Tensor)`";
8456+
let arguments = (ins
8457+
Torch_StringType:$equation,
8458+
AnyTorchListOfTensorType:$tensors,
8459+
AnyTorchOptionalListOfTorchIntType:$path
8460+
);
8461+
let results = (outs
8462+
AnyTorchTensorType:$result
8463+
);
8464+
let hasCustomAssemblyFormat = 1;
8465+
let extraClassDefinition = [{
8466+
ParseResult AtenEinsumOp::parse(OpAsmParser &parser, OperationState &result) {
8467+
return parseDefaultTorchOp(parser, result, 3, 1);
8468+
}
8469+
void AtenEinsumOp::print(OpAsmPrinter &printer) {
8470+
printDefaultTorchOp(printer, *this, 3, 1);
8471+
}
8472+
}];
8473+
}
8474+
84508475
def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [
84518476
AllowsTypeRefinement,
84528477
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11321,6 +11321,33 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1132111321
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
1132211322
" return %5 : !torch.int\n"
1132311323
" }\n"
11324+
" func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list<tuple<int, int>>, %arg2: !torch.optional<list<int>>) -> !torch.int {\n"
11325+
" %true = torch.constant.bool true\n"
11326+
" %none = torch.constant.none\n"
11327+
" %str = torch.constant.str \"AssertionError: \"\n"
11328+
" %int0 = torch.constant.int 0\n"
11329+
" %0 = torch.prim.ListConstruct : () -> !torch.list<optional<int>>\n"
11330+
" %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
11331+
" %2 = torch.aten.len.t %arg1 : !torch.list<tuple<int, int>> -> !torch.int\n"
11332+
" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
11333+
" torch.prim.If %3 -> () {\n"
11334+
" torch.prim.If.yield\n"
11335+
" } else {\n"
11336+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
11337+
" torch.prim.If.yield\n"
11338+
" }\n"
11339+
" %4 = torch.aten.len.t %arg1 : !torch.list<tuple<int, int>> -> !torch.int\n"
11340+
" torch.prim.Loop %4, %true, init() {\n"
11341+
" ^bb0(%arg3: !torch.int):\n"
11342+
" %6 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
11343+
" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11344+
" %8 = torch.aten.append.t %0, %7#0 : !torch.list<optional<int>>, !torch.int -> !torch.list<optional<int>>\n"
11345+
" %9 = torch.aten.append.t %1, %7#1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
11346+
" torch.prim.Loop.condition %true, iter()\n"
11347+
" } : (!torch.int, !torch.bool) -> ()\n"
11348+
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
11349+
" return %5 : !torch.int\n"
11350+
" }\n"
1132411351
" func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1132511352
" %int4 = torch.constant.int 4\n"
1132611353
" return %int4 : !torch.int\n"

0 commit comments

Comments
 (0)