diff --git a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp index 1d9244df..258d5a7c 100644 --- a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp +++ b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp @@ -1182,10 +1182,25 @@ struct MatmulConverter : public OpConversionPattern { rewriter.create(loc, ValueRange{zero}, ValueRange{init}) .result(); - auto res = rewriter - .create(loc, ValueRange{opa, opb}, - ValueRange{zeroes}) - .getResult(0); + Value res; + auto rank = dstType.getRank(); + + if (rank == 2) { + // Standard matmul + res = rewriter + .create(loc, ValueRange{opa, opb}, + ValueRange{zeroes}) + .getResult(0); + } else if (rank == 3) { + // Batched matmul + res = rewriter + .create(loc, ValueRange{opa, opb}, + ValueRange{zeroes}) + .getResult(0); + } else { + return rewriter.notifyMatchFailure( + op, "Only 2D or 3D inputs supported for tt.dot lowering"); + } if (!skipC) { if (integers) {