Skip to content

Commit ec7dc0a

Browse files
committed
OpenXLA-specific changes
1 parent 48b4b60 commit ec7dc0a

File tree

43 files changed

+3539
-931
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+3539
-931
lines changed

BUILD

+854
Large diffs are not rendered by default.

lib/Analysis/Allocation.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
123123

124124
std::tie(scratchConfig.inVec, scratchConfig.outVec) =
125125
getScratchCvtInOutVecLengths(srcTy, dstTy);
126+
// We can't write a longer vector than the shape of shared memory.
127+
// This shape might be smaller than the tensor shape in case we decided to
128+
// do the conversion in multiple iterations.
129+
unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]];
130+
scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim);
131+
scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim);
126132

127133
// No padding is required if the tensor is 1-D, or if all dimensions except
128134
// the first accessed dimension have a size of 1.

lib/Analysis/AxisInfo.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
935935
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
936936
lhsDivisibility = 1;
937937
}
938-
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
938+
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
939939
}
940940

941941
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
5757
addArgumentMaterialization([&](OpBuilder &builder,
5858
RankedTensorType tensorType, ValueRange inputs,
5959
Location loc) -> Value {
60+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
61+
// remaining arguments that have been converted to a new type.
62+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
63+
// 'convert-triton-to-tritongpu'.
64+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
65+
inputs);
6066
llvm_unreachable("Argument rematerialization should not happen in Triton "
6167
"-> TritonGPU conversion");
6268
return {};
@@ -66,6 +72,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
6672
// convert origValue to newValue
6773
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
6874
ValueRange inputs, Location loc) -> Value {
75+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
76+
// remaining uses of values that have been converted to a new type.
77+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
78+
// 'convert-triton-to-tritongpu'.
79+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
80+
inputs);
6981
llvm_unreachable("Source rematerialization should not happen in Triton -> "
7082
"TritonGPU Conversion");
7183
return {};

lib/Dialect/TritonGPU/IR/Ops.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ struct CanonicalizeConvertFromAlloc
150150
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
151151
if (!convert)
152152
return failure();
153+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
154+
// to SharedEncoding, so we want to keep this layout conversion.
155+
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
156+
convert.getSrc().getType().getEncoding()))
157+
return failure();
153158
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
154159
op, op->getResult(0).getType(), convert.getSrc());
155160
return mlir::success();
@@ -212,13 +217,13 @@ struct CanonicalizeConvertFromConvert
212217
// heuristic to accommodate fused attention.
213218
auto srcType = op.getSrc().getType();
214219
auto dstType = op.getType();
215-
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
216-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
220+
if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
221+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
217222
return failure();
218223

219224
// for hopper MMAv3
220-
if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
221-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
225+
if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
226+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
222227
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
223228
return dot->hasTrait<OpTrait::DotLike>();
224229
})) {

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

+44-8
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,13 @@ namespace mlir {
2121
namespace triton {
2222
namespace gpu {
2323

24-
namespace {
25-
2624
// Get the highest version supported for the hardware and the dot.
2725
static int getMMAVersionSafe(int computeCapability, DotOp op) {
2826
// List supported mma version in order of preference.
2927
SmallVector<int> versionsSupported;
3028
if (computeCapability < 75) {
3129
versionsSupported = {1};
32-
} else if (computeCapability < 90) {
30+
} else if (computeCapability < 90 || computeCapability >= 100) {
3331
versionsSupported = {2};
3432
} else if (computeCapability < 100) {
3533
versionsSupported = {3, 2};
@@ -45,8 +43,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
4543
return 0;
4644
}
4745

48-
SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
49-
int numWarps) {
46+
SmallVector<unsigned>
47+
warpsPerTileV2(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps) {
5048
auto rank = shape.size();
5149
// Early exit for batched matmul
5250
if (rank == 3)
@@ -110,10 +108,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
110108
}
111109

112110
SmallVector<unsigned, 2>
113-
warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
111+
warpsPerTileV3(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps,
114112
const SmallVector<unsigned, 3> &instrShape) {
115113
SetVector<Operation *> slices;
116-
mlir::getForwardSlice(dotOp.getResult(), &slices);
114+
mlir::getForwardSlice(dotOp->getResult(0), &slices);
117115
// Contains a chained dot. We prefer to assign warps to one axis
118116
// to facilitate use cases like flash attention, allowing reductions within
119117
// the same warp.
@@ -168,11 +166,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
168166
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
169167
newLayout, SharedMemorySpace);
170168
rewriter.setInsertionPointAfterValue(arg);
169+
170+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
171+
// to SharedEncoding.
172+
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
173+
argType.getEncoding())) {
174+
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
175+
// then pass it to the LocalAllocOp.
176+
auto newArgType = RankedTensorType::get(
177+
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
178+
auto dotOperandToBlockedCvt =
179+
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
180+
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
181+
dotOperandToBlockedCvt);
182+
}
183+
171184
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
172185
}
173186

174187
SmallVector<unsigned, 3>
175-
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
188+
getWarpsPerTile(Operation* dotOp, const ArrayRef<int64_t> shape, int version,
176189
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
177190
switch (version) {
178191
case 2:
@@ -185,18 +198,32 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
185198
}
186199
}
187200

201+
// Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
202+
// extension.
203+
namespace {
204+
188205
class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
189206
int computeCapability;
190207
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;
191208

192209
static bool bwdFilter(Operation *op) {
210+
// Dot operand layout assignment to Predicates are not currently supported
211+
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
212+
// condition limits visibility of the original bit-width so that predicate
213+
// are not considered, hence, kwidth can never be = 32.
214+
if (isa<arith::UIToFPOp>(op)) {
215+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
216+
if (srcType.isInteger(1))
217+
return false;
218+
}
193219
return op->getNumOperands() == 1 &&
194220
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
195221
isPureUnaryInlineAsm(op) ||
196222
op->getDialect()->getTypeID() ==
197223
mlir::TypeID::get<arith::ArithDialect>());
198224
}
199225

226+
public:
200227
// Finds the first different bitwidth in the chain of shape-preserving
201228
// unary ops that x depends on.
202229
// There are two primary scenarios:
@@ -720,6 +747,15 @@ class TritonGPUAccelerateMatmulPass
720747
}
721748
};
722749

750+
// Expose helper functions from BlockedToMMA to be reused for sparse matmul.
751+
int computeOrigBitWidth(Value x) {
752+
return BlockedToMMA::computeOrigBitWidth(x);
753+
}
754+
Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter,
755+
int opIdx, bool allowTranspose) {
756+
return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose);
757+
}
758+
723759
} // namespace gpu
724760
} // namespace triton
725761
} // namespace mlir

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
284284
if (!foundInitializer)
285285
return failure();
286286

287+
rewriter.setInsertionPointAfter(src);
287288
SmallVector<ConvertLayoutOp> newOperands;
288289
for (auto operand : src->getOperands()) {
289290
// We checked earlier that all operands are ranked tensors.

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
131131

132132
Value zero = builder.createWithStage<arith::ConstantIntOp>(
133133
forOp.getLoc(), stage, clusterId, 0, 32);
134+
134135
// Replace the load with insert/extract slice.
135136
builder.setInsertionPoint(loadOp);
136137
Location loc = loadOp.getLoc();
@@ -491,7 +492,8 @@ assignMemoryLayouts(scf::ForOp &forOp,
491492
});
492493

493494
bool isTMALoad = isa<tt::ExperimentalDescriptorLoadOp>(op);
494-
loadsToPipeline.insert(&op);
495+
// TODO: b/381421713 - Uncomment this once pipelining is fixed.
496+
// loadsToPipeline.insert(&op);
495497
LoadInfo loadInfo;
496498
for (auto use : users) {
497499
if (use->hasTrait<OpTrait::DotLike>()) {
@@ -527,6 +529,11 @@ assignMemoryLayouts(scf::ForOp &forOp,
527529
getBlockedEncoding(loadOp, axisInfoAnalysis);
528530
}
529531
}
532+
533+
// TODO: b/381421713 - Remove this once pipelining is fixed.
534+
if (!loadInfo.sharedEncoding) continue;
535+
loadsToPipeline.insert(&op);
536+
530537
loadToInfo[&op] = loadInfo;
531538
}
532539
// Make sure all loads in loadsToPipeline are in loadToInfo.

lib/Dialect/TritonGPU/Transforms/Prefetch.cpp

+24-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
116116
// opIdx: 0 => a, 1 => b
117117
auto type = cast<triton::gpu::MemDescType>(v.getType());
118118
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
119-
SmallVector<int64_t> offset{0, 0};
119+
SmallVector<int64_t> offset(shape.size(), 0);
120120
Type elementType = type.getElementType();
121121

122122
// k => (prefetchWidth, k - prefetchWidth)
@@ -141,8 +141,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
141141
type.getMutableMemory(), type.getAllocShape()),
142142
v, offsetsVal);
143143

144+
// We need to assign kwidth to zero in the case where the parent layout is
145+
// Blocked, otherwise the verifier emits a failure. The parent layout is
146+
// Blocked only when Tensor Cores are disabled.
147+
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
148+
? 0
149+
: prefetchWidth / 8;
144150
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
145-
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
151+
builder.getContext(), opIdx, dotEncoding, kwidth);
146152
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
147153
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
148154
newSmem);
@@ -191,6 +197,22 @@ LogicalResult Prefetcher::initialize() {
191197
break;
192198
if (!op->getResult(0).hasOneUse())
193199
break;
200+
// Similar to issues faced in HoistLayoutConversion pattern in
201+
// OptimizeDotOperands.cpp, we can't propagate through type casts from
202+
// predicates as they aren't supported in Triton when encoded with dot_op
203+
// layout.
204+
if (isa<arith::UIToFPOp>(op)) {
205+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
206+
if (srcType.isInteger(1))
207+
break;
208+
}
209+
// Propagation through ExpandDims is currently not supported. This blindly
210+
// replaces the encoding with dot encoding & but ExpandDims requires a
211+
// SliceEncoding. This could be rewritten to support it somehow, but I
212+
// don't think it's trivial & it's currently crashing.
213+
if (isa<ExpandDimsOp>(op)) {
214+
break;
215+
}
194216
rets.push_back(op->getOperand(0));
195217
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
196218
foundConvertFromShared = true;

lib/Dialect/TritonGPU/Transforms/Utility.cpp

+19-11
Original file line numberDiff line numberDiff line change
@@ -994,18 +994,26 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
994994
} else {
995995
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
996996
return std::nullopt;
997-
auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
998-
cast<triton::gpu::TensorOrMemDesc>(user->getResult(0).getType())
999-
.getEncoding());
1000-
if (!dotOpEnc)
997+
auto enc =
998+
cast<triton::gpu::TensorOrMemDesc>(user->getResult(0).getType()).getEncoding();
999+
if (isa<ttg::DotOperandEncodingAttr>(enc)) {
1000+
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
1001+
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
1002+
auto order = ttg::getOrder(srcTy.getEncoding());
1003+
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
1004+
tempAttr = ttg::SharedEncodingAttr::get(
1005+
val.getContext(), cast<ttg::DotOperandEncodingAttr>(enc),
1006+
srcTy.getShape(), order, CTALayout, bitWidth, /*needTrans=*/false);
1007+
} else if (enc.getAbstractAttribute().getName().str() ==
1008+
"triton.gpu.sparse_dot_meta_encoding") {
1009+
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
1010+
tempAttr = ttg::SharedEncodingAttr::get(
1011+
val.getContext(), /*vec=*/1, /*perPhase=*/1, /*maxPhase=*/1,
1012+
ttg::getOrder(srcTy.getEncoding()),
1013+
ttg::getCTALayout(srcTy.getEncoding()));
1014+
} else {
10011015
return std::nullopt;
1002-
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
1003-
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
1004-
auto order = ttg::getOrder(srcTy.getEncoding());
1005-
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
1006-
tempAttr = ttg::SharedEncodingAttr::get(
1007-
val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout,
1008-
bitWidth, /*needTrans=*/false);
1016+
}
10091017
}
10101018
// Check that the shared encodings needed by the users are compatible.
10111019
if (attr != nullptr && attr != tempAttr) {

lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ struct FenceInsertionPass
4444
return;
4545
ModuleOp mod = getOperation();
4646
mod.walk([&](Operation *op) {
47-
if (!isa<ttng::WarpGroupDotOp>(op))
47+
if (!isa<ttng::WarpGroupDotOp>(op) &&
48+
op->getName().getStringRef() != "triton_xla.sparse_dot")
4849
return WalkResult::advance();
4950
OpBuilder builder(op);
5051
auto a = op->getOperand(0);

0 commit comments

Comments
 (0)