Skip to content

Commit 7a5940c

Browse files
jax-triton-devkarupayun
jax-triton-dev
authored andcommitted
OpenXLA-specific changes
1 parent b2de88f commit 7a5940c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+2187
-114
lines changed

BUILD

+900
Large diffs are not rendered by default.

include/triton/Analysis/Alias.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,9 @@ class SharedMemoryAliasAnalysis
8585
}
8686

8787
/// Computes if the alloc set of the results are changed.
88-
void
89-
visitOperation(Operation *op,
90-
ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
91-
ArrayRef<dataflow::Lattice<AliasInfo> *> results) override;
88+
LogicalResult visitOperation(
89+
Operation *op, ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
90+
ArrayRef<dataflow::Lattice<AliasInfo> *> results) override;
9291
};
9392

9493
} // namespace mlir

include/triton/Dialect/Triton/IR/TritonTypes.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class TritonTypeDef<string name, string _mnemonic, list<Trait> traits = []>
1515
}
1616

1717
// Floating-point Type
18-
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
18+
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
1919
def TT_FloatTensor : RankedTensorOf<[TT_Float]>;
2020
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
2121

include/triton/Dialect/Triton/IR/Utility.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
3131

3232
/// Get the highest power of 2 divisor of an integer.
3333
template <typename T> T highestPowOf2Divisor(T n) {
34-
if (n == 0) {
34+
// When n is 0 or min, return the highest power of 2. The min case is handled
35+
// separately to avoid underflow when T is a signed integer. Technically
36+
// in that case the correct divisor is -n, but this value is outside the
37+
// range of possible values, so we take the next best alternative.
38+
if (n == 0 || n == std::numeric_limits<T>::min()) {
3539
return (static_cast<T>(1) << (sizeof(T) * 8 - 2));
3640
}
3741
return (n & (~(n - 1)));

lib/Analysis/Alias.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) {
2121
return ret;
2222
}
2323

24-
void SharedMemoryAliasAnalysis::visitOperation(
24+
LogicalResult SharedMemoryAliasAnalysis::visitOperation(
2525
Operation *op, ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
2626
ArrayRef<dataflow::Lattice<AliasInfo> *> results) {
2727
AliasInfo aliasInfo;
@@ -31,7 +31,7 @@ void SharedMemoryAliasAnalysis::visitOperation(
3131
if (auto memdescTy = dyn_cast<triton::MemDescType>(result.getType())) {
3232
if (!isa_and_nonnull<triton::gpu::SharedMemorySpaceAttr>(
3333
memdescTy.getMemorySpace()))
34-
return;
34+
return mlir::success();
3535
}
3636

3737
// Only LocalAllocOp creates a new buffer.
@@ -49,11 +49,13 @@ void SharedMemoryAliasAnalysis::visitOperation(
4949
}
5050

5151
if (pessimistic) {
52-
return setAllToEntryStates(results);
52+
setAllToEntryStates(results);
53+
return mlir::success();
5354
}
5455
// Join all lattice elements
5556
for (auto *result : results)
5657
propagateIfChanged(result, result->join(aliasInfo));
58+
return mlir::success();
5759
}
5860

5961
AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) {

lib/Analysis/AxisInfo.cpp

+9-6
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,9 @@ class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
195195
dataflow::Lattice<AxisInfo>>::getLatticeElement;
196196
using FuncAxisInfoMapT = DenseMap<FunctionOpInterface, AxisInfo>;
197197

198-
void visitOperation(Operation *op,
199-
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
200-
ArrayRef<dataflow::Lattice<AxisInfo> *> results) override;
198+
LogicalResult visitOperation(
199+
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
200+
ArrayRef<dataflow::Lattice<AxisInfo> *> results) override;
201201
void
202202
visitForOpInductionVar(scf::ForOp op,
203203
ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices);
@@ -1039,7 +1039,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10391039
visitors.append<LoadOpAxisInfoVisitor>();
10401040
}
10411041

1042-
void AxisInfoAnalysis::visitOperation(
1042+
LogicalResult AxisInfoAnalysis::visitOperation(
10431043
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
10441044
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
10451045
// TODO: For sure not the right way to do this
@@ -1048,8 +1048,10 @@ void AxisInfoAnalysis::visitOperation(
10481048
if (op->getValue().getRank() == 0)
10491049
setToEntryState((dataflow::Lattice<AxisInfo> *)op);
10501050
AxisInfo curr = visitors.apply(op, operands);
1051-
if (curr.getRank() == 0)
1052-
return setAllToEntryStates(results);
1051+
if (curr.getRank() == 0) {
1052+
setAllToEntryStates(results);
1053+
return mlir::success();
1054+
}
10531055
// override with hint
10541056
auto newContiguity = curr.getContiguity();
10551057
auto newDivisibility = curr.getDivisibility();
@@ -1071,6 +1073,7 @@ void AxisInfoAnalysis::visitOperation(
10711073
// join all lattice elements
10721074
for (auto *result : results)
10731075
propagateIfChanged(result, result->join(curr));
1076+
return mlir::success();
10741077
}
10751078

10761079
void AxisInfoAnalysis::visitForOpInductionVar(

lib/Analysis/Utility.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ bool supportMFMATypes(Type a, Type b) {
425425
if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth())
426426
return false;
427427

428+
auto F8E4M3FN = TypeID::get<Float8E4M3FNType>();
428429
auto F8E5M2 = TypeID::get<Float8E5M2Type>();
429430
auto F8E4M3FNUZ = TypeID::get<Float8E4M3FNUZType>();
430431
auto F8E5M2FNUZ = TypeID::get<Float8E5M2FNUZType>();
@@ -436,6 +437,7 @@ bool supportMFMATypes(Type a, Type b) {
436437
{F32, F32},
437438
{F16, F16},
438439
{BF16, BF16},
440+
{F8E4M3FN, F8E4M3FN},
439441
{F8E5M2, F8E5M2},
440442
{F8E4M3FNUZ, F8E4M3FNUZ},
441443
{F8E4M3FNUZ, F8E5M2FNUZ},
@@ -495,14 +497,14 @@ bool supportMMA(triton::DotOp op, int version) {
495497
return false;
496498
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
497499
retShapePerCTA[rank - 1] % 8 == 0 &&
498-
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() ||
500+
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
499501
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
500502
aElemTy.isF32()))) {
501503
return false;
502504
}
503505
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
504506
if (op.getMaxNumImpreciseAcc() < 32 &&
505-
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) &&
507+
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
506508
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
507509
return false;
508510
}

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
4040
auto ouEltTy = ouTensorTy.getElementType();
4141
if (inBitWidth == ouBitWidth)
4242
return values;
43-
if (inBitWidth == 16 && ouBitWidth == 32) {
43+
if ((inBitWidth == 16 && ouBitWidth == 32) ||
44+
(inBitWidth == 32 && ouBitWidth == 16)) {
4445
SmallVector<Value> ret;
4546
for (unsigned i = 0; i < values.size(); i += 8) {
4647
ret.push_back(values[i]);

lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
3434
addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional<Type> {
3535
return IntegerType::get(type.getContext(), 8);
3636
});
37+
addConversion([&](mlir::Float8E4M3FNType type) -> std::optional<Type> {
38+
return IntegerType::get(type.getContext(), 8);
39+
});
3740
addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> {
3841
return IntegerType::get(type.getContext(), 8);
3942
});

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ struct ArithConstantSplatOpConversion
8787
// LLVM IR.
8888
if (type::isFloat8(elemType))
8989
elemType = rewriter.getIntegerType(8);
90-
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
9190
auto typeConverter = getTypeConverter();
91+
auto constOp = rewriter.create<LLVM::ConstantOp>(
92+
loc, typeConverter->convertType(elemType), val);
9293
auto llStruct = SplatOpConversion::convertSplatLikeOp(
9394
elemType, op.getType(), constOp, typeConverter, rewriter, loc);
9495
rewriter.replaceOp(op, llStruct);

lib/Dialect/TritonGPU/IR/Dialect.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -2721,6 +2721,11 @@ struct CanonicalizeConvertFromAlloc
27212721
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
27222722
if (!convert)
27232723
return failure();
2724+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
2725+
// to SharedEncoding, so we want to keep this layout conversion.
2726+
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
2727+
convert.getSrc().getType().getEncoding()))
2728+
return failure();
27242729
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
27252730
op, op->getResult(0).getType(), convert.getSrc());
27262731
return mlir::success();

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

+25-1
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
153153
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
154154
newLayout, SharedMemorySpace);
155155
rewriter.setInsertionPointAfterValue(arg);
156+
157+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
158+
// to SharedEncoding.
159+
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
160+
argType.getEncoding())) {
161+
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
162+
// then pass it to the LocalAllocOp.
163+
auto newArgType = RankedTensorType::get(
164+
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
165+
auto dotOperandToBlockedCvt =
166+
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
167+
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
168+
dotOperandToBlockedCvt);
169+
}
170+
156171
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
157172
}
158173

@@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
162177
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;
163178

164179
static bool bwdFilter(Operation *op) {
180+
// Dot operand layout assignment to Predicates are not currently supported
181+
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
182+
// condition limits visibility of the original bit-width so that predicate
183+
// are not considered, hence, kwidth can never be = 32.
184+
if (isa<arith::UIToFPOp>(op)) {
185+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
186+
if (srcType.isInteger(1))
187+
return false;
188+
}
165189
return op->getNumOperands() == 1 &&
166190
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
167191
isPureUnaryInlineAsm(op) ||
@@ -357,7 +381,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
357381
NvidiaMmaEncodingAttr mmaLayout =
358382
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
359383
if (mmaLayout) {
360-
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ();
384+
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
361385
// promote operands for sm < 89 since fp8 mma is not natively supported
362386
// promote operands for sm >= 90 when mma is not v3
363387
if (!isNativeFP8 ||

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
111111
PatternRewriter &rewriter) const override {
112112
// Only consider conversions to dot operand.
113113
auto cvtTy = cast<RankedTensorType>(cvt.getType());
114-
if (!isa<DotOperandEncodingAttr>(cvtTy.getEncoding()))
114+
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
115+
if (!dotOpEnc)
115116
return failure();
116117

117118
auto src = cvt.getSrc().getDefiningOp();
@@ -126,6 +127,12 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
126127
[](Type ty) { return isa<RankedTensorType>(ty); }))
127128
return failure();
128129

130+
// Quick handling to fix loading issues when computing the original
131+
// bitwidth is unable to realize that there is a mixed-precision dot
132+
// (hence kWidth = 1) but wants to hoist through the type conversion.
133+
if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1)
134+
return failure();
135+
129136
// Only consider custom conversions or arith ops.
130137
// TODO(jlebar): Is this too restrictive?
131138
if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) &&
@@ -138,6 +145,14 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
138145
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
139146
return failure();
140147

148+
// Don't hoist through u1 -> fp casts as they aren't supported in
149+
// ElementwiseOpToLLVM::reorderValues().
150+
if (isa<arith::UIToFPOp>(src)) {
151+
Type srcType = getElementTypeOrSelf(src->getOperand(0));
152+
if (srcType.isInteger(1))
153+
return failure();
154+
}
155+
141156
// Check that the conversion is transitively dependent on a load, and all
142157
// operations between the load and the conversion are layout preserving.
143158
//

lib/Dialect/TritonGPU/Transforms/Prefetch.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
140140
type.getMemorySpace()),
141141
v, offsetsVal);
142142

143+
// We need to assign kwidth to zero in the case where the parent layout is
144+
// Blocked, otherwise the verifier emits a failure. The parent layout is
145+
// Blocked only when Tensor Cores are disabled.
146+
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
147+
? 0
148+
: prefetchWidth / 8;
143149
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
144-
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
150+
builder.getContext(), opIdx, dotEncoding, kwidth);
145151
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
146152
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
147153
newSmem);
@@ -187,6 +193,15 @@ LogicalResult Prefetcher::initialize() {
187193
break;
188194
if (!op->getResult(0).hasOneUse())
189195
break;
196+
// Similar to issues faced in HoistLayoutConversion pattern in
197+
// OptimizeDotOperands.cpp, we can't propagate through type casts from
198+
// predicates as they aren't supported in Triton when encoded with dot_op
199+
// layout.
200+
if (isa<arith::UIToFPOp>(op)) {
201+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
202+
if (srcType.isInteger(1))
203+
break;
204+
}
190205
rets.push_back(op->getOperand(0));
191206
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
192207
foundConvertFromShared = true;

0 commit comments

Comments
 (0)