Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/hcu3.6-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ jobs:
set -x
source ~/env.sh
python3 -m pytest -s third_party/hcu/python/test/unit
python3 -m pytest -s python/test/tle/integration/test_tle_local_store.py
python3 -m pytest -s python/test/tle/integration/test_tle_gemm.py
python3 -m pytest -s python/test/tle/integration/test_tle_pipeline_e2e.py
python3 -m pytest -s python/test/tle/unit/test_tle_gpu_local_ptr.py
python3 -m pytest -s python/test/tle/unit/test_tle_gpu_slot.py
12 changes: 12 additions & 0 deletions python/test/tle/unit/test_tle_gpu_local_ptr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def _is_enflame_backend():
return target.backend == "gcu"


def _is_hcu_backend():
target = triton.runtime.driver.active.get_current_target()
return target.backend == "hip"


def _require_cuda():
try:
if _is_enflame_backend():
Expand Down Expand Up @@ -380,6 +385,13 @@ def test_local_pointer_none_generates_full_view_1d(self):
line_lhs = line.split(":", 1)[0]
assert "tle.local_pointers" in line_lhs
assert "," not in line_lhs
elif _is_hcu_backend():
ttgir = compiled.asm["ttgir"]
line = next((line for line in ttgir.splitlines() if "tle.local_pointers" in line), None)
if line is not None:
line_lhs = line.split(":", 1)[0]
assert "tle.local_pointers" in line_lhs
assert "," not in line_lhs
else:
ttgir = compiled.asm["ttgir"]
assert "ttg.local_store" in ttgir
Expand Down
7 changes: 7 additions & 0 deletions third_party/hcu/backend/compiler_hcu.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,16 @@ def make_ttgir(mod, metadata, options):
pm.enable_debug()
emuTF32 = False
passes.ttgpuir.add_coalesce(pm)
passes.ttgpuir.add_process_shared_memory_hint(pm) # flagtree hints
passes.ttgpuir.add_f32_dot_tc(pm, emuTF32)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_thread_locality(pm)
# begin flagtree tle
tle.passes.add_select_encodings(pm)
tle.passes.add_insert_local_pointer_barriers(pm)
tle.passes.add_optimize_local_pointer_loads(pm)
tle.passes.add_optimize_local_pointer_stores(pm)
# end flagtree tle
hcu.passes.ttgpuir.add_accelerate_matmul(pm, options.arch, options.matrix_instr_nonkdim, options.kpack,
options.mmac_layout_force)
passes.ttgpuir.add_remove_layout_conversions(pm)
Expand Down
13 changes: 12 additions & 1 deletion third_party/hcu/lib/TritonHCUGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
if(FLAGTREE_TLE)
set(_TLE_LIBS
TritonTLEAnalysis
TritonNVIDIAGPUToLLVM
TleToLLVM # flagtree tle raw
TritonTLETransforms
)
else()
set(_TLE_LIBS "")
endif()

add_triton_library(TritonHCUGPUToLLVM
AsyncUtility.cpp
AtomicRMWOpsEmitter.cpp
Expand Down Expand Up @@ -43,7 +54,7 @@ add_triton_library(TritonHCUGPUToLLVM
LINK_LIBS PUBLIC
TritonGPUToLLVM
TritonHCUGPUIR
TritonTLETransforms
${_TLE_LIBS}
LLVMCore
LLVMPasses
LLVMSupport
Expand Down
4 changes: 4 additions & 0 deletions third_party/hcu/lib/TritonHCUGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"
#ifdef __TLE__
#include "tle/dialect/include/Conversion/TleToLLVM/LocalPointersOpToLLVM.h"
#include "tle/dialect/include/IR/Dialect.h"
#include "tle/dialect/include/Transforms/PatternTleToLLVM.h"
#endif
Expand Down Expand Up @@ -212,6 +213,9 @@ struct ConvertTritonHCUGPUToLLVM
mlir::triton::tle::populateInsertTileOpToLLVMPatterns(
typeConverter, tlePatterns, targetInfo,
patternBenefitPrioritizeOverLLVMConversions);
mlir::triton::tle::populateLocalPointersOpToLLVMPatterns(
typeConverter, targetInfo, tlePatterns,
patternBenefitPrioritizeOverLLVMConversions);
if (failed(
applyPartialConversion(mod, tleTarget, std::move(tlePatterns))))
return signalPassFailure();
Expand Down
Loading