diff --git a/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp b/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp index e8a3b7f6..1c1c2195 100644 --- a/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp +++ b/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp @@ -14,6 +14,7 @@ // from_memref accordingly. //===----------------------------------------------------------------------===// +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Ptr/IR/PtrTypes.h" @@ -30,6 +31,8 @@ #include "triton-shared/Conversion/TritonToLinalgExperimental/ReconcilePtrCasts.h" #include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" using namespace mlir; using namespace triton; @@ -82,6 +85,7 @@ struct FromMemrefConverter return failure(); } + Location loc = op.getLoc(); auto input = op.getInputs().front(); auto unrankedInput = dyn_cast(input.getType()); auto output = op.getResult(0); @@ -90,9 +94,30 @@ struct FromMemrefConverter if (unrankedInput && isa(outType)) { // from_memref only takes ranked memref, cast the unranked memref to // ranked memref first. - auto rankedMemref = rewriter.create( + Value rankedMemref = rewriter.create( op.getLoc(), MemRefType::get({1}, unrankedInput.getElementType()), input); + + // The output of the cast may have a different element type than the + // input memref. In that case we should use an unrealized conversion cast + // to match the element type. + Type elementTy; + if (auto ttPtr = dyn_cast(outType)) + elementTy = ttPtr.getPointeeType(); + else if (auto genericPtr = dyn_cast(outType)) + elementTy = genericPtr.getElementType(); + else + return failure(); + + if (elementTy && elementTy != unrankedInput.getElementType()) { + // Insert an unrealized conversion cast to match element type + auto castOp = rewriter.create( + loc, MemRefType::get({1}, elementTy), rankedMemref); + + // Use the result of the cast op + rankedMemref = castOp.getResult(0); + } + auto memrefToPtr = rewriter.create( op->getLoc(), ptr::PtrType::get( @@ -110,6 +135,90 @@ struct FromMemrefConverter } }; +static std::optional mapTritonToMLIR(uint32_t triKind) { + switch (triKind) { + case 1: + return arith::AtomicRMWKind::andi; // Triton AND + case 2: + return arith::AtomicRMWKind::ori; // Triton OR + case 3: + return std::nullopt; // Triton XOR not supported + case 4: + return arith::AtomicRMWKind::addi; // Triton ADD + case 5: + return arith::AtomicRMWKind::addf; // Triton FADD + case 6: + return arith::AtomicRMWKind::maxs; // Triton signed max + case 7: + return arith::AtomicRMWKind::mins; // Triton signed min + case 8: + return arith::AtomicRMWKind::maxu; // Triton unsigned max + case 9: + return arith::AtomicRMWKind::minu; // Triton unsigned min + case 10: + return arith::AtomicRMWKind::assign; // Triton XCHG → assign + default: + return std::nullopt; + } +} + +struct AtomicrmwConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto val = op.getVal(); + auto tritonPtr = op.getPtr(); + auto mask = op.getMask(); + + // Recover memref from Triton pointer + Value memref; + if (auto fromMemref = tritonPtr.getDefiningOp()) + memref = fromMemref.getOperand(); + + // Get Triton's atomic kind integer + auto kindIntAttr = op->getAttrOfType("atomic_rmw_op"); + if (!kindIntAttr) + return rewriter.notifyMatchFailure(op, "missing Triton atomic kind"); + + uint32_t triKind = kindIntAttr.getInt(); + auto mlirKindOpt = mapTritonToMLIR(triKind); + if (!mlirKindOpt) + return rewriter.notifyMatchFailure(op, "unsupported Triton atomic kind"); + + auto kindAttr = + arith::AtomicRMWKindAttr::get(rewriter.getContext(), *mlirKindOpt); + + // Use index 0 for rank-1 memrefs for now + Value idx = rewriter.create(loc, 0); + SmallVector indices{idx}; + + Value atomic; + + if (mask) { + auto ifOp = rewriter.create( + loc, mask, + /*thenBuilder=*/ + [&](OpBuilder &b, Location l) { + atomic = b.create(l, kindAttr, val, memref, + indices); + b.create(l); + }, + /*elseBuilder=*/ + [&](OpBuilder &b, Location l) { b.create(l); }); + + rewriter.replaceOp(op, atomic); + } else { + // Replace with memref.atomic_rmw + rewriter.replaceOpWithNewOp(op, kindAttr, val, + memref, indices); + } + + return success(); + } +}; + struct ToMemrefConverter : public OpRewritePattern { ToMemrefConverter(MLIRContext *context, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit) {} @@ -151,17 +260,42 @@ class ReconcilePtrCastsPass public: void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override { auto moduleOp = getOperation(); - RewritePatternSet patterns(&getContext()); - patterns - .add( - &getContext()); - if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { - signalPassFailure(); + + // === Phase 1: Greedy rewrites === + { + RewritePatternSet greedyPatterns(&getContext()); + greedyPatterns + .add( + &getContext()); + + if (failed(applyPatternsGreedily(moduleOp, std::move(greedyPatterns)))) { + signalPassFailure(); + return; + } + } + + // === Phase 2: Conversion patterns === + { + RewritePatternSet conversionPatterns(&getContext()); + conversionPatterns.add(&getContext()); + + ConversionTarget target(getContext()); + target.addIllegalOp(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + if (failed(applyPartialConversion(moduleOp, target, + std::move(conversionPatterns)))) { + signalPassFailure(); + return; + } } } };