Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[triton-raise-block-pointer]: Introduce env. variable to ignore masked load/stores #3416

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
9 changes: 0 additions & 9 deletions scripts/pytest-utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,6 @@ pytest() {
run_tutorial_test() {
echo
echo "****** Running $1 test ******"
if [[ $TRITON_INTEL_RAISE_BLOCK_POINTER = true ]]; then
echo "****** With: INTEL_RAISE_BLOCK_POINTER ******"
PREV_TRITON_TEST_REPORTS=$TRITON_TEST_REPORTS
TRITON_TEST_REPORTS=false
fi
echo

TUTORIAL_RESULT=TODO
Expand Down Expand Up @@ -90,10 +85,6 @@ run_tutorial_test() {
echo $TUTORIAL_RESULT > "$TRITON_TEST_REPORTS_DIR/tutorial-$1.txt"
fi

if [[ $TRITON_INTEL_RAISE_BLOCK_POINTER = true ]]; then
TRITON_TEST_REPORTS=$PREV_TRITON_TEST_REPORTS
fi

if [[ $TUTORIAL_RESULT = FAIL && $TRITON_TEST_IGNORE_ERRORS = false ]]; then
exit 1
fi
Expand Down
9 changes: 7 additions & 2 deletions scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,20 @@ run_tutorial_tests() {
run_tutorial_test "01-vector-add"
run_tutorial_test "02-fused-softmax"
run_tutorial_test "03-matrix-multiplication"
TRITON_INTEL_RAISE_BLOCK_POINTER=true \
run_tutorial_test "03-matrix-multiplication" "TRITON_INTEL_RAISE_BLOCK_POINTER"
run_tutorial_test "04-low-memory-dropout"
run_tutorial_test "05-layer-norm"
run_tutorial_test "06-fused-attention"
run_tutorial_test "07-extern-functions"
run_tutorial_test "08-grouped-gemm"
run_tutorial_test "10-experimental-block-pointer"
run_tutorial_test "10i-experimental-block-pointer"

echo "\n***************************************************"
echo "Running with TRITON_INTEL_RAISE_BLOCK_POINTER=ignore-masks"
echo "***************************************************"

TRITON_TEST_REPORTS=false TRITON_INTEL_RAISE_BLOCK_POINTER=ignore-masks \
run_tutorial_test "03-matrix-multiplication"
}

run_microbench_tests() {
Expand Down
10 changes: 3 additions & 7 deletions test/Triton/Intel/RaiseToBlockPointers/addptr_cmpge.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s -triton-raise-block-pointer --split-input-file -canonicalize | FileCheck %s
// RUN: triton-opt %s -triton-raise-block-pointer=ignore-masks=true --split-input-file -canonicalize | FileCheck %s

// These tests check that loads/stores that exhibit a cmp ge against 0 work
// correctly with the pointer analysis pass
Expand Down Expand Up @@ -37,9 +37,7 @@ tt.func public @test_masked_load(%arg0: !tt.ptr<f16>) -> tensor<16x16xf16> {
%13 = tt.expand_dims %12 {axis = 0 : i32} : tensor<16xi64> -> tensor<1x16xi64>
%14 = arith.cmpi sge, %13, %cst : tensor<1x16xi64>
%15 = tt.broadcast %14 : tensor<1x16xi1> -> tensor<16x16xi1>
%16 = tt.load %8 evictionPolicy = evict_last : tensor<16x16x!tt.ptr<f16>>
// TODO: Replace above with below once support for masked loads is complete.
// %16 = tt.load %8, %15 evictionPolicy = evict_last : tensor<16x16x!tt.ptr<f16>>
%16 = tt.load %8, %15 evictionPolicy = evict_last : tensor<16x16x!tt.ptr<f16>>
tt.return %16 : tensor<16x16xf16>
}

Expand All @@ -65,9 +63,7 @@ tt.func public @test_masked_store(%arg0: !tt.ptr<f16>) {
%5 = tt.addptr %0, %4 : tensor<16x16x!tt.ptr<f16>>, tensor<16x16xi64>
%6 = arith.cmpi sge, %3, %cst : tensor<16x1xi64>
%7 = tt.broadcast %6 : tensor<16x1xi1> -> tensor<16x16xi1>
// TODO: Replace above with below once support for masked stores is complete.
// tt.store %5, %cst_0, %7 : tensor<16x16x!tt.ptr<f16>>
tt.store %5, %cst_0 : tensor<16x16x!tt.ptr<f16>>
tt.store %5, %cst_0, %7 : tensor<16x16x!tt.ptr<f16>>
tt.return
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s

// TODO: expand this example to 3D
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This TODO is meaning less.

module {
tt.func @kernel(
%arg0 : !tt.ptr<bf16>,
Expand Down
14 changes: 4 additions & 10 deletions test/Triton/Intel/RaiseToBlockPointers/kernel-01-vector-add.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s
// RUN: triton-opt %s -triton-raise-block-pointer=ignore-masks=true -canonicalize | FileCheck %s

module {
tt.func public @add_kernel_01234(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32) {
Expand All @@ -12,20 +12,14 @@ module {
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32>
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// TODO: add back once masked loads are supported
// %9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>>
%9 = tt.load %8 : tensor<1024x!tt.ptr<f32>>
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>>
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// TODO: add back once masked loads are supported
// %12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>>
%12 = tt.load %11 : tensor<1024x!tt.ptr<f32>>
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>>
%13 = arith.addf %9, %12 : tensor<1024xf32>
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// TODO: add back once masked stores are supported
// tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>>
tt.store %15, %13 : tensor<1024x!tt.ptr<f32>>
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>>
tt.return
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s
// RUN: triton-opt %s -triton-raise-block-pointer=ignore-masks=true -canonicalize | FileCheck %s

module {
tt.func public @softmax_kernel_012345(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: i32, %arg3: i32, %arg4: i32) {
Expand All @@ -12,9 +12,7 @@ module {
%6 = tt.splat %arg4 : i32 -> tensor<128xi32>
%7 = arith.cmpi slt, %3, %6 : tensor<128xi32>
%8 = tt.splat %cst : f32 -> tensor<128xf32>
// TODO: add back once masked loads are supported
// %9 = tt.load %5, %7, %8 : tensor<128x!tt.ptr<f32>>
%9 = tt.load %5 : tensor<128x!tt.ptr<f32>>
%9 = tt.load %5, %7, %8 : tensor<128x!tt.ptr<f32>>
%10 = "tt.reduce"(%9) ({
^bb0(%arg5: f32, %arg6: f32):
%21 = arith.cmpf ogt, %arg5, %arg6 : f32
Expand All @@ -35,9 +33,7 @@ module {
%18 = tt.addptr %arg0, %17 : !tt.ptr<f32>, i32
%19 = tt.splat %18 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>>
%20 = tt.addptr %19, %3 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
// TODO: add back once masked stores are supported
// tt.store %20, %16, %7 : tensor<128x!tt.ptr<f32>>
tt.store %20, %16 : tensor<128x!tt.ptr<f32>>
tt.store %20, %16, %7 : tensor<128x!tt.ptr<f32>>
tt.return
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s
// RUN: triton-opt %s -triton-raise-block-pointer=ignore-masks=true -canonicalize | FileCheck %s

module {
tt.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
Expand Down Expand Up @@ -97,9 +97,7 @@ module {
%64 = tt.broadcast %61 : tensor<64x1xi1> -> tensor<64x128xi1>
%65 = tt.broadcast %63 : tensor<1x128xi1> -> tensor<64x128xi1>
%66 = arith.andi %64, %65 : tensor<64x128xi1>
// TODO: add back once masked stores are supported
// tt.store %59, %50, %66 : tensor<64x128x!tt.ptr<f16>>
tt.store %59, %50 : tensor<64x128x!tt.ptr<f16>>
tt.store %59, %50, %66 : tensor<64x128x!tt.ptr<f16>>
tt.return
}
}
Expand All @@ -119,8 +117,8 @@ module {
// CHECK: [[VAR_24_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_23_]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[VAR_15_]]] {{.*}} : <tensor<32x128xf16>>
// CHECK: [[VAR_27_:%.+]] = arith.muli [[PARAM_7_]], [[CST_32_i32]] : i32
// CHECK: [[VAR_28_:%.+]]:3 = scf.for {{.*}} iter_args([[VAR_arg10_:%.+]] = [[VAR_cst_]], [[VAR_arg11_:%.+]] = [[VAR_21_]], [[VAR_arg12_:%.+]] = [[VAR_24_]]) -> (tensor<64x128xf32>, !tt.ptr<tensor<64x32xf16>>, !tt.ptr<tensor<32x128xf16>>) : i32 {
// CHECK: [[VAR_39_:%.+]] = tt.load [[VAR_arg11_]], {{.*}}, {{.*}} : !tt.ptr<tensor<64x32xf16>>
// CHECK: [[VAR_43_:%.+]] = tt.load [[VAR_arg12_]], {{.*}}, {{.*}} : !tt.ptr<tensor<32x128xf16>>
// CHECK-DAG: [[VAR_39_:%.+]] = tt.load [[VAR_arg11_]] : !tt.ptr<tensor<64x32xf16>>
// CHECK-DAG: [[VAR_43_:%.+]] = tt.load [[VAR_arg12_]] : !tt.ptr<tensor<32x128xf16>>
// CHECK: [[VAR_44_:%.+]] = tt.dot [[VAR_39_]], [[VAR_43_]], [[VAR_arg10_]], inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x128xf16> -> tensor<64x128xf32>
// CHECK-DAG: [[VAR_45_:%.+]] = tt.advance [[VAR_arg11_]], {{\[}}[[CST_0_i32]], [[CST_32_i32]]] : <tensor<64x32xf16>>
// CHECK-DAG: [[VAR_46_:%.+]] = tt.advance [[VAR_arg12_]], {{\[}}[[CST_0_i32]], [[VAR_27_]]] : <tensor<32x128xf16>>
Expand Down
28 changes: 16 additions & 12 deletions test/Triton/Intel/RaiseToBlockPointers/raise-block-pointer.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck --check-prefixes=CHECK,CHECK-DEFAULT %s
// RUN: triton-opt %s -triton-raise-block-pointer=ignore-masks=true -canonicalize | FileCheck --check-prefixes=CHECK,CHECK-IGNORE-MASKS %s

// COM: 1D PTR + LOAD
// CHECK-LABEL: tt.func @test_addptr_splat_make_range(
Expand All @@ -22,9 +23,12 @@ tt.func @test_addptr_splat_make_range(%arg0 : !tt.ptr<f32>) -> tensor<128xf32> {
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
// CHECK-SAME: %[[VAL_1:.*]]: i32,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<128xi1>) -> tensor<128xf32> {
// CHECK: %[[VAL_3:.*]] = tt.addptr
// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]], %[[VAL_2]] cacheModifier = ca evictionPolicy = evict_first {isVolatile = true} : tensor<128x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_4]] : tensor<128xf32>
// CHECK-DEFAULT: %[[VAL_3:.*]] = tt.addptr
// CHECK-DEFAULT: %[[VAL_4:.*]] = tt.load %[[VAL_3]], %[[VAL_2]] cacheModifier = ca evictionPolicy = evict_first {isVolatile = true} : tensor<128x!tt.ptr<f32>>
// CHECK-IGNORE-MASKS: %[[CST_0_i64:.*]] = arith.constant 0 : i64
// CHECK-IGNORE-MASKS: %[[VAL_3:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[CST_0_i64]]], {{\[}}%[[CST_0_i64]]], {{\[}}%[[VAL_1]]] {order = array<i32>} : <tensor<128xf32>>
// CHECK-IGNORE-MASKS: %[[VAL_4:.*]] = tt.load %[[VAL_3]] cacheModifier = ca evictionPolicy = evict_first {isVolatile = true} : !tt.ptr<tensor<128xf32>>
// CHECK: tt.return %[[VAL_4]] : tensor<128xf32>
tt.func @test_addptr_load_with_mask(%arg0 : !tt.ptr<f32>, %arg1: i32, %arg2: tensor<128xi1>) -> tensor<128xf32> {
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>>
%1 = tt.splat %arg1 : i32 -> tensor<128xi32>
Expand Down Expand Up @@ -615,15 +619,15 @@ module {
}

// CHECK: tt.func @matmul_kernel
// CHECK: tt.make_tensor_ptr %arg0
// CHECK: tt.make_tensor_ptr %arg1
// CHECK-IGNORE-MASKS: tt.make_tensor_ptr %arg0
// CHECK-IGNORE-MASKS: tt.make_tensor_ptr %arg1
// CHECK: scf.for
// CHECK: [[LOAD1:%.*]] = tt.load [[ARG10:%.*]], {{.*}}, {{.*}} : !tt.ptr<tensor<64x32xf16>>
// CHECK: [[LOAD2:%.*]] = tt.load [[ARG11:%.*]], {{.*}}, {{.*}} : !tt.ptr<tensor<32x128xf16>>
// CHECK: [[DOT:%.*]] = tt.dot [[LOAD1]], [[LOAD2]]
// CHECK: [[ADV1:%.*]] = tt.advance [[ARG10]], {{.*}} : <tensor<64x32xf16>>
// CHECK: [[ADV2:%.*]] = tt.advance [[ARG11]], {{.*}} : <tensor<32x128xf16>>
// CHECK: scf.yield [[DOT]], [[ADV1]], [[ADV2]]
// CHECK-IGNORE-MASKS: [[LOAD1:%.*]] = tt.load [[ARG10:%.*]] : !tt.ptr<tensor<64x32xf16>>
// CHECK-IGNORE-MASKS: [[LOAD2:%.*]] = tt.load [[ARG11:%.*]] : !tt.ptr<tensor<32x128xf16>>
// CHECK-IGNORE-MASKS: [[DOT:%.*]] = tt.dot [[LOAD1]], [[LOAD2]]
// CHECK-IGNORE-MASKS: [[ADV1:%.*]] = tt.advance [[ARG10]], {{.*}} : <tensor<64x32xf16>>
// CHECK-IGNORE-MASKS: [[ADV2:%.*]] = tt.advance [[ARG11]], {{.*}} : <tensor<32x128xf16>>
// CHECK-IGNORE-MASKS: scf.yield [[DOT]], [[ADV1]], [[ADV2]]
module {
tt.func @matmul_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16> , %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> tensor<64x128xf16> {
%cst = arith.constant dense<0.000000e+00> : tensor<64x128xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s
// RUN: triton-opt %s -triton-raise-block-pointer=ignore-masks=true -canonicalize | FileCheck %s

// IR from python/examples/sign_extend.py
module {
Expand All @@ -15,9 +15,7 @@ module {
%8 = arith.cmpi slt, %5, %7 : tensor<4xi64>
%9 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<4x!tt.ptr<f32>>
%10 = tt.addptr %9, %5 : tensor<4x!tt.ptr<f32>>, tensor<4xi64>
%11 = tt.load %10 : tensor<4x!tt.ptr<f32>>
// TODO: uncomment once masked loads are supported
// %11 = tt.load %10, %8, %cst : tensor<4x!tt.ptr<f32>>
%11 = tt.load %10, %8, %cst : tensor<4x!tt.ptr<f32>>
%12 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<4x!tt.ptr<f32>>
%13 = tt.addptr %12, %2 : tensor<4x!tt.ptr<f32>>, tensor<4xi32>
tt.store %13, %11 : tensor<4x!tt.ptr<f32>>
Expand Down
21 changes: 19 additions & 2 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,31 @@ def get_module_map(self) -> Dict[str, ModuleType]:
def load_dialects(self, ctx):
intel.load_dialects(ctx)

@staticmethod
def parse_raise_block_pointer_flags() -> dict:
str = os.getenv("TRITON_INTEL_RAISE_BLOCK_POINTER", "0")
raise_block_ptr_flags = {}
raise_block_ptr_flags['enabled'] = False
raise_block_ptr_flags['ignore-masks'] = False
for flag in str.split(':'):
if (flag == "1"):
raise_block_ptr_flags['enabled'] = True
if (flag == "ignore-masks"):
raise_block_ptr_flags['enabled'] = True
raise_block_ptr_flags['ignore-masks'] = True
return raise_block_ptr_flags

@staticmethod
def make_ttir(mod, metadata, opt):
raise_block_ptr_flags = XPUBackend.parse_raise_block_pointer_flags()

pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_combine(pm)
if os.getenv("TRITON_INTEL_RAISE_BLOCK_POINTER", "0") == "1":
intel.passes.ttir.add_raise_block_pointer(pm)
if raise_block_ptr_flags['enabled']:
ignore_masks = True if raise_block_ptr_flags['ignore-masks'] else False
intel.passes.ttir.add_raise_block_pointer(pm, ignore_masks)
passes.common.add_canonicalizer(pm)
passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
Expand Down
10 changes: 8 additions & 2 deletions third_party/intel/include/TritonRaiseBlockPointer/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@ def TritonRaiseBlockPointer
}];

let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::triton::TritonDialect",
"mlir::arith::ArithDialect",
"mlir::triton::TritonDialect",
];

let options = [
Option<"IgnoreMasks", "ignore-masks",
"bool", /*default*/"false",
"Drop masks from loads and stores">,
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,10 @@ struct TritonRaiseBlockPointer

void runOnOperation() final {
ModuleOp moduleOp = getOperation();

if (IgnoreMasks)
dropMasks(moduleOp);

if (failed(rewriteOp(moduleOp)))
moduleOp->emitWarning("TritonRaiseToBlockPointer failed");

Expand All @@ -508,6 +512,7 @@ struct TritonRaiseBlockPointer
assert(succeeded(verify(moduleOp)) && "Module verification failed");
}

private:
LogicalResult rewriteOp(Operation *rootOp, bool isNested = false) {
assert(rootOp && "Expected a valid operation");

Expand Down Expand Up @@ -572,8 +577,12 @@ struct TritonRaiseBlockPointer

auto canBeRewrittenUsingBlockPtr = [&](Operation *op) {
return TypeSwitch<Operation *, bool>(op)
.Case<tt::AddPtrOp, tt::LoadOp, tt::StoreOp>(
[](auto) { return true; })
.Case<tt::AddPtrOp>([](auto) { return true; })
.Case<tt::LoadOp>(
[this](auto loadOp) { return IgnoreMasks || !loadOp.getMask(); })
.Case<tt::StoreOp>([this](auto storeOp) {
return IgnoreMasks || !storeOp.getMask();
})
.Default([](auto) { return false; });
};

Expand Down Expand Up @@ -1117,6 +1126,45 @@ struct TritonRaiseBlockPointer
return success();
}

void dropMasks(ModuleOp moduleOp) const {
assert(IgnoreMasks && "Expecting 'IgnoreMask' flag to be set");

moduleOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
TypeSwitch<Operation *>(op)
.Case<tt::LoadOp>([&](auto loadOp) {
if (loadOp.getMask()) {
loadOp->emitWarning("TritonRaiseBlockPointer: ignoring mask");
OpBuilder builder(loadOp);
auto newLoadOp = builder.create<tt::LoadOp>(
loadOp.getLoc(), loadOp.getPtr(), loadOp.getBoundaryCheck(),
loadOp.getPadding(), loadOp.getCache(), loadOp.getEvict(),
loadOp.getIsVolatile());
loadOp->replaceAllUsesWith(newLoadOp);
loadOp->erase();
}
return WalkResult::advance();
})
.Case<tt::StoreOp>([&](auto storeOp) {
if (storeOp.getMask()) {
storeOp->emitWarning("TritonRaiseBlockPointer: ignoring mask");
OpBuilder builder(storeOp);
auto newStoreOp = builder.createOrFold<tt::StoreOp>(
storeOp.getLoc(), storeOp.getPtr(), storeOp.getValue(),
storeOp.getBoundaryCheck(), storeOp.getCache(),
storeOp.getEvict());

storeOp->erase();
if (storeOp.getMask().getUsers().empty())
storeOp.getMask().getDefiningOp()->erase();
}
return WalkResult::advance();
})
.Default([&](auto) { return WalkResult::advance(); });
});

moduleOp.dump();
}

static void dump(const IRMapping &map) {
for (auto [key, val] : map.getValueMap()) {
llvm::dbgs() << "key: " << key << "(0x" << &key << "), value: " << val
Expand Down
4 changes: 2 additions & 2 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ static uint32_t findKernels(llvm::Module &M,
}

void init_triton_intel_passes_ttir(py::module &&m) {
ADD_PASS_WRAPPER_0("add_raise_block_pointer",
intel::createTritonRaiseBlockPointer);
ADD_PASS_WRAPPER_OPT_1("add_raise_block_pointer",
intel::createTritonRaiseBlockPointer, bool);
ADD_PASS_WRAPPER_OPT_1("add_convert_to_ttgpuir_warp",
intel::createConvertTritonToTritonGPUWarp, unsigned);
}
Expand Down