Skip to content

Commit 192f187

Browse files
committed
Add best effort triton-cpu support
Fixes #163 stack-info: PR: #1037, branch: oulgen/stack/163
1 parent 32eeadf commit 192f187

24 files changed

+122
-9
lines changed

.github/matrix.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@
7171
"container-options": "--device=/dev/kfd --device=/dev/dri",
7272
"pytorch-version": "pytorch-nightly",
7373
"alias": "mi325x"
74+
},
75+
{
76+
"runner": "linux.g5.4xlarge.nvidia.gpu",
77+
"python-version": "3.12",
78+
"ref-eager": false,
79+
"image": "nvidia/cuda:12.8.1-devel-ubuntu24.04",
80+
"runtime-version": "cpu",
81+
"container-options": "--gpus all",
82+
"pytorch-version": "pytorch-nightly",
83+
"alias": "cpu"
7484
}
7585
]
7686
}

.github/workflows/test.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ jobs:
9797
fi
9898
9999
- name: Install Triton
100-
if: steps.cache.outputs.cache-hit != 'true' && matrix.pytorch-version != 'pytorch-2.9'
100+
if: steps.cache.outputs.cache-hit != 'true' && (matrix.pytorch-version != 'pytorch-2.9' || contains(matrix.alias, 'cpu'))
101101
run: |
102102
set -x
103103
source .venv/bin/activate
@@ -110,7 +110,11 @@ jobs:
110110
cd /tmp/$USER
111111
uv pip uninstall triton pytorch-triton || true
112112
rm -rf triton/ || true
113-
git clone https://github.com/triton-lang/triton.git
113+
if [[ "${{ matrix.alias }}" == *cpu* ]]; then
114+
git clone --recursive -b main-merged https://github.com/triton-lang/triton-cpu.git triton
115+
else
116+
git clone https://github.com/triton-lang/triton.git triton
117+
fi
114118
cd triton/
115119
uv pip install -r python/requirements.txt
116120
MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 uv pip install .
@@ -131,6 +135,7 @@ jobs:
131135
if [[ "${{ matrix.dtype-asserts }}" == "true" ]]; then export HELION_DEBUG_DTYPE_ASSERTS=1; fi
132136
if [[ "${{ matrix.expecttest-accept }}" == "true" ]]; then export EXPECTTEST_ACCEPT=1; fi
133137
if [[ "${{ matrix.ref-eager }}" == "true" ]]; then export HELION_INTERPRET=1; fi
138+
if [[ "${{ contains(matrix.alias, 'cpu') }}" == "true" ]]; then export TRITON_CPU_BACKEND=1; fi
134139
# -rf: print failed tests
135140
# --timeout: max allowed time for each test
136141
pytest -rf --timeout=60

helion/_testing.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,37 @@
3434
from .runtime.kernel import Kernel
3535

3636

37-
DEVICE = torch.device("xpu") if torch.xpu.is_available() else torch.device("cuda")
38-
PROJECT_ROOT: Path = Path(__file__).parent.parent
39-
EXAMPLES_DIR: Path = PROJECT_ROOT / "examples"
37+
def _get_triton_backend() -> str | None:
38+
try:
39+
return triton.runtime.driver.active.get_current_target().backend # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
40+
except Exception:
41+
return None
4042

4143

42-
def is_cuda() -> bool:
43-
"""Return True if running on CUDA (NVIDIA GPU)."""
44+
def is_cpu() -> bool:
45+
"""Return True if running on Triton CPU backend."""
4446
return (
45-
triton.runtime.driver.active.get_current_target().backend == "cuda" # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
46-
and DEVICE.type == "cuda"
47+
os.environ.get("TRITON_CPU_BACKEND", "0") == "1"
48+
or _get_triton_backend() == "cpu"
4749
)
4850

4951

52+
def is_cuda() -> bool:
53+
"""Return True if running on CUDA (NVIDIA GPU)."""
54+
return _get_triton_backend() == "cuda" and torch.cuda.is_available()
55+
56+
57+
PROJECT_ROOT: Path = Path(__file__).parent.parent
58+
EXAMPLES_DIR: Path = PROJECT_ROOT / "examples"
59+
60+
if is_cpu():
61+
DEVICE = torch.device("cpu")
62+
elif torch.xpu.is_available():
63+
DEVICE = torch.device("xpu")
64+
else:
65+
DEVICE = torch.device("cuda")
66+
67+
5068
def get_nvidia_gpu_model() -> str:
5169
"""
5270
Retrieves the model of the NVIDIA GPU being used.
@@ -80,6 +98,11 @@ def skipIfXPU(reason: str) -> Callable[[Callable], Callable]:
8098
return unittest.skipIf(torch.xpu.is_available(), reason) # pyright: ignore[reportAttributeAccessIssue]
8199

82100

101+
def skipIfCpu(reason: str) -> Callable[[Callable], Callable]:
102+
"""Skip test if running on Triton CPU backend."""
103+
return unittest.skipIf(is_cpu(), reason)
104+
105+
83106
def skipIfA10G(reason: str) -> Callable[[Callable], Callable]:
84107
"""Skip test if running on A10G GPU"""
85108
gpu_model = get_nvidia_gpu_model()

test/test_autotuner.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from helion._testing import RefEagerTestDisabled
2929
from helion._testing import TestCase
3030
from helion._testing import import_path
31+
from helion._testing import skipIfCpu
3132
from helion._testing import skipIfRocm
3233
from helion.autotuner import DifferentialEvolutionSearch
3334
from helion.autotuner import PatternSearch
@@ -316,6 +317,7 @@ def add(a, b):
316317
)
317318
torch.testing.assert_close(add(*args), sum(args))
318319

320+
@skipIfCpu("fails on Triton CPU backend")
319321
def test_run_finite_search(self):
320322
@helion.kernel(
321323
configs=[
@@ -347,6 +349,7 @@ def add(a, b):
347349
torch.testing.assert_close(add(*args), sum(args))
348350

349351
@skipIfRocm("too slow on rocm")
352+
@skipIfCpu("TritonError: Error from Triton code")
350353
def test_random_search(self):
351354
args = (
352355
torch.randn([512, 512], device=DEVICE),
@@ -436,6 +439,7 @@ def diff_count(flat):
436439
]
437440
self.assertEqual(sorted(pair_neighbors), sorted(expected))
438441

442+
@skipIfCpu("fails on Triton CPU backend")
439443
def test_accuracy_check_filters_bad_config_wrong_output(self) -> None:
440444
bad_config = helion.Config(block_sizes=[1], num_warps=8)
441445
good_config = helion.Config(block_sizes=[1], num_warps=4)
@@ -509,6 +513,7 @@ def make_bad_config_produce_wrong_output(
509513
run_mode("fork", expect_error=False)
510514
run_mode("spawn", expect_error=True)
511515

516+
@skipIfCpu("fails on Triton CPU backend")
512517
def test_accuracy_check_filters_bad_config_wrong_arg_mutation(self) -> None:
513518
bad_config = helion.Config(block_sizes=[1], num_warps=8)
514519
good_config = helion.Config(block_sizes=[1], num_warps=4)
@@ -591,6 +596,7 @@ def wrong_fn(*fn_args, **fn_kwargs):
591596
run_mode("fork", expect_error=False)
592597
run_mode("spawn", expect_error=True)
593598

599+
@skipIfCpu("fails on Triton CPU backend")
594600
def test_autotune_baseline_fn(self) -> None:
595601
"""Test that custom baseline function is used for accuracy checking."""
596602
config1 = helion.Config(block_sizes=[32], num_warps=4)
@@ -631,6 +637,7 @@ def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
631637
# Verify the result is correct
632638
torch.testing.assert_close(result, args[0] + args[1])
633639

640+
@skipIfCpu("fails on Triton CPU backend")
634641
def test_autotune_baseline_fn_filters_bad_config(self) -> None:
635642
"""Test that custom baseline function correctly filters incorrect configs."""
636643
bad_config = helion.Config(block_sizes=[1], num_warps=8)
@@ -729,6 +736,7 @@ def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
729736
):
730737
add(*args)
731738

739+
@skipIfCpu("fails on Triton CPU backend")
732740
def test_max_generations(self):
733741
"""Autotuner max generation respects explicit kwargs then setting override."""
734742

@@ -772,6 +780,7 @@ def add(a, b):
772780
result = add(*args)
773781
torch.testing.assert_close(result, sum(args))
774782

783+
@skipIfCpu("fails on Triton CPU backend")
775784
def test_autotune_effort_quick(self):
776785
"""Test that quick effort profile uses correct default values."""
777786
# Get the quick profile defaults
@@ -907,6 +916,7 @@ def add(a, b):
907916
return search.samples[0]
908917

909918
@skipIfRocm("accuracy difference")
919+
@skipIfCpu("fails on Triton CPU backend")
910920
def test_autotune_random_seed_from_env_var(self) -> None:
911921
# same env var value -> same random sample
912922
with patch.dict(
@@ -931,6 +941,7 @@ def test_autotune_random_seed_from_env_var(self) -> None:
931941
self.assertNotEqual(first, second)
932942

933943
@skipIfRocm("accuracy difference")
944+
@skipIfCpu("fails on Triton CPU backend")
934945
def test_autotune_random_seed_from_settings(self) -> None:
935946
# same autotune_random_seed setting -> same random sample
936947
first = self._autotune_and_record(autotune_random_seed=4242)

test/test_cache.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from helion._testing import RefEagerTestDisabled
1616
from helion._testing import TestCase
1717
from helion._testing import import_path
18+
from helion._testing import skipIfCpu
1819
from helion._utils import counters
1920
from helion.autotuner import StrictLocalAutotuneCache
2021
from helion.autotuner.base_search import BaseSearch
@@ -73,6 +74,7 @@ def get_welford_kernel():
7374

7475
class TestCache(RefEagerTestDisabled, TestCase):
7576
@parametrize("name", ("add", "matmul", "welford"))
77+
@skipIfCpu("fails on Triton CPU backend")
7678
def test_kernel(self, name):
7779
kernel, args_a, result_a, args_b, result_b = KERNELS[name]()
7880

@@ -105,6 +107,7 @@ def test_kernel(self, name):
105107
self.assertEqual(counters["autotune"]["cache_hit"], 1)
106108
self.assertEqual(counters["autotune"]["cache_put"], 2)
107109

110+
@skipIfCpu("fails on Triton CPU backend")
108111
def test_key_affects_cache_specialization(self):
109112
counters["autotune"].clear()
110113
self.addCleanup(counters["autotune"].clear)
@@ -150,6 +153,7 @@ def add_one(x: torch.Tensor):
150153
self.assertEqual(counters["autotune"]["cache_hit"], 1)
151154
self.assertEqual(counters["autotune"]["cache_put"], 2)
152155

156+
@skipIfCpu("fails on Triton CPU backend")
153157
def test_assert_cache_hit(self):
154158
counters["autotune"].clear()
155159
self.addCleanup(counters["autotune"].clear)

test/test_dot.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from helion._testing import TestCase
1515
from helion._testing import code_and_output
1616
from helion._testing import is_cuda
17+
from helion._testing import skipIfCpu
1718
from helion._testing import skipIfRefEager
1819
from helion._testing import skipIfRocm
1920
from helion._testing import skipIfXPU
@@ -293,6 +294,7 @@ def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
293294

294295
@skipIfRefEager("Debug dtype codegen checks rely on compiled code")
295296
@skipIfXPU("Failed on XPU - https://github.com/pytorch/helion/issues/772")
297+
@skipIfCpu("Failed: Timeout (>10.0s) from pytest-timeout.")
296298
def test_baddbmm_pipeline_debug_dtype_asserts(self):
297299
# Reproduces scripts/repro512.py within the test suite and asserts
298300
# the kernel compiles and runs with debug dtype asserts enabled.
@@ -981,6 +983,16 @@ def test_matmul_reshape_n_2(self):
981983
"float16 accumulator not supported for bf16/f32 in ref eager mode"
982984
)(_test_func)
983985

986+
# CPU backend skip for specific failing dynamic-shape case
987+
if test_name == "test_input_float16_acc_float16_dynamic_shape":
988+
_test_func = skipIfCpu("AssertionError: Tensor-likes are not close!")(
989+
_test_func
990+
)
991+
if test_name == "test_input_float16_acc_float16_static_shape":
992+
_test_func = skipIfCpu("AssertionError: Tensor-likes are not close!")(
993+
_test_func
994+
)
995+
984996
setattr(TestDot, test_name, _test_func)
985997

986998

test/test_errors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from helion._testing import RefEagerTestDisabled
1111
from helion._testing import TestCase
1212
from helion._testing import code_and_output
13+
from helion._testing import skipIfCpu
1314
from helion.autotuner.base_search import PopulationBasedSearch
1415
from helion.autotuner.base_search import PopulationMember
1516
from helion.autotuner.differential_evolution import DifferentialEvolutionSearch
@@ -33,6 +34,7 @@ def _test_outer_kernel_calling_inner(x: torch.Tensor) -> torch.Tensor:
3334

3435

3536
class TestErrors(RefEagerTestDisabled, TestCase):
37+
@skipIfCpu("fails on Triton CPU backend")
3638
def test_autotune_no_valid_configs(self):
3739
class FakeKernel:
3840
def __init__(self) -> None:

test/test_examples.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from helion._testing import check_example
1717
from helion._testing import import_path
1818
from helion._testing import skipIfA10G
19+
from helion._testing import skipIfCpu
1920
from helion._testing import skipIfRefEager
2021
from helion._testing import skipIfRocm
2122
from helion._testing import skipIfXPU
@@ -24,6 +25,7 @@
2425
torch.backends.cudnn.conv.fp32_precision = "tf32"
2526

2627

28+
@skipIfCpu("needs to be debugged")
2729
class TestExamples(RefEagerTestBase, TestCase):
2830
def test_add(self):
2931
args = (

test/test_generate_ast.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from helion._testing import TestCase
1212
from helion._testing import code_and_output
1313
from helion._testing import import_path
14+
from helion._testing import skipIfCpu
1415
from helion._testing import skipIfRefEager
1516
import helion.language as hl
1617

@@ -35,6 +36,7 @@ def test_add1d(self):
3536
torch.testing.assert_close(result, args[0] + args[1])
3637
self.assertExpectedJournal(code)
3738

39+
@skipIfCpu("fails on Triton CPU backend")
3840
def test_add2d(self):
3941
args = (
4042
torch.randn([100, 500], device=DEVICE),
@@ -46,6 +48,7 @@ def test_add2d(self):
4648
torch.testing.assert_close(result, args[0] + args[1])
4749
self.assertExpectedJournal(code)
4850

51+
@skipIfCpu("fails on Triton CPU backend")
4952
def test_add2d_loop_order(self):
5053
args = (
5154
torch.randn([100, 500], device=DEVICE),
@@ -61,6 +64,7 @@ def test_add2d_loop_order(self):
6164
torch.testing.assert_close(result, args[0] + args[1])
6265
self.assertExpectedJournal(code)
6366

67+
@skipIfCpu("fails on Triton CPU backend")
6468
def test_add3d(self):
6569
args = (
6670
torch.randn([100, 500, 10], device=DEVICE),
@@ -83,6 +87,7 @@ def test_add3d_xy_grid(self):
8387
torch.testing.assert_close(result, args[0] + args[1])
8488
self.assertExpectedJournal(code)
8589

90+
@skipIfCpu("fails on Triton CPU backend")
8691
def test_add3d_reorder(self):
8792
args = (
8893
torch.randn([100, 500, 10], device=DEVICE),
@@ -213,6 +218,7 @@ def test_final_cast_enforced_for_to_dtype(self):
213218
# Ensure codegen emits a final tl.cast(..., tl.bfloat16)
214219
assert "tl.cast" in code and "tl.bfloat16" in code
215220

221+
@skipIfCpu("Failed: Timeout (>10.0s) from pytest-timeout.")
216222
def test_sigmoid_scalar_autocast(self):
217223
@helion.kernel(
218224
config=helion.Config(

test/test_indexing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from helion._testing import RefEagerTestBase
1515
from helion._testing import TestCase
1616
from helion._testing import code_and_output
17+
from helion._testing import skipIfCpu
1718
from helion._testing import skipIfLowVRAM
1819
from helion._testing import skipIfNormalMode
1920
from helion._testing import skipIfRefEager
@@ -396,6 +397,7 @@ def test_block_size_access(x: torch.Tensor) -> torch.Tensor:
396397
"IndexOffsetOutOfRangeForInt32 error is not raised in ref eager mode"
397398
)
398399
@skipIfLowVRAM("Test requires high VRAM")
400+
@skipIfCpu("fails on Triton CPU backend")
399401
def test_int32_offset_out_of_range_error(self):
400402
repro_config = helion.Config(
401403
block_sizes=[32, 32],

0 commit comments

Comments
 (0)