@@ -11321,6 +11321,33 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
11321
11321
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
11322
11322
" return %5 : !torch.int\n"
11323
11323
" }\n"
11324
+ " func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list<tuple<int, int>>, %arg2: !torch.optional<list<int>>) -> !torch.int {\n"
11325
+ " %true = torch.constant.bool true\n"
11326
+ " %none = torch.constant.none\n"
11327
+ " %str = torch.constant.str \"AssertionError: \"\n"
11328
+ " %int0 = torch.constant.int 0\n"
11329
+ " %0 = torch.prim.ListConstruct : () -> !torch.list<optional<int>>\n"
11330
+ " %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
11331
+ " %2 = torch.aten.len.t %arg1 : !torch.list<tuple<int, int>> -> !torch.int\n"
11332
+ " %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
11333
+ " torch.prim.If %3 -> () {\n"
11334
+ " torch.prim.If.yield\n"
11335
+ " } else {\n"
11336
+ " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
11337
+ " torch.prim.If.yield\n"
11338
+ " }\n"
11339
+ " %4 = torch.aten.len.t %arg1 : !torch.list<tuple<int, int>> -> !torch.int\n"
11340
+ " torch.prim.Loop %4, %true, init() {\n"
11341
+ " ^bb0(%arg3: !torch.int):\n"
11342
+ " %6 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
11343
+ " %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11344
+ " %8 = torch.aten.append.t %0, %7#0 : !torch.list<optional<int>>, !torch.int -> !torch.list<optional<int>>\n"
11345
+ " %9 = torch.aten.append.t %1, %7#1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
11346
+ " torch.prim.Loop.condition %true, iter()\n"
11347
+ " } : (!torch.int, !torch.bool) -> ()\n"
11348
+ " %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
11349
+ " return %5 : !torch.int\n"
11350
+ " }\n"
11324
11351
" func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
11325
11352
" %int4 = torch.constant.int 4\n"
11326
11353
" return %int4 : !torch.int\n"
0 commit comments