From 8f70d55d0d11b1c369382dc645e646942203b7a0 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 9 Sep 2025 12:30:35 +0200 Subject: [PATCH] Add support for tt.dot using 3d tensors --- .../ConversionPatterns.hpp | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp index 1d9244df..258d5a7c 100644 --- a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp +++ b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp @@ -1182,10 +1182,25 @@ struct MatmulConverter : public OpConversionPattern { rewriter.create(loc, ValueRange{zero}, ValueRange{init}) .result(); - auto res = rewriter - .create(loc, ValueRange{opa, opb}, - ValueRange{zeroes}) - .getResult(0); + Value res; + auto rank = dstType.getRank(); + + if (rank == 2) { + // Standard matmul + res = rewriter + .create(loc, ValueRange{opa, opb}, + ValueRange{zeroes}) + .getResult(0); + } else if (rank == 3) { + // Batched matmul + res = rewriter + .create(loc, ValueRange{opa, opb}, + ValueRange{zeroes}) + .getResult(0); + } else { + return rewriter.notifyMatchFailure( + op, "Only 2D or 3D inputs supported for tt.dot lowering"); + } if (!skipC) { if (integers) {