Skip to content

Commit 8a371e2

Browse files
committed
OpenXLA-specific changes
1 parent 5aa4af9 commit 8a371e2

File tree

50 files changed

+3787
-1046
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+3787
-1046
lines changed

BUILD

+928
Large diffs are not rendered by default.

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

+4-2
Original file line numberDiff line numberDiff line change
@@ -502,15 +502,17 @@ We call each individual tile "rep".
502502
"unsigned",
503503
"getTotalElemsPerThread",
504504
(ins "ArrayRef<int64_t>":$shape),
505+
/*methodBody=*/[{}],
505506
/*defaultImplementation=*/[{
506-
return toLinearEncoding($_self, shape).getTotalElemsPerThread(shape);
507+
return toLinearEncoding($_attr, shape).getTotalElemsPerThread(shape);
507508
}]>,
508509
InterfaceMethod<"Return element size per thread in each dimension.",
509510
"SmallVector<unsigned>",
510511
"getElemsPerThread",
511512
(ins "ArrayRef<int64_t>":$shape),
513+
/*methodBody=*/[{}],
512514
/*defaultImplementation=*/[{
513-
return toLinearEncoding($_self, shape).getElemsPerThread(shape);
515+
return toLinearEncoding($_attr, shape).getElemsPerThread(shape);
514516
}]>,
515517
// Interface for the meta information about the multiple thread hierarchy.
516518
InterfaceMethod<"Get the shape of the warps per CTA.",

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
5858
addArgumentMaterialization([&](OpBuilder &builder,
5959
RankedTensorType tensorType, ValueRange inputs,
6060
Location loc) -> Value {
61+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
62+
// remaining arguments that have been converted to a new type.
63+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
64+
// 'convert-triton-to-tritongpu'.
65+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
66+
inputs);
6167
llvm_unreachable("Argument rematerialization should not happen in Triton "
6268
"-> TritonGPU conversion");
6369
return {};
@@ -67,6 +73,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
6773
// convert origValue to newValue
6874
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
6975
ValueRange inputs, Location loc) -> Value {
76+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
77+
// remaining uses of values that have been converted to a new type.
78+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
79+
// 'convert-triton-to-tritongpu'.
80+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
81+
inputs);
7082
llvm_unreachable("Source rematerialization should not happen in Triton -> "
7183
"TritonGPU Conversion");
7284
return {};

lib/Dialect/TritonGPU/IR/Dialect.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@ LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef<int64_t> shape) {
3939
}
4040

4141
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
42-
return toLinearEncoding(layout, shape).getTotalElemsPerThread(shape);
42+
auto distributedEncoding = mlir::cast<DistributedEncodingTrait>(layout);
43+
return distributedEncoding.getTotalElemsPerThread(shape);
4344
}
4445

4546
SmallVector<unsigned> getElemsPerThread(Attribute layout,
4647
ArrayRef<int64_t> shape) {
47-
return toLinearEncoding(layout, shape).getElemsPerThread(shape);
48+
auto distributedEncoding = mlir::cast<DistributedEncodingTrait>(layout);
49+
return distributedEncoding.getElemsPerThread(shape);
4850
}
4951

5052
SmallVector<unsigned> getElemsPerThread(Type type) {

lib/Dialect/TritonGPU/IR/Ops.cpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ struct CanonicalizeConvertFromAlloc
159159
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
160160
if (!convert)
161161
return failure();
162+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
163+
// to SharedEncoding, so we want to keep this layout conversion.
164+
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
165+
convert.getSrc().getType().getEncoding()))
166+
return failure();
162167
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
163168
op, op->getResult(0).getType(), convert.getSrc());
164169
return mlir::success();
@@ -221,8 +226,8 @@ struct CanonicalizeConvertFromConvert
221226
// heuristic to accommodate fused attention.
222227
auto srcType = op.getSrc().getType();
223228
auto dstType = op.getType();
224-
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
225-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
229+
if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
230+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
226231
return failure();
227232

228233
Operation *arg = op.getSrc().getDefiningOp();

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

+39-8
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ namespace mlir {
2121
namespace triton {
2222
namespace gpu {
2323

24-
namespace {
25-
2624
// Get the highest version supported for the hardware and the dot.
2725
static int getMMAVersionSafe(int computeCapability, DotOp op) {
2826
// List supported mma version in order of preference.
@@ -47,8 +45,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
4745
return 0;
4846
}
4947

50-
SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
51-
int numWarps) {
48+
SmallVector<unsigned>
49+
warpsPerTileV2(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps) {
5250
auto rank = shape.size();
5351
// Early exit for batched matmul
5452
if (rank == 3)
@@ -112,10 +110,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
112110
}
113111

114112
SmallVector<unsigned, 2>
115-
warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
113+
warpsPerTileV3(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps,
116114
const SmallVector<unsigned, 3> &instrShape) {
117115
SetVector<Operation *> slices;
118-
mlir::getForwardSlice(dotOp.getResult(), &slices);
116+
mlir::getForwardSlice(dotOp->getResult(0), &slices);
119117
// Contains a chained dot. We prefer to assign warps to one axis
120118
// to facilitate use cases like flash attention, allowing reductions within
121119
// the same warp.
@@ -181,6 +179,21 @@ getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
181179
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
182180
newLayout, SharedMemorySpace);
183181
rewriter.setInsertionPointAfterValue(arg);
182+
183+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
184+
// to SharedEncoding.
185+
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
186+
argType.getEncoding())) {
187+
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
188+
// then pass it to the LocalAllocOp.
189+
auto newArgType = RankedTensorType::get(
190+
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
191+
auto dotOperandToBlockedCvt =
192+
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
193+
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
194+
dotOperandToBlockedCvt);
195+
}
196+
184197
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
185198
}
186199

@@ -204,7 +217,7 @@ getSharedMemoryScale(Value arg, mlir::PatternRewriter &rewriter, Location loc) {
204217
}
205218

206219
SmallVector<unsigned, 3>
207-
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
220+
getWarpsPerTile(Operation* dotOp, const ArrayRef<int64_t> shape, int version,
208221
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
209222
switch (version) {
210223
case 2:
@@ -218,6 +231,16 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
218231
}
219232

220233
static bool bwdFilter(Operation *op) {
234+
// Dot operand layout assignment to Predicates are not currently supported
235+
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
236+
// condition limits visibility of the original bit-width so that predicate
237+
// are not considered, hence, kwidth can never be = 32.
238+
if (isa<arith::UIToFPOp>(op)) {
239+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
240+
if (srcType.isInteger(1))
241+
return false;
242+
}
243+
221244
return op->getNumOperands() == 1 &&
222245
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
223246
isPureUnaryInlineAsm(op) ||
@@ -237,7 +260,7 @@ static bool bwdFilter(Operation *op) {
237260
// result, kwidth can be the bitwidth of the lower precision primitive.
238261
// Conversely, in the downcasting scenario, no reordering is performed,
239262
// making it directory use the lower precision primitive.
240-
static int computeOrigBitWidth(Value x) {
263+
int computeOrigBitWidth(Value x) {
241264
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
242265
int origBitWidth = finalBitWidth;
243266
SetVector<Operation *> slice;
@@ -257,6 +280,9 @@ static int computeOrigBitWidth(Value x) {
257280
}
258281
return origBitWidth;
259282
}
283+
// Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
284+
// extension.
285+
namespace {
260286

261287
class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
262288
int computeCapability;
@@ -1147,6 +1173,11 @@ class TritonGPUAccelerateMatmulPass
11471173
}
11481174
};
11491175

1176+
Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter,
1177+
int opIdx, bool allowTranspose) {
1178+
return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose);
1179+
}
1180+
11501181
} // namespace gpu
11511182
} // namespace triton
11521183
} // namespace mlir

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
131131

132132
Value zero = builder.createWithStage<arith::ConstantIntOp>(
133133
forOp.getLoc(), stage, clusterId, 0, 32);
134+
134135
// Replace the load with insert/extract slice.
135136
builder.setInsertionPoint(loadOp);
136137
Location loc = loadOp.getLoc();
@@ -524,7 +525,8 @@ assignMemoryLayouts(scf::ForOp &forOp,
524525

525526
bool isTMALoad = isa<tt::ExperimentalDescriptorLoadOp,
526527
tt::ExperimentalDescriptorGatherOp>(op);
527-
loadsToPipeline.insert(&op);
528+
// TODO: b/381421713 - Uncomment this once pipelining is fixed.
529+
// loadsToPipeline.insert(&op);
528530
LoadInfo loadInfo;
529531
for (auto use : users) {
530532
if (isa<mlir::triton::DotOpInterface>(use)) {
@@ -562,6 +564,11 @@ assignMemoryLayouts(scf::ForOp &forOp,
562564
getBlockedEncoding(loadOp, axisInfoAnalysis);
563565
}
564566
}
567+
568+
// TODO: b/381421713 - Remove this once pipelining is fixed.
569+
if (!loadInfo.sharedEncoding) continue;
570+
loadsToPipeline.insert(&op);
571+
565572
loadToInfo[&op] = loadInfo;
566573
}
567574
// Make sure all loads in loadsToPipeline are in loadToInfo.

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ mlir::triton::maybeGetStageCluster(Operation *op) {
255255
}
256256
std::pair<int, int> mlir::triton::getStageCluster(Operation *op) {
257257
auto res = maybeGetStageCluster(op);
258-
assert(res.has_value() || "Operation is missing stage & cluster attribute");
258+
assert(res.has_value() && "Operation is missing stage & cluster attribute");
259259
return *res;
260260
}
261261

lib/Dialect/TritonGPU/Transforms/Prefetch.cpp

+24-2
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
121121
// opIdx: 0 => a, 1 => b
122122
auto type = cast<triton::gpu::MemDescType>(v.getType());
123123
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
124-
SmallVector<int64_t> offset{0, 0};
124+
SmallVector<int64_t> offset(shape.size(), 0);
125125
Type elementType = type.getElementType();
126126

127127
// k => (prefetchWidth, k - prefetchWidth)
@@ -146,8 +146,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
146146
type.getMutableMemory(), type.getAllocShape()),
147147
v, offsetsVal);
148148

149+
// We need to assign kwidth to zero in the case where the parent layout is
150+
// Blocked, otherwise the verifier emits a failure. The parent layout is
151+
// Blocked only when Tensor Cores are disabled.
152+
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
153+
? 0
154+
: prefetchWidth / 8;
149155
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
150-
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
156+
builder.getContext(), opIdx, dotEncoding, kwidth);
151157
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
152158
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
153159
newSmem);
@@ -197,6 +203,22 @@ LogicalResult Prefetcher::initialize() {
197203
break;
198204
if (!op->getResult(0).hasOneUse())
199205
break;
206+
// Similar to issues faced in HoistLayoutConversion pattern in
207+
// OptimizeDotOperands.cpp, we can't propagate through type casts from
208+
// predicates as they aren't supported in Triton when encoded with dot_op
209+
// layout.
210+
if (isa<arith::UIToFPOp>(op)) {
211+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
212+
if (srcType.isInteger(1))
213+
break;
214+
}
215+
// Propagation through ExpandDims is currently not supported. This blindly
216+
// replaces the encoding with dot encoding & but ExpandDims requires a
217+
// SliceEncoding. This could be rewritten to support it somehow, but I
218+
// don't think it's trivial & it's currently crashing.
219+
if (isa<ExpandDimsOp>(op)) {
220+
break;
221+
}
200222
rets.push_back(op->getOperand(0));
201223
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
202224
// NYI for other encodings, for example if we have transpose

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

+28-5
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ class LayoutRematerialization {
165165
SetVector<Operation *> opToDelete;
166166
FuncOp funcOp;
167167
DominanceInfo domInfo;
168+
PostDominanceInfo postDomInfo;
168169
};
169170

170171
void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
@@ -1120,13 +1121,35 @@ void LayoutRematerialization::hoistConvertDotOperand(
11201121
ConvertLayoutOp convertOp) {
11211122
auto targetType = convertOp.getType();
11221123
// The pass is targeted to Nvidia mma/wgmma dot operands
1124+
1125+
// Partial cherry-pick of https://github.com/triton-lang/triton/pull/5475.
1126+
// Path 2 in b/391692127#comment28. Added check for parent being a for loop.
1127+
auto canBePipelined = [&](ConvertLayoutOp convertOp) {
1128+
auto parent = dyn_cast<scf::ForOp>(convertOp->getParentOp());
1129+
if (!parent) return false;
1130+
1131+
// Find all the dot-like ops in the for loop that have a nvidia dot operand
1132+
// encoding on the lhs and check if any of them post-dominates the load +
1133+
// cvt
1134+
SmallVector<Operation *> dotLikeOps;
1135+
parent->walk([&](Operation *op) {
1136+
if (!isa<mlir::triton::DotOpInterface>(op)) return;
1137+
auto opType = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1138+
if (!opType) return;
1139+
auto dotEnc = dyn_cast<DotOperandEncodingAttr>(opType.getEncoding());
1140+
if (!dotEnc) return;
1141+
if (isa<NvidiaMmaEncodingAttr>(dotEnc.getParent()))
1142+
dotLikeOps.push_back(op);
1143+
});
1144+
if (dotLikeOps.empty()) return false;
1145+
return llvm::any_of(dotLikeOps, [&](Operation *dot) {
1146+
return postDomInfo.postDominates(dot, convertOp);
1147+
});
1148+
};
1149+
11231150
// We move convert #dot_operand next to their loads. This is done
11241151
// so that it's then easy to pipeline these loads
1125-
// TODO: Perhaps we should do this whenever convertOp is within a loop
1126-
1127-
auto dotEnc = dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding());
1128-
if (!(dotEnc && isa<NvidiaMmaEncodingAttr>(dotEnc.getParent())))
1129-
return;
1152+
if (!canBePipelined(convertOp)) return;
11301153

11311154
// We hoist over any operation that can be done without data movement between
11321155
// threads We do views and elementwise pure ops for now

lib/Dialect/TritonGPU/Transforms/Utility.cpp

+19-11
Original file line numberDiff line numberDiff line change
@@ -1022,18 +1022,26 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
10221022
} else {
10231023
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
10241024
return std::nullopt;
1025-
auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
1026-
cast<triton::gpu::TensorOrMemDesc>(user->getResult(0).getType())
1027-
.getEncoding());
1028-
if (!dotOpEnc)
1025+
auto enc =
1026+
cast<triton::gpu::TensorOrMemDesc>(user->getResult(0).getType()).getEncoding();
1027+
if (isa<ttg::DotOperandEncodingAttr>(enc)) {
1028+
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
1029+
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
1030+
auto order = ttg::getOrder(srcTy.getEncoding());
1031+
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
1032+
tempAttr = ttg::SwizzledSharedEncodingAttr::get(
1033+
val.getContext(), cast<ttg::DotOperandEncodingAttr>(enc),
1034+
srcTy.getShape(), order, CTALayout, bitWidth, /*needTrans=*/false);
1035+
} else if (enc.getAbstractAttribute().getName().str() ==
1036+
"triton.gpu.sparse_dot_meta_encoding") {
1037+
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
1038+
tempAttr = ttg::SwizzledSharedEncodingAttr::get(
1039+
val.getContext(), /*vec=*/1, /*perPhase=*/1, /*maxPhase=*/1,
1040+
ttg::getOrder(srcTy.getEncoding()),
1041+
ttg::getCTALayout(srcTy.getEncoding()));
1042+
} else {
10291043
return std::nullopt;
1030-
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
1031-
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
1032-
auto order = ttg::getOrder(srcTy.getEncoding());
1033-
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
1034-
tempAttr = ttg::SwizzledSharedEncodingAttr::get(
1035-
val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout,
1036-
bitWidth, /*needTrans=*/false);
1044+
}
10371045
}
10381046
// Check that the shared encodings needed by the users are compatible.
10391047
if (attr != nullptr && attr != tempAttr) {

0 commit comments

Comments
 (0)