Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 142 additions & 8 deletions lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
// from_memref accordingly.
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
Expand All @@ -30,6 +31,8 @@

#include "triton-shared/Conversion/TritonToLinalgExperimental/ReconcilePtrCasts.h"
#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"

using namespace mlir;
using namespace triton;
Expand Down Expand Up @@ -82,6 +85,7 @@ struct FromMemrefConverter
return failure();
}

Location loc = op.getLoc();
auto input = op.getInputs().front();
auto unrankedInput = dyn_cast<UnrankedMemRefType>(input.getType());
auto output = op.getResult(0);
Expand All @@ -90,9 +94,30 @@ struct FromMemrefConverter
if (unrankedInput && isa<triton::PointerType, ptr::PtrType>(outType)) {
// from_memref only takes ranked memref, cast the unranked memref to
// ranked memref first.
auto rankedMemref = rewriter.create<memref::CastOp>(
Value rankedMemref = rewriter.create<memref::CastOp>(
op.getLoc(), MemRefType::get({1}, unrankedInput.getElementType()),
input);

// The output of the cast may have a different element type than the
// input memref. In that case we should use an unrealized conversion cast
// to match the element type.
Type elementTy;
if (auto ttPtr = dyn_cast<triton::PointerType>(outType))
elementTy = ttPtr.getPointeeType();
else if (auto genericPtr = dyn_cast<ptr::PtrType>(outType))
elementTy = genericPtr.getElementType();
else
return failure();

if (elementTy && elementTy != unrankedInput.getElementType()) {
// Insert an unrealized conversion cast to match element type
auto castOp = rewriter.create<UnrealizedConversionCastOp>(
loc, MemRefType::get({1}, elementTy), rankedMemref);

// Use the result of the cast op
rankedMemref = castOp.getResult(0);
}

auto memrefToPtr = rewriter.create<tptr::FromMemrefOp>(
op->getLoc(),
ptr::PtrType::get(
Expand All @@ -110,6 +135,90 @@ struct FromMemrefConverter
}
};

static std::optional<arith::AtomicRMWKind> mapTritonToMLIR(uint32_t triKind) {
switch (triKind) {
case 1:
return arith::AtomicRMWKind::andi; // Triton AND
case 2:
return arith::AtomicRMWKind::ori; // Triton OR
case 3:
return std::nullopt; // Triton XOR not supported
case 4:
return arith::AtomicRMWKind::addi; // Triton ADD
case 5:
return arith::AtomicRMWKind::addf; // Triton FADD
case 6:
return arith::AtomicRMWKind::maxs; // Triton signed max
case 7:
return arith::AtomicRMWKind::mins; // Triton signed min
case 8:
return arith::AtomicRMWKind::maxu; // Triton unsigned max
case 9:
return arith::AtomicRMWKind::minu; // Triton unsigned min
case 10:
return arith::AtomicRMWKind::assign; // Triton XCHG → assign
default:
return std::nullopt;
}
}

struct AtomicrmwConverter : public OpRewritePattern<triton::AtomicRMWOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(triton::AtomicRMWOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto val = op.getVal();
auto tritonPtr = op.getPtr();
auto mask = op.getMask();

// Recover memref from Triton pointer
Value memref;
if (auto fromMemref = tritonPtr.getDefiningOp<tptr::FromMemrefOp>())
memref = fromMemref.getOperand();

// Get Triton's atomic kind integer
auto kindIntAttr = op->getAttrOfType<IntegerAttr>("atomic_rmw_op");
if (!kindIntAttr)
return rewriter.notifyMatchFailure(op, "missing Triton atomic kind");

uint32_t triKind = kindIntAttr.getInt();
auto mlirKindOpt = mapTritonToMLIR(triKind);
if (!mlirKindOpt)
return rewriter.notifyMatchFailure(op, "unsupported Triton atomic kind");

auto kindAttr =
arith::AtomicRMWKindAttr::get(rewriter.getContext(), *mlirKindOpt);

// Use index 0 for rank-1 memrefs for now
Value idx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value, 1> indices{idx};

Value atomic;

if (mask) {
auto ifOp = rewriter.create<scf::IfOp>(
loc, mask,
/*thenBuilder=*/
[&](OpBuilder &b, Location l) {
atomic = b.create<memref::AtomicRMWOp>(l, kindAttr, val, memref,
indices);
b.create<scf::YieldOp>(l);
},
/*elseBuilder=*/
[&](OpBuilder &b, Location l) { b.create<scf::YieldOp>(l); });

rewriter.replaceOp(op, atomic);
} else {
// Replace with memref.atomic_rmw
rewriter.replaceOpWithNewOp<memref::AtomicRMWOp>(op, kindAttr, val,
memref, indices);
}

return success();
}
};

struct ToMemrefConverter : public OpRewritePattern<UnrealizedConversionCastOp> {
ToMemrefConverter(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<UnrealizedConversionCastOp>(context, benefit) {}
Expand Down Expand Up @@ -151,17 +260,42 @@ class ReconcilePtrCastsPass

public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tptr::TPtrDialect, memref::MemRefDialect, BuiltinDialect>();
registry.insert<tptr::TPtrDialect, memref::MemRefDialect, BuiltinDialect,
arith::ArithDialect>();
}

void runOnOperation() override {
auto moduleOp = getOperation();
RewritePatternSet patterns(&getContext());
patterns
.add<SimplifyUnrealizedCast, FromMemrefConverter, ToMemrefConverter>(
&getContext());
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
signalPassFailure();

// === Phase 1: Greedy rewrites ===
{
RewritePatternSet greedyPatterns(&getContext());
greedyPatterns
.add<SimplifyUnrealizedCast, FromMemrefConverter, ToMemrefConverter>(
&getContext());

if (failed(applyPatternsGreedily(moduleOp, std::move(greedyPatterns)))) {
signalPassFailure();
return;
}
}

// === Phase 2: Conversion patterns ===
{
RewritePatternSet conversionPatterns(&getContext());
conversionPatterns.add<AtomicrmwConverter>(&getContext());

ConversionTarget target(getContext());
target.addIllegalOp<triton::AtomicRMWOp>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<scf::SCFDialect>();

if (failed(applyPartialConversion(moduleOp, target,
std::move(conversionPatterns)))) {
signalPassFailure();
return;
}
}
}
};
Expand Down