Skip to content

Commit aedc018

Browse files
Aliia Khasanovakhasanovaa
Aliia Khasanova
authored andcommitted
OpenXLA-specific changes
1 parent 99cff45 commit aedc018

File tree

58 files changed

+3633
-982
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+3633
-982
lines changed

BUILD

+891
Large diffs are not rendered by default.

include/triton/Conversion/MLIRTypes.h

+7-6
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,16 @@ inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); }
2828

2929
inline bool isFloat(Type type) {
3030
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
31-
type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
32-
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
33-
type.isFloat8E5M2FNUZ();
31+
type.isBF16() ||
32+
llvm::isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType,
33+
mlir::Float8E4M3FNUZType, mlir::Float8E5M2Type,
34+
mlir::Float8E5M2FNUZType>(type);
3435
}
3536

3637
inline bool isFloat8(Type type) {
37-
return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
38-
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
39-
type.isFloat8E5M2FNUZ();
38+
return llvm::isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType,
39+
mlir::Float8E4M3FNUZType, mlir::Float8E5M2Type,
40+
mlir::Float8E5M2FNUZType>(type);
4041
}
4142

4243
inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }

include/triton/Dialect/Triton/IR/TritonOps.td

+6-1
Original file line numberDiff line numberDiff line change
@@ -1105,7 +1105,12 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI
11051105
MutableOperandRange getArgOperandsMutable() {
11061106
return getOperandsMutable();
11071107
}
1108-
1108+
Attribute removeArgAttrsAttr() { return nullptr; }
1109+
Attribute removeResAttrsAttr() { return nullptr; }
1110+
ArrayAttr getArgAttrsAttr() { return nullptr; }
1111+
ArrayAttr getResAttrsAttr() { return nullptr; }
1112+
void setArgAttrsAttr(ArrayAttr) { return; }
1113+
void setResAttrsAttr(ArrayAttr) { return; }
11091114
}];
11101115

11111116
let assemblyFormat = [{

lib/Analysis/Allocation.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
123123

124124
std::tie(scratchConfig.inVec, scratchConfig.outVec) =
125125
getScratchCvtInOutVecLengths(srcTy, dstTy);
126+
// We can't write a longer vector than the shape of shared memory.
127+
// This shape might be smaller than the tensor shape in case we decided to
128+
// do the conversion in multiple iterations.
129+
unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]];
130+
scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim);
131+
scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim);
126132

127133
// No padding is required if the tensor is 1-D, or if all dimensions except
128134
// the first accessed dimension have a size of 1.

lib/Analysis/AxisInfo.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
935935
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
936936
lhsDivisibility = 1;
937937
}
938-
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
938+
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
939939
}
940940

941941
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,

lib/Analysis/Utility.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -750,14 +750,14 @@ bool supportMMA(triton::DotOp op, int version) {
750750
return false;
751751
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
752752
retShapePerCTA[rank - 1] % 8 == 0 &&
753-
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
753+
(llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType>(aElemTy) ||
754754
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
755755
aElemTy.isF32()))) {
756756
return false;
757757
}
758758
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
759759
if (op.getMaxNumImpreciseAcc() < 32 &&
760-
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
760+
(llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType>(aElemTy)) &&
761761
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
762762
return false;
763763
}
@@ -778,8 +778,9 @@ bool supportMMA(Value value, int version) {
778778
cast<triton::gpu::TensorOrMemDesc>(value.getType()).getElementType();
779779
// FP8 is not natively supported on all mma versions but it can always be
780780
// promoted to fp16 therefore we can always support it.
781-
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
782-
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
781+
bool isFP8 =
782+
llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType,
783+
mlir::Float8E5M2FNUZType, mlir::Float8E4M3FNUZType>(elemTy);
783784
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
784785
(elemTy.isF32() && version >= 2) ||
785786
(elemTy.isInteger(8) && version >= 2);

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
5757
addArgumentMaterialization([&](OpBuilder &builder,
5858
RankedTensorType tensorType, ValueRange inputs,
5959
Location loc) -> Value {
60+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
61+
// remaining arguments that have been converted to a new type.
62+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
63+
// 'convert-triton-to-tritongpu'.
64+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
65+
inputs);
6066
llvm_unreachable("Argument rematerialization should not happen in Triton "
6167
"-> TritonGPU conversion");
6268
return {};
@@ -66,6 +72,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
6672
// convert origValue to newValue
6773
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
6874
ValueRange inputs, Location loc) -> Value {
75+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
76+
// remaining uses of values that have been converted to a new type.
77+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
78+
// 'convert-triton-to-tritongpu'.
79+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
80+
inputs);
6981
llvm_unreachable("Source rematerialization should not happen in Triton -> "
7082
"TritonGPU Conversion");
7183
return {};

lib/Dialect/Triton/IR/Ops.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
899899
if (argAttrs.empty())
900900
return;
901901
assert(type.getNumInputs() == argAttrs.size());
902-
function_interface_impl::addArgAndResultAttrs(
902+
call_interface_impl::addArgAndResultAttrs(
903903
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
904904
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
905905
}

lib/Dialect/TritonGPU/IR/Ops.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,11 @@ struct CanonicalizeConvertFromAlloc
151151
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
152152
if (!convert)
153153
return failure();
154+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
155+
// to SharedEncoding, so we want to keep this layout conversion.
156+
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
157+
convert.getSrc().getType().getEncoding()))
158+
return failure();
154159
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
155160
op, op->getResult(0).getType(), convert.getSrc());
156161
return mlir::success();
@@ -213,13 +218,13 @@ struct CanonicalizeConvertFromConvert
213218
// heuristic to accommodate fused attention.
214219
auto srcType = op.getSrc().getType();
215220
auto dstType = op.getType();
216-
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
217-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
221+
if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
222+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
218223
return failure();
219224

220225
// for hopper MMAv3
221-
if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
222-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
226+
if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
227+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
223228
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
224229
return dot->hasTrait<OpTrait::DotLike>();
225230
})) {

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

+41-9
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ namespace mlir {
2121
namespace triton {
2222
namespace gpu {
2323

24-
namespace {
25-
2624
// Get the highest version supported for the hardware and the dot.
2725
static int getMMAVersionSafe(int computeCapability, DotOp op) {
2826
// List supported mma version in order of preference.
@@ -47,8 +45,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
4745
return 0;
4846
}
4947

50-
SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
51-
int numWarps) {
48+
SmallVector<unsigned>
49+
warpsPerTileV2(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps) {
5250
auto rank = shape.size();
5351
// Early exit for batched matmul
5452
if (rank == 3)
@@ -112,10 +110,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
112110
}
113111

114112
SmallVector<unsigned, 2>
115-
warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
113+
warpsPerTileV3(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps,
116114
const SmallVector<unsigned, 3> &instrShape) {
117115
SetVector<Operation *> slices;
118-
mlir::getForwardSlice(dotOp.getResult(), &slices);
116+
mlir::getForwardSlice(dotOp->getResult(0), &slices);
119117
// Contains a chained dot. We prefer to assign warps to one axis
120118
// to facilitate use cases like flash attention, allowing reductions within
121119
// the same warp.
@@ -170,11 +168,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
170168
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
171169
newLayout, SharedMemorySpace);
172170
rewriter.setInsertionPointAfterValue(arg);
171+
172+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
173+
// to SharedEncoding.
174+
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
175+
argType.getEncoding())) {
176+
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
177+
// then pass it to the LocalAllocOp.
178+
auto newArgType = RankedTensorType::get(
179+
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
180+
auto dotOperandToBlockedCvt =
181+
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
182+
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
183+
dotOperandToBlockedCvt);
184+
}
185+
173186
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
174187
}
175188

176189
SmallVector<unsigned, 3>
177-
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
190+
getWarpsPerTile(Operation* dotOp, const ArrayRef<int64_t> shape, int version,
178191
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
179192
switch (version) {
180193
case 2:
@@ -188,6 +201,16 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
188201
}
189202

190203
static bool bwdFilter(Operation *op) {
204+
// Dot operand layout assignment to Predicates are not currently supported
205+
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
206+
// condition limits visibility of the original bit-width so that predicate
207+
// are not considered, hence, kwidth can never be = 32.
208+
if (isa<arith::UIToFPOp>(op)) {
209+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
210+
if (srcType.isInteger(1))
211+
return false;
212+
}
213+
191214
return op->getNumOperands() == 1 &&
192215
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
193216
isPureUnaryInlineAsm(op) ||
@@ -207,7 +230,7 @@ static bool bwdFilter(Operation *op) {
207230
// result, kwidth can be the bitwidth of the lower precision primitive.
208231
// Conversely, in the downcasting scenario, no reordering is performed,
209232
// making it directory use the lower precision primitive.
210-
static int computeOrigBitWidth(Value x) {
233+
int computeOrigBitWidth(Value x) {
211234
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
212235
int origBitWidth = finalBitWidth;
213236
SetVector<Operation *> slice;
@@ -227,6 +250,9 @@ static int computeOrigBitWidth(Value x) {
227250
}
228251
return origBitWidth;
229252
}
253+
// Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
254+
// extension.
255+
namespace {
230256

231257
class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
232258
int computeCapability;
@@ -632,7 +658,8 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
632658
NvidiaMmaEncodingAttr mmaLayout =
633659
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
634660
if (mmaLayout) {
635-
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
661+
bool isNativeFP8 =
662+
llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType>(AElType);
636663
// promote operands for sm < 89 since fp8 mma is not natively supported
637664
// promote operands for sm >= 90 when mma is not v3
638665
if (!isNativeFP8 ||
@@ -1018,6 +1045,11 @@ class TritonGPUAccelerateMatmulPass
10181045
}
10191046
};
10201047

1048+
Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter,
1049+
int opIdx, bool allowTranspose) {
1050+
return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose);
1051+
}
1052+
10211053
} // namespace gpu
10221054
} // namespace triton
10231055
} // namespace mlir

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
285285
if (!foundInitializer)
286286
return failure();
287287

288+
rewriter.setInsertionPointAfter(src);
288289
SmallVector<ConvertLayoutOp> newOperands;
289290
for (auto operand : src->getOperands()) {
290291
// We checked earlier that all operands are ranked tensors.

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
132132

133133
Value zero = builder.createWithStage<arith::ConstantIntOp>(
134134
forOp.getLoc(), stage, clusterId, 0, 32);
135+
135136
// Replace the load with insert/extract slice.
136137
builder.setInsertionPoint(loadOp);
137138
Location loc = loadOp.getLoc();
@@ -527,7 +528,8 @@ assignMemoryLayouts(scf::ForOp &forOp,
527528

528529
bool isTMALoad = isa<tt::ExperimentalDescriptorLoadOp,
529530
tt::ExperimentalDescriptorGatherOp>(op);
530-
loadsToPipeline.insert(&op);
531+
// TODO: b/381421713 - Uncomment this once pipelining is fixed.
532+
// loadsToPipeline.insert(&op);
531533
LoadInfo loadInfo;
532534
for (auto use : users) {
533535
if (use->hasTrait<OpTrait::DotLike>()) {
@@ -566,6 +568,11 @@ assignMemoryLayouts(scf::ForOp &forOp,
566568
getBlockedEncoding(loadOp, axisInfoAnalysis);
567569
}
568570
}
571+
572+
// TODO: b/381421713 - Remove this once pipelining is fixed.
573+
if (!loadInfo.sharedEncoding) continue;
574+
loadsToPipeline.insert(&op);
575+
569576
loadToInfo[&op] = loadInfo;
570577
}
571578
// Make sure all loads in loadsToPipeline are in loadToInfo.

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ mlir::triton::maybeGetStageCluster(Operation *op) {
255255
}
256256
std::pair<int, int> mlir::triton::getStageCluster(Operation *op) {
257257
auto res = maybeGetStageCluster(op);
258-
assert(res.has_value() || "Operation is missing stage & cluster attribute");
258+
assert(res.has_value() && "Operation is missing stage & cluster attribute");
259259
return *res;
260260
}
261261

lib/Dialect/TritonGPU/Transforms/Prefetch.cpp

+24-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
116116
// opIdx: 0 => a, 1 => b
117117
auto type = cast<triton::gpu::MemDescType>(v.getType());
118118
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
119-
SmallVector<int64_t> offset{0, 0};
119+
SmallVector<int64_t> offset(shape.size(), 0);
120120
Type elementType = type.getElementType();
121121

122122
// k => (prefetchWidth, k - prefetchWidth)
@@ -141,8 +141,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
141141
type.getMutableMemory(), type.getAllocShape()),
142142
v, offsetsVal);
143143

144+
// We need to assign kwidth to zero in the case where the parent layout is
145+
// Blocked, otherwise the verifier emits a failure. The parent layout is
146+
// Blocked only when Tensor Cores are disabled.
147+
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
148+
? 0
149+
: prefetchWidth / 8;
144150
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
145-
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
151+
builder.getContext(), opIdx, dotEncoding, kwidth);
146152
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
147153
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
148154
newSmem);
@@ -191,6 +197,22 @@ LogicalResult Prefetcher::initialize() {
191197
break;
192198
if (!op->getResult(0).hasOneUse())
193199
break;
200+
// Similar to issues faced in HoistLayoutConversion pattern in
201+
// OptimizeDotOperands.cpp, we can't propagate through type casts from
202+
// predicates as they aren't supported in Triton when encoded with dot_op
203+
// layout.
204+
if (isa<arith::UIToFPOp>(op)) {
205+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
206+
if (srcType.isInteger(1))
207+
break;
208+
}
209+
// Propagation through ExpandDims is currently not supported. This blindly
210+
// replaces the encoding with dot encoding & but ExpandDims requires a
211+
// SliceEncoding. This could be rewritten to support it somehow, but I
212+
// don't think it's trivial & it's currently crashing.
213+
if (isa<ExpandDimsOp>(op)) {
214+
break;
215+
}
194216
rets.push_back(op->getOperand(0));
195217
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
196218
foundConvertFromShared = true;

0 commit comments

Comments
 (0)