Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions .github/workflows/tileir3.6-build-and-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
name: TileIR3.6-Build-And-Test

on:
push:
branches: [ "triton_v3.6.x" ]
pull_request:
branches: [ "triton_v3.6.x" ]

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
tileir-build-and-test:
runs-on: hopper
if: ${{ github.repository == 'FlagTree/flagtree' || github.repository == 'flagos-ai/flagtree' }}
steps:
- name: Setup environment
shell: bash
run: |
source ~/env.sh
env | grep -E '^(http_proxy|https_proxy|all_proxy|no_proxy)=' >> $GITHUB_ENV || true

- name: Smart Checkout
uses: flagos-ai/FlagTree/.github/actions/smart-checkout@main
with:
checkout_version: 'v6'

- name: Check if backend-relevant files changed
id: check_backend
uses: flagos-ai/FlagTree/.github/actions/check-backend-changed@main
with:
backend: tileir

- name: FlagTree Build on TileIR
if: steps.check_backend.outputs.should_skip != 'true'
shell: bash
run: |
set -x
pip uninstall -y triton || true
source ~/env-3.6.sh
env | grep -E '^(LLVM_SYSPATH)=' >> $GITHUB_ENV || true
/usr/local/cuda/bin/tileiras --version
rm -rf ~/.triton/cache
MAX_JOBS=32 CMAKE_BUILD_PARALLEL_LEVEL=32 \
TRITON_PTXAS_PATH="$PWD/third_party/nvidia/backend/bin/ptxas" \
TRITON_PTXAS_BLACKWELL_PATH="$PWD/third_party/nvidia/backend/bin/ptxas-blackwell" \
python3 -m pip install . --no-build-isolation

- name: FlagTree Test on TileIR
if: steps.check_backend.outputs.should_skip != 'true'
shell: bash
run: |
set -x
source ~/env-3.6.sh
rm -rf /tmp/flagtree_tileir_ci_01_cache /tmp/flagtree_tileir_ci_02_cache
TRITON_CACHE_DIR=/tmp/flagtree_tileir_ci_01_cache \
python3 python/tutorials/tileir/01-load-view-token-ordering.py
TRITON_CACHE_DIR=/tmp/flagtree_tileir_ci_02_cache \
python3 python/tutorials/tileir/02-mixed-kernel-routing.py
python3 python/tutorials/tileir/03-triton-tileir-benchmarks.py \
--case all \
--ci-subset \
--path both \
--warmup 1 \
--rep 1 \
--min-rep 1 \
--initial-rep 1 \
--cache-dir /tmp/flagtree_tileir_ci_03_cache \
--clear-cache
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/tileir/third_party/cuda-tile"]
path = third_party/tileir/third_party/cuda-tile
url = https://github.com/NVIDIA/cuda-tile
5 changes: 4 additions & 1 deletion python/setup_tools/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _get_flagtree_root() -> str:

@dataclass
class FlagtreeConfigs:
default_backends: tuple = ("nvidia", "amd")
default_backends: tuple = ("nvidia", "amd", "tileir")
plugin_backends: tuple = ("cambricon", "ascend", "aipu", "tsingmicro", "enflame", "hcu", "thrive")
use_cuda_toolkit_backends: tuple = ('aipu', )
language_extra_backends: tuple = ('xpu', 'mthreads', "cambricon")
Expand All @@ -46,6 +46,9 @@ def __post_init__(self):
self.flagtree_submodule_dir = os.path.join(self.flagtree_root_dir, "third_party")
self.activated_module = self._activate_device_module()

def non_tileir_default_backends(self):
return tuple(backend for backend in self.default_backends if backend != "tileir")

def _activate_device_module(self, suffix=".py"):
backend = self.flagtree_backend or "default"
module_path = Path(os.path.dirname(__file__)) / backend
Expand Down
12 changes: 10 additions & 2 deletions python/src/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,24 @@

namespace py = pybind11;

// flagtree: extended from FOR_EACH_5 to FOR_EACH_8 to give headroom for the
// expanded backend set (nvidia + amd + tileir + proton + tle = 5 by default;
// can be 6+ with TRITON_PLUGIN_DIRS or FLAGTREE_BACKEND=thrive/aipu/...).
// Stopping at 5 caused the C preprocessor to emit an undefined FOR_EACH_N
// when TRITON_BACKENDS_TUPLE grew past 5 entries.
#define FOR_EACH_1(MACRO, X) MACRO(X)
#define FOR_EACH_2(MACRO, X, ...) MACRO(X) FOR_EACH_1(MACRO, __VA_ARGS__)
#define FOR_EACH_3(MACRO, X, ...) MACRO(X) FOR_EACH_2(MACRO, __VA_ARGS__)
#define FOR_EACH_4(MACRO, X, ...) MACRO(X) FOR_EACH_3(MACRO, __VA_ARGS__)
#define FOR_EACH_5(MACRO, X, ...) MACRO(X) FOR_EACH_4(MACRO, __VA_ARGS__)
#define FOR_EACH_6(MACRO, X, ...) MACRO(X) FOR_EACH_5(MACRO, __VA_ARGS__)
#define FOR_EACH_7(MACRO, X, ...) MACRO(X) FOR_EACH_6(MACRO, __VA_ARGS__)
#define FOR_EACH_8(MACRO, X, ...) MACRO(X) FOR_EACH_7(MACRO, __VA_ARGS__)

#define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__, FOR_EACH_RSEQ_N())
#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__)
#define FOR_EACH_ARG_N(_1, _2, _3, _4, _5, N, ...) N
#define FOR_EACH_RSEQ_N() 5, 4, 3, 2, 1, 0
#define FOR_EACH_ARG_N(_1, _2, _3, _4, _5, _6, _7, _8, N, ...) N
#define FOR_EACH_RSEQ_N() 8, 7, 6, 5, 4, 3, 2, 1, 0

#define CONCATENATE(x, y) CONCATENATE1(x, y)
#define CONCATENATE1(x, y) x##y
Expand Down
2 changes: 1 addition & 1 deletion python/triton/_filecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def run_parser(kernel_fn, args=(), kwargs={}, target=stub_target):

codegen_fns = backend.get_codegen_implementation(options)
module_map = backend.get_module_map()
module = src.make_ir(target, options, codegen_fns, module_map, context)
module = backend.make_ir(src, options, codegen_fns, module_map, context)
return module


Expand Down
65 changes: 65 additions & 0 deletions python/triton/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,68 @@ def _discover_backends() -> dict[str, Backend]:


backends: dict[str, Backend] = _discover_backends()


def get_backend(target) -> Backend:
compatible = [backend for backend in backends.values() if backend.compiler.supports_target(target)]
if len(compatible) != 1:
raise RuntimeError(f"{len(compatible)} compatible backends for target ({target.backend}) ({compatible}). "
"There should only be one.")
return compatible[0]


def route_target(target, jit_fn):
routes = []
for name, backend in backends.items():
candidate = backend.compiler.route_target(target, jit_fn)
if candidate is not None and candidate != target:
routes.append((name, candidate))
if len(routes) > 1:
names = ", ".join(name for name, _ in routes)
raise RuntimeError(f"Multiple backends requested this kernel: {names}")
return routes[0][1] if routes else target


_target_drivers = {}


def get_driver(target, active_driver):
driver_cls = get_backend(target).driver
if isinstance(active_driver, driver_cls):
return active_driver
if driver_cls not in _target_drivers:
_target_drivers[driver_cls] = driver_cls()
return _target_drivers[driver_cls]


class _LanguageExtensions:

def __init__(self):
self._modules = None
self._symbols = {}

def _get_modules(self):
if self._modules is None:
self._modules = []
for name, backend in backends.items():
module = backend.compiler.get_language_extension()
if module is not None:
self._modules.append((name, module))
return self._modules

def __getattr__(self, name):
if name in self._symbols:
return self._symbols[name]
providers = [(backend_name, getattr(module, name))
for backend_name, module in self._get_modules()
if hasattr(module, name)]
if not providers:
raise RuntimeError(f"tl.ext.{name} is not provided by an installed Triton backend")
if len(providers) > 1:
names = ", ".join(backend_name for backend_name, _ in providers)
raise RuntimeError(f"tl.ext.{name} is provided by multiple backends: {names}")
self._symbols[name] = providers[0][1]
return self._symbols[name]


language_extensions = _LanguageExtensions()
14 changes: 14 additions & 0 deletions python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ def __init__(self, target: GPUTarget) -> None:
self.target = target
assert self.supports_target(target)

@classmethod
def route_target(cls, target: GPUTarget, jit_fn):
"""Return an alternate per-kernel target, or ``None`` to keep ``target``."""
return None

@classmethod
def get_language_extension(cls):
"""Return an optional module contributing symbols to ``triton.language.ext``."""
return None

def make_ir(self, src, options, codegen_fns, module_map, context):
"""Create the initial IR module for this backend."""
return src.make_ir(self.target, options, codegen_fns, module_map, context)

@staticmethod
@abstractmethod
def supports_target(target: GPUTarget):
Expand Down
23 changes: 12 additions & 11 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import hashlib
import json
from .._C.libtriton import get_cache_invalidating_env_vars, ir
from ..backends import backends
from ..backends import get_backend, get_driver
from ..backends.compiler import Language
from ..backends.compiler import BaseBackend, GPUTarget
from .. import __version__, knobs
Expand Down Expand Up @@ -301,7 +301,7 @@ def compile(src, target=None, options=None, _env_vars=None):
codegen_fns = backend.get_codegen_implementation(options)
module_map = backend.get_module_map()
try:
module = src.make_ir(target, options, codegen_fns, module_map, context)
module = backend.make_ir(src, options, codegen_fns, module_map, context)
except Exception as e:
filter_traceback(e)
raise
Expand Down Expand Up @@ -361,11 +361,7 @@ def compile(src, target=None, options=None, _env_vars=None):


def make_backend(target: GPUTarget) -> BaseBackend:
actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)]
if len(actives) != 1:
raise RuntimeError(
f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.")
return actives[0](target)
return get_backend(target).compiler(target)


class LazyDict:
Expand Down Expand Up @@ -437,6 +433,10 @@ def __init__(self, src, metadata_group, hash):
self.function = None
self._run = None

def _get_kernel_driver(self):
"""Return the runtime driver paired with this kernel's compile target."""
return get_driver(self.metadata.target, driver.active)

def _init_handles(self):
if self.module is not None:
return
Expand All @@ -452,10 +452,11 @@ def raise_(err):
raise err

device = driver.active.get_current_device()
kernel_driver = self._get_kernel_driver()
# create launcher
self._run = driver.active.launcher_cls(self.src, self.metadata)
# not enough shared memory to run the kernel
self._run = kernel_driver.launcher_cls(self.src, self.metadata)
max_shared = max_shared_mem(device)
# not enough shared memory to run the kernel
if self.metadata.shared > max_shared:
raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:
Expand All @@ -466,9 +467,9 @@ def raise_(err):
if knobs.runtime.kernel_load_start_hook is not None:
knobs.runtime.kernel_load_start_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = kernel_driver.utils.load_binary(
self.name, self.kernel, self.metadata.shared, device)
warp_size = driver.active.get_current_target().warp_size
warp_size = self.metadata.target.warp_size
if self.metadata.num_warps * warp_size > self.n_max_threads:
raise_(OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads"))
if knobs.runtime.kernel_load_end_hook is not None:
Expand Down
2 changes: 2 additions & 0 deletions python/triton/experimental/tle/language/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from .types import (layout, shared_layout, swizzled_shared_layout, tensor_memory_layout, nv_mma_shared_layout, scope,
buffered_tensor, buffered_tensor_type, smem, tmem)
from . import tile

# Backward-compat alias expected by existing tests/tutorials.
storage_kind = memory_space
Expand All @@ -31,4 +32,5 @@
"buffered_tensor_type",
"smem",
"tmem",
"tile",
]
9 changes: 9 additions & 0 deletions python/triton/experimental/tle/language/gpu/tile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# flagtree tle
import triton.language as _language


def __getattr__(name):
language_extensions = getattr(_language, "ext", None)
if language_extensions is None:
raise RuntimeError(f"tle.gpu.tile.{name} requires a backend providing tl.ext.{name}")
return getattr(language_extensions, name)
3 changes: 3 additions & 0 deletions python/triton/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from . import math
from . import extra
from ..backends import language_extensions as ext

from .standard import (
argmax,
argmin,
Expand Down Expand Up @@ -187,6 +189,7 @@
"exp2",
"expand_dims",
"extra",
"ext",
"fdiv",
"flip",
"float16",
Expand Down
2 changes: 2 additions & 0 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,9 @@ def create_binder(self):
Precompute as much as possible.
"""
from ..compiler import CompiledKernel, compile, ASTSource, make_backend
from ..backends import route_target
target = driver.active.get_current_target()
target = route_target(target, self)
backend = make_backend(target)
self.CompiledKernel = CompiledKernel
self.compile = compile
Expand Down
Loading
Loading