Skip to content

Commit 202c92f

Browse files
committed
OpenXLA-specific changes
1 parent b01bb25 commit 202c92f

File tree

41 files changed

+3692
-1011
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3692
-1011
lines changed

BUILD

+934
Large diffs are not rendered by default.

include/triton/Conversion/MLIRTypes.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ inline Type u1Ty(MLIRContext *ctx) {
2121
}
2222

2323
// Float types
24-
inline Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
25-
inline Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
26-
inline Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
27-
inline Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
24+
inline Type f16Ty(MLIRContext *ctx) { return Float16Type::get(ctx); }
25+
inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); }
26+
inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); }
27+
inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); }
2828

2929
inline bool isFloat(Type type) {
3030
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||

lib/Analysis/AxisInfo.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
937937
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
938938
lhsDivisibility = 1;
939939
}
940-
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
940+
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
941941
}
942942

943943
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
5757
addArgumentMaterialization([&](OpBuilder &builder,
5858
RankedTensorType tensorType, ValueRange inputs,
5959
Location loc) -> Value {
60+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
61+
// remaining arguments that have been converted to a new type.
62+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
63+
// 'convert-triton-to-tritongpu'.
64+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
65+
inputs);
6066
llvm_unreachable("Argument rematerialization should not happen in Triton "
6167
"-> TritonGPU conversion");
6268
return {};
@@ -66,6 +72,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
6672
// convert origValue to newValue
6773
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
6874
ValueRange inputs, Location loc) -> Value {
75+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
76+
// remaining uses of values that have been converted to a new type.
77+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
78+
// 'convert-triton-to-tritongpu'.
79+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
80+
inputs);
6981
llvm_unreachable("Source rematerialization should not happen in Triton -> "
7082
"TritonGPU Conversion");
7183
return {};

lib/Dialect/TritonGPU/IR/Ops.cpp

+51-70
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ namespace mlir::triton::gpu {
1212

1313
namespace {
1414

15-
template <typename T> bool hasEncoding(Value value) {
15+
template <typename T>
16+
bool hasEncoding(Value value) {
1617
auto type = value.getType();
1718
if (auto tensorType = dyn_cast<TensorOrMemDesc>(type)) {
1819
auto encoding = tensorType.getEncoding();
@@ -25,7 +26,7 @@ bool hasDotOperandEncoding(Value value) {
2526
return hasEncoding<triton::gpu::DotOperandEncodingAttr>(value);
2627
}
2728

28-
} // namespace
29+
} // namespace
2930

3031
//===----------------------------------------------------------------------===//
3132
// Canonicalizer
@@ -36,16 +37,13 @@ struct CanonicalizeConvertFromReshape
3637
: public mlir::OpRewritePattern<triton::ReshapeOp> {
3738
using OpRewritePattern::OpRewritePattern;
3839

39-
mlir::LogicalResult
40-
matchAndRewrite(triton::ReshapeOp op,
41-
PatternRewriter &rewriter) const override {
40+
mlir::LogicalResult matchAndRewrite(
41+
triton::ReshapeOp op, PatternRewriter &rewriter) const override {
4242
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
43-
if (!convert)
44-
return failure();
43+
if (!convert) return failure();
4544
if (isExpensiveView(convert.getSrc().getType(), op.getType()))
4645
return failure();
47-
if (!op.getAllowReorder() || op.getEfficientLayout())
48-
return failure();
46+
if (!op.getAllowReorder() || op.getEfficientLayout()) return failure();
4947

5048
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
5149
op, op.getType(), convert.getSrc(), op.getAllowReorder());
@@ -58,12 +56,10 @@ struct CanonicalizeConvertFromHistogram
5856
: public mlir::OpRewritePattern<triton::HistogramOp> {
5957
using OpRewritePattern::OpRewritePattern;
6058

61-
mlir::LogicalResult
62-
matchAndRewrite(triton::HistogramOp op,
63-
PatternRewriter &rewriter) const override {
59+
mlir::LogicalResult matchAndRewrite(
60+
triton::HistogramOp op, PatternRewriter &rewriter) const override {
6461
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
65-
if (!convert)
66-
return failure();
62+
if (!convert) return failure();
6763
rewriter.replaceOpWithNewOp<triton::HistogramOp>(
6864
op, op->getResult(0).getType(), convert.getSrc());
6965
return mlir::success();
@@ -79,15 +75,13 @@ struct CanonicalizeConvertFromHistogram
7975
struct CanonicalizeConvertFromGatherSource : public OpRewritePattern<GatherOp> {
8076
using OpRewritePattern::OpRewritePattern;
8177

82-
mlir::LogicalResult
83-
matchAndRewrite(GatherOp op, PatternRewriter &rewriter) const override {
78+
mlir::LogicalResult matchAndRewrite(
79+
GatherOp op, PatternRewriter &rewriter) const override {
8480
// Don't do this if the compiler picked an optimized layout.
85-
if (op.getEfficientLayout())
86-
return failure();
81+
if (op.getEfficientLayout()) return failure();
8782

8883
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
89-
if (!convert)
90-
return failure();
84+
if (!convert) return failure();
9185

9286
rewriter.replaceOpWithNewOp<GatherOp>(op, convert.getSrc(), op.getIndices(),
9387
op.getAxis());
@@ -100,13 +94,15 @@ struct CanonicalizeConvertFromAlloc
10094
: public mlir::OpRewritePattern<triton::gpu::LocalAllocOp> {
10195
using OpRewritePattern::OpRewritePattern;
10296

103-
mlir::LogicalResult
104-
matchAndRewrite(triton::gpu::LocalAllocOp op,
105-
PatternRewriter &rewriter) const override {
106-
if (!op.getSrc())
107-
return failure();
97+
mlir::LogicalResult matchAndRewrite(
98+
triton::gpu::LocalAllocOp op, PatternRewriter &rewriter) const override {
99+
if (!op.getSrc()) return failure();
108100
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
109-
if (!convert)
101+
if (!convert) return failure();
102+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
103+
// to SharedEncoding, so we want to keep this layout conversion.
104+
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
105+
convert.getSrc().getType().getEncoding()))
110106
return failure();
111107
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
112108
op, op->getResult(0).getType(), convert.getSrc());
@@ -119,12 +115,10 @@ struct CanonicalizeConvertFromLocalStore
119115
: public mlir::OpRewritePattern<triton::gpu::LocalStoreOp> {
120116
using OpRewritePattern::OpRewritePattern;
121117

122-
mlir::LogicalResult
123-
matchAndRewrite(triton::gpu::LocalStoreOp op,
124-
PatternRewriter &rewriter) const override {
118+
mlir::LogicalResult matchAndRewrite(
119+
triton::gpu::LocalStoreOp op, PatternRewriter &rewriter) const override {
125120
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
126-
if (!convert)
127-
return failure();
121+
if (!convert) return failure();
128122
rewriter.replaceOpWithNewOp<triton::gpu::LocalStoreOp>(op, convert.getSrc(),
129123
op.getDst());
130124
return mlir::success();
@@ -135,19 +129,16 @@ struct CanonicalizeConvertFromSplit
135129
: public mlir::OpRewritePattern<triton::SplitOp> {
136130
using OpRewritePattern::OpRewritePattern;
137131

138-
mlir::LogicalResult
139-
matchAndRewrite(triton::SplitOp op,
140-
PatternRewriter &rewriter) const override {
132+
mlir::LogicalResult matchAndRewrite(
133+
triton::SplitOp op, PatternRewriter &rewriter) const override {
141134
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
142-
if (!convert)
143-
return failure();
135+
if (!convert) return failure();
144136
auto srcEncoding = convert.getSrc().getType().getEncoding();
145137
// Multiple source layout can give the same output layout, if the source
146138
// layout of the convert gives the same destination layout we can skip the
147139
// convert.
148140
auto dstEncoding = inferDstEncoding(op, srcEncoding);
149-
if (dstEncoding != op.getOutLHS().getType().getEncoding())
150-
return failure();
141+
if (dstEncoding != op.getOutLHS().getType().getEncoding()) return failure();
151142
rewriter.replaceOpWithNewOp<triton::SplitOp>(op, convert.getSrc());
152143
return mlir::success();
153144
}
@@ -157,9 +148,8 @@ struct CanonicalizeConvertFromConvert
157148
: public OpRewritePattern<ConvertLayoutOp> {
158149
using OpRewritePattern::OpRewritePattern;
159150

160-
mlir::LogicalResult
161-
matchAndRewrite(ConvertLayoutOp op,
162-
PatternRewriter &rewriter) const override {
151+
mlir::LogicalResult matchAndRewrite(
152+
ConvertLayoutOp op, PatternRewriter &rewriter) const override {
163153
// Convert to the same layout is redundant.
164154
if (op->getResultTypes() == op->getOperandTypes()) {
165155
rewriter.replaceOp(op, op->getOperands());
@@ -170,22 +160,21 @@ struct CanonicalizeConvertFromConvert
170160
// heuristic to accommodate fused attention.
171161
auto srcType = op.getSrc().getType();
172162
auto dstType = op.getType();
173-
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
174-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
163+
if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
164+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
175165
return failure();
176166

177167
// for hopper MMAv3
178-
if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
179-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
168+
if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
169+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
180170
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
181171
return dot->hasTrait<OpTrait::DotLike>();
182172
})) {
183173
return failure();
184174
}
185175

186176
Operation *arg = op.getSrc().getDefiningOp();
187-
if (!arg)
188-
return failure();
177+
if (!arg) return failure();
189178

190179
// cvt(reshape) -> reshape
191180
if (auto reshape = dyn_cast<ReshapeOp>(arg)) {
@@ -233,8 +222,7 @@ struct CanonicalizeConvertFromConvert
233222

234223
// cvt(cat) -> cat
235224
if (auto cat = dyn_cast<CatOp>(arg)) {
236-
if (isExpensiveCat(cat, op.getType().getEncoding()))
237-
return failure();
225+
if (isExpensiveCat(cat, op.getType().getEncoding())) return failure();
238226

239227
rewriter.replaceOpWithNewOp<CatOp>(op, op->getResult(0).getType(),
240228
cat.getOperands());
@@ -291,15 +279,14 @@ LogicalResult UpcastMXFPOp::verify() {
291279

292280
auto xTy = getSrc().getType();
293281
auto scaleTy = getScale().getType();
294-
Builder b(getContext());
295-
if (xTy.getElementType() != b.getBF16Type() &&
296-
xTy.getElementType() != b.getF16Type() &&
297-
xTy.getElementType() != b.getI8Type()) {
298-
return emitOpError(
299-
"element type of the first operand must be bf16/fp16 or i8");
282+
283+
if (xTy.getElementType() != BFloat16Type::get(getContext()) &&
284+
xTy.getElementType() != Float16Type::get(getContext()) &&
285+
xTy.getElementType() != IntegerType::get(getContext(), 8)) {
286+
return emitOpError("element type of the first operand must be bf16 or i8");
300287
}
301288

302-
if (scaleTy.getElementType() != b.getI8Type()) {
289+
if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) {
303290
return emitOpError("element type of the second operand must be uint8");
304291
}
305292

@@ -373,14 +360,12 @@ LogicalResult UpcastMXFPOp::verify() {
373360
return success();
374361
}
375362

376-
RankedTensorType
377-
UpcastMXFPOp::deduceOutputType(TypedValue<RankedTensorType> inputTensor,
378-
ScaleDotElemType inputElemType,
379-
Type outputElemType) {
363+
RankedTensorType UpcastMXFPOp::deduceOutputType(
364+
TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType,
365+
Type outputElemType) {
380366
MLIRContext *ctx = inputTensor.getContext();
381367
auto xTy = inputTensor.getType();
382-
if (inputElemType != ScaleDotElemType::E2M1)
383-
return xTy;
368+
if (inputElemType != ScaleDotElemType::E2M1) return xTy;
384369

385370
auto xShape = xTy.getShape();
386371
auto newShape = llvm::to_vector(xShape);
@@ -466,17 +451,13 @@ void LocalAllocOp::getEffects(
466451
}
467452

468453
OpFoldResult LocalAllocOp::fold(FoldAdaptor adaptor) {
469-
if (getType().getMutableMemory())
470-
return {};
454+
if (getType().getMutableMemory()) return {};
471455
auto src = getSrc();
472-
if (!src)
473-
return {};
456+
if (!src) return {};
474457
auto localLoadOp = src.getDefiningOp<LocalLoadOp>();
475-
if (!localLoadOp)
476-
return {};
458+
if (!localLoadOp) return {};
477459
auto loadSrc = localLoadOp.getSrc();
478-
if (loadSrc.getType() != getType())
479-
return {};
460+
if (loadSrc.getType() != getType()) return {};
480461
return loadSrc;
481462
}
482463

0 commit comments

Comments
 (0)