Skip to content

Commit c8f89a6

Browse files
committed
OpenXLA-specific changes
1 parent a33e3ac commit c8f89a6

File tree

37 files changed

+2404
-64
lines changed

37 files changed

+2404
-64
lines changed

BUILD

+908
Large diffs are not rendered by default.

lib/Analysis/AxisInfo.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
937937
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
938938
lhsDivisibility = 1;
939939
}
940-
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
940+
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
941941
}
942942

943943
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
@@ -1084,8 +1084,8 @@ LogicalResult AxisInfoAnalysis::visitOperation(
10841084

10851085
void AxisInfoAnalysis::visitForOpInductionVar(
10861086
scf::ForOp op, ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices) {
1087-
auto lb = getLatticeElementFor(op, op.getLowerBound())->getValue();
1088-
auto step = getLatticeElementFor(op, op.getStep())->getValue();
1087+
auto lb = getLatticeElementFor(getProgramPointAfter(op), op.getLowerBound())->getValue();
1088+
auto step = getLatticeElementFor(getProgramPointAfter(op), op.getStep())->getValue();
10891089

10901090
AxisInfo::DimVectorT knownContiguity(1, 1);
10911091
AxisInfo::DimVectorT knownDivisibility(1, 1);

lib/Analysis/Utility.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -896,15 +896,15 @@ class ConstantAnalysis : public DataFlowAnalysis {
896896

897897
LogicalResult initialize(Operation *top) override {
898898
WalkResult result = top->walk([&](Operation *op) {
899-
if (failed(visit(op)))
899+
if (failed(visit(getProgramPointAfter(op))))
900900
return WalkResult::interrupt();
901901
return WalkResult::advance();
902902
});
903903
return success(!result.wasInterrupted());
904904
}
905905

906-
LogicalResult visit(ProgramPoint point) override {
907-
Operation *op = point.get<Operation *>();
906+
LogicalResult visit(ProgramPoint* point) override {
907+
Operation *op = point->getPrevOp();
908908
Attribute value;
909909
if (matchPattern(op, m_Constant(&value))) {
910910
auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
4040
auto ouEltTy = ouTensorTy.getElementType();
4141
if (inBitWidth == ouBitWidth)
4242
return values;
43-
if (inBitWidth == 16 && ouBitWidth == 32) {
43+
if ((inBitWidth == 16 && ouBitWidth == 32) ||
44+
(inBitWidth == 32 && ouBitWidth == 16)) {
4445
SmallVector<Value> ret;
4546
for (unsigned i = 0; i < values.size(); i += 8) {
4647
ret.push_back(values[i]);
@@ -54,7 +55,8 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
5455
}
5556
return ret;
5657
}
57-
if (inBitWidth == 8 && ouBitWidth == 16) {
58+
if ((inBitWidth == 8 && ouBitWidth == 16) ||
59+
(inBitWidth == 16 && ouBitWidth == 8)) {
5860
SmallVector<Value> ret;
5961
for (unsigned i = 0; i < values.size(); i += 16) {
6062
ret.push_back(values[i + 0]);

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,20 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
5656
// This will create newArg, and map(origArg, newArg)
5757
addArgumentMaterialization([&](OpBuilder &builder,
5858
RankedTensorType tensorType, ValueRange inputs,
59-
Location loc) -> std::optional<Value> {
59+
Location loc) -> Value {
6060
llvm_unreachable("Argument rematerialization should not happen in Triton "
6161
"-> TritonGPU conversion");
62-
return std::nullopt;
62+
return Value();
6363
});
6464

6565
// If the origValue still has live user(s), use this to
6666
// convert origValue to newValue
6767
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
6868
ValueRange inputs,
69-
Location loc) -> std::optional<Value> {
69+
Location loc) -> Value {
7070
llvm_unreachable("Source rematerialization should not happen in Triton -> "
7171
"TritonGPU Conversion");
72-
return std::nullopt;
72+
return Value();
7373
});
7474

7575
// This will be called when (desiredType != newOperandType)
@@ -79,7 +79,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
7979
ValueRange inputs, Location loc) {
8080
auto cast =
8181
builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType, inputs);
82-
return std::optional<Value>(cast.getResult());
82+
return Value(cast.getResult());
8383
});
8484
}
8585

lib/Dialect/TritonGPU/IR/Dialect.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -2770,6 +2770,11 @@ struct CanonicalizeConvertFromAlloc
27702770
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
27712771
if (!convert)
27722772
return failure();
2773+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
2774+
// to SharedEncoding, so we want to keep this layout conversion.
2775+
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
2776+
convert.getSrc().getType().getEncoding()))
2777+
return failure();
27732778
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
27742779
op, op->getResult(0).getType(), convert.getSrc());
27752780
return mlir::success();

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
153153
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
154154
newLayout, SharedMemorySpace);
155155
rewriter.setInsertionPointAfterValue(arg);
156+
157+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
158+
// to SharedEncoding.
159+
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
160+
argType.getEncoding())) {
161+
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
162+
// then pass it to the LocalAllocOp.
163+
auto newArgType = RankedTensorType::get(
164+
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
165+
auto dotOperandToBlockedCvt =
166+
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
167+
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
168+
dotOperandToBlockedCvt);
169+
}
170+
156171
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
157172
}
158173

@@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
162177
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;
163178

164179
static bool bwdFilter(Operation *op) {
180+
// Dot operand layout assignment to Predicates are not currently supported
181+
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
182+
// condition limits visibility of the original bit-width so that predicate
183+
// are not considered, hence, kwidth can never be = 32.
184+
if (isa<arith::UIToFPOp>(op)) {
185+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
186+
if (srcType.isInteger(1))
187+
return false;
188+
}
165189
return op->getNumOperands() == 1 &&
166190
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
167191
isPureUnaryInlineAsm(op) ||

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
111111
PatternRewriter &rewriter) const override {
112112
// Only consider conversions to dot operand.
113113
auto cvtTy = cast<RankedTensorType>(cvt.getType());
114-
if (!isa<DotOperandEncodingAttr>(cvtTy.getEncoding()))
114+
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
115+
if (!dotOpEnc)
115116
return failure();
116117

117118
auto src = cvt.getSrc().getDefiningOp();
@@ -126,6 +127,12 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
126127
[](Type ty) { return isa<RankedTensorType>(ty); }))
127128
return failure();
128129

130+
// Quick handling to fix loading issues when computing the original
131+
// bitwidth is unable to realize that there is a mixed-precision dot
132+
// (hence kWidth = 1) but wants to hoist through the type conversion.
133+
if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1)
134+
return failure();
135+
129136
// Only consider custom conversions or arith ops.
130137
// TODO(jlebar): Is this too restrictive?
131138
if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) &&
@@ -138,6 +145,14 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
138145
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
139146
return failure();
140147

148+
// Don't hoist through u1 -> fp casts as they aren't supported in
149+
// ElementwiseOpToLLVM::reorderValues().
150+
if (isa<arith::UIToFPOp>(src)) {
151+
Type srcType = getElementTypeOrSelf(src->getOperand(0));
152+
if (srcType.isInteger(1))
153+
return failure();
154+
}
155+
141156
// Check that the conversion is transitively dependent on a load, and all
142157
// operations between the load and the conversion are layout preserving.
143158
//

lib/Dialect/TritonGPU/Transforms/Prefetch.cpp

+24-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
116116
// opIdx: 0 => a, 1 => b
117117
auto type = cast<triton::MemDescType>(v.getType());
118118
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
119-
SmallVector<int64_t> offset{0, 0};
119+
SmallVector<int64_t> offset(shape.size(), 0);
120120
Type elementType = type.getElementType();
121121

122122
// k => (prefetchWidth, k - prefetchWidth)
@@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
140140
type.getMemorySpace()),
141141
v, offsetsVal);
142142

143+
// We need to assign kwidth to zero in the case where the parent layout is
144+
// Blocked, otherwise the verifier emits a failure. The parent layout is
145+
// Blocked only when Tensor Cores are disabled.
146+
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
147+
? 0
148+
: prefetchWidth / 8;
143149
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
144-
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
150+
builder.getContext(), opIdx, dotEncoding, kwidth);
145151
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
146152
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
147153
newSmem);
@@ -190,6 +196,22 @@ LogicalResult Prefetcher::initialize() {
190196
break;
191197
if (!op->getResult(0).hasOneUse())
192198
break;
199+
// Similar to issues faced in HoistLayoutConversion pattern in
200+
// OptimizeDotOperands.cpp, we can't propagate through type casts from
201+
// predicates as they aren't supported in Triton when encoded with dot_op
202+
// layout.
203+
if (isa<arith::UIToFPOp>(op)) {
204+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
205+
if (srcType.isInteger(1))
206+
break;
207+
}
208+
// Propagation through ExpandDims is currently not supported. This blindly
209+
// replaces the encoding with dot encoding & but ExpandDims requires a
210+
// SliceEncoding. This could be rewritten to support it somehow, but I
211+
// don't think it's trivial & it's currently crashing.
212+
if (isa<ExpandDimsOp>(op)) {
213+
break;
214+
}
193215
rets.push_back(op->getOperand(0));
194216
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
195217
foundConvertFromShared = true;

lib/Target/LLVMIR/LLVMDIScope.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ struct LLVMDIScopePass : public LLVMDIScopeBase<LLVMDIScopePass> {
104104
auto subprogramAttr = LLVM::DISubprogramAttr::get(
105105
context, distinctId, compileUnitAttr, fileAttr, funcNameAttr,
106106
funcNameAttr, fileAttr, /*line=*/line, /*scopeline=*/line,
107-
subprogramFlags, subroutineTypeAttr, /*retainNodes=*/{});
107+
subprogramFlags, subroutineTypeAttr, /*retainNodes=*/{},
108+
/*annotations=*/{});
108109
funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr));
109110
}
110111

python/BUILD

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# NOTE: Do not depend on any targets from this directory,
2+
# but use //third_party/py/triton instead.
3+
4+
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
5+
6+
package(
7+
default_applicable_licenses = ["//:license"],
8+
default_visibility = [
9+
"//third_party/py/triton:__pkg__",
10+
"@triton//python:__subpackages__",
11+
],
12+
)
13+
14+
cc_library(
15+
name = "passes",
16+
hdrs = ["src/passes.h"],
17+
includes = ["src"],
18+
visibility = ["@triton//third_party:__subpackages__"],
19+
)
20+
21+
pybind_extension(
22+
name = "libtriton",
23+
srcs = [
24+
"src/interpreter.cc",
25+
"src/ir.cc",
26+
"src/llvm.cc",
27+
"src/main.cc",
28+
"src/passes.cc",
29+
],
30+
copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"],
31+
deps = [
32+
":passes",
33+
"@llvm-project//llvm:Core",
34+
"@llvm-project//llvm:IPO",
35+
"@llvm-project//llvm:IRReader",
36+
"@llvm-project//llvm:InstCombine",
37+
"@llvm-project//llvm:Linker",
38+
"@llvm-project//llvm:MC",
39+
"@llvm-project//llvm:Passes",
40+
"@llvm-project//llvm:Support",
41+
"@llvm-project//llvm:Target",
42+
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
43+
"@llvm-project//mlir:BytecodeWriter",
44+
"@llvm-project//mlir:ControlFlowDialect",
45+
"@llvm-project//mlir:ConversionPasses",
46+
"@llvm-project//mlir:IR",
47+
"@llvm-project//mlir:IndexDialect",
48+
"@llvm-project//mlir:LLVMDialect",
49+
"@llvm-project//mlir:LLVMIRTransforms",
50+
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
51+
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
52+
"@llvm-project//mlir:Parser",
53+
"@llvm-project//mlir:Pass",
54+
"@llvm-project//mlir:Support",
55+
"@llvm-project//mlir:ToLLVMIRTranslation",
56+
"@llvm-project//mlir:Transforms",
57+
"//:TritonAnalysis",
58+
"//:TritonDialects",
59+
"//:TritonGPUToLLVM",
60+
"//:TritonGPUTransforms",
61+
"//:TritonHSACO",
62+
"//:TritonLLVMIR",
63+
"//:TritonNvidiaGPUTransforms",
64+
"//:TritonPTX",
65+
"//:TritonToTritonGPU",
66+
"//:TritonTools",
67+
"//:TritonTransforms",
68+
"@triton//third_party/nvidia:triton_nvidia",
69+
],
70+
)
71+
72+
filegroup(
73+
name = "files",
74+
srcs = glob(
75+
include = ["triton/**/*.py"],
76+
),
77+
)

python/test/regression/BUILD

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests")
2+
3+
package(
4+
default_applicable_licenses = ["//:license"],
5+
)
6+
7+
pytest_multi_tests(
8+
name = "tests",
9+
size = "large",
10+
srcs = ["conftest.py"],
11+
shard_count = 10,
12+
tags = [
13+
"config-cuda-only",
14+
"requires-gpu-sm80",
15+
],
16+
tests = glob(
17+
include = ["test_*.py"],
18+
exclude = [
19+
"test_performance.py", #TODO(b/321005767): fix failing test
20+
],
21+
),
22+
deps = [
23+
"//third_party/py/torch:pytorch",
24+
"//third_party/py/triton",
25+
],
26+
)

python/test/regression/conftest.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# content of conftest.py
2+
3+
import pytest
4+
5+
6+
def pytest_addoption(parser):
7+
parser.addoption("--device", action="store", default='cuda')
8+
9+
10+
@pytest.fixture
11+
def device(request):
12+
return request.config.getoption("--device")

0 commit comments

Comments
 (0)