Skip to content

Commit 20d94dc

Browse files
committed
Update examples to use run_example
stack-info: PR: #941, branch: jansel/stack/193
1 parent a6656c1 commit 20d94dc

File tree

5 files changed

+136
-53
lines changed

5 files changed

+136
-53
lines changed

examples/all_reduce.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import helion
2525
from helion._testing import DEVICE
26+
from helion._testing import run_example
2627
import helion.language as hl
2728

2829
# %%
@@ -212,6 +213,27 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
212213
)
213214

214215

216+
def reference_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
217+
"""
218+
Reference implementation using the symmetric memory one-shot primitive.
219+
"""
220+
dist_group = dist.group.WORLD
221+
if dist_group is None:
222+
raise RuntimeError("No distributed group available")
223+
224+
a_shared_clone = symm_mem.empty(
225+
a_shared.shape,
226+
dtype=a_shared.dtype,
227+
device=a_shared.device,
228+
)
229+
symm_mem.rendezvous(a_shared_clone, dist_group.group_name)
230+
a_shared_clone.copy_(a_shared)
231+
232+
return torch.ops.symm_mem.one_shot_all_reduce( # pyright: ignore[reportCallIssue]
233+
a_shared_clone, "sum", dist_group.group_name
234+
)
235+
236+
215237
# %%
216238
# Testing Function
217239
# ----------------
@@ -232,21 +254,13 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
232254
world_size = dist.get_world_size()
233255
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()
234256

235-
a_shared_clone = symm_mem.empty(
236-
a_shared.shape,
237-
dtype=a_shared.dtype,
238-
device=a_shared.device,
257+
run_example(
258+
helion_one_shot_all_reduce,
259+
reference_one_shot_all_reduce,
260+
(a_shared,),
261+
rtol=1e-1,
262+
atol=1e-1,
239263
)
240-
symm_mem.rendezvous(a_shared_clone, dist_group.group_name)
241-
a_shared_clone.copy_(a_shared)
242-
243-
a_out = helion_one_shot_all_reduce(a_shared)
244-
245-
gloden_o = torch.ops.symm_mem.one_shot_all_reduce(
246-
a_shared_clone, "sum", dist_group.group_name
247-
)
248-
249-
torch.testing.assert_close(a_out, gloden_o, rtol=1e-1, atol=1e-1)
250264

251265

252266
def main() -> None:

examples/bf16xint16_gemm.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import helion
1717
from helion._testing import DEVICE
18+
from helion._testing import run_example
1819
import helion.language as hl
1920

2021

@@ -140,19 +141,25 @@ def check(m: int, k: int, n: int) -> None:
140141
"""
141142
x = torch.randn([m, k], device=DEVICE, dtype=torch.bfloat16)
142143
w = torch.randint(-(2**15), 2**15 - 1, (k, n), device=DEVICE, dtype=torch.int16)
143-
144-
result = bf16xint16_gemm(x, w, transpose=False)
145-
expected = reference_bf16xint16_pytorch(x, w, transpose=False)
146-
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
144+
run_example(
145+
bf16xint16_gemm,
146+
reference_bf16xint16_pytorch,
147+
(x, w, False),
148+
rtol=1e-2,
149+
atol=1e-2,
150+
)
147151

148152
x_int16 = torch.randint(
149153
-(2**15), 2**15 - 1, (m, k), device=DEVICE, dtype=torch.int16
150154
)
151155
w_bf16 = torch.randn([k, n], device=DEVICE, dtype=torch.bfloat16)
152-
153-
result = bf16xint16_gemm(x_int16, w_bf16, transpose=True)
154-
expected = reference_bf16xint16_pytorch(x_int16, w_bf16, transpose=True)
155-
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
156+
run_example(
157+
bf16xint16_gemm,
158+
reference_bf16xint16_pytorch,
159+
(x_int16, w_bf16, True),
160+
rtol=1e-2,
161+
atol=1e-2,
162+
)
156163

157164

158165
# %%

examples/grouped_gemm.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
import helion
4242
from helion._testing import DEVICE
43+
from helion._testing import run_example
4344
import helion.language as hl
4445

4546
# %%
@@ -310,6 +311,26 @@ def _reference_grouped_gemm(
310311
return torch.cat(outs, dim=0)
311312

312313

314+
def grouped_gemm_jagged_example(
315+
group_A: list[torch.Tensor], group_B: list[torch.Tensor]
316+
) -> torch.Tensor:
317+
"""
318+
Wrapper to run grouped_gemm_jagged with unpacked TritonBench inputs.
319+
"""
320+
A_packed, B_shared, group_offsets = _pack_group_inputs(group_A, group_B)
321+
return grouped_gemm_jagged(A_packed, B_shared, group_offsets)
322+
323+
324+
def grouped_gemm_jagged_persistent_example(
325+
group_A: list[torch.Tensor], group_B: list[torch.Tensor]
326+
) -> torch.Tensor:
327+
"""
328+
Wrapper to run grouped_gemm_jagged_persistent with unpacked TritonBench inputs.
329+
"""
330+
A_packed, B_shared, group_offsets = _pack_group_inputs(group_A, group_B)
331+
return grouped_gemm_jagged_persistent(A_packed, B_shared, group_offsets)
332+
333+
313334
# %%
314335
# Test Harness and Validation
315336
# ---------------------------
@@ -330,18 +351,23 @@ def main() -> None:
330351
# Shared weight matrix B replicated for each group (as per TritonBench convention)
331352
group_B = [torch.randn(K, N, device=device, dtype=dtype).contiguous()] * G
332353

333-
ref = _reference_grouped_gemm(group_A, group_B)
334-
335354
print("Testing grouped GEMM kernels...")
336-
337-
# Test basic jagged kernel correctness
338-
out = grouped_gemm_jagged_tritonbench(None, group_A, group_B)()
339-
torch.testing.assert_close(out.float(), ref.float(), atol=1e-2, rtol=1e-2)
355+
run_example(
356+
grouped_gemm_jagged_example,
357+
_reference_grouped_gemm,
358+
(group_A, group_B),
359+
rtol=1e-2,
360+
atol=1e-2,
361+
)
340362
print("✓ Non-persistent kernel passed")
341363

342-
# Test persistent kernel with dynamic tiling
343-
out_p = grouped_gemm_jagged_persistent_tritonbench(None, group_A, group_B)()
344-
torch.testing.assert_close(out_p.float(), ref.float(), atol=1e-2, rtol=1e-2)
364+
run_example(
365+
grouped_gemm_jagged_persistent_example,
366+
_reference_grouped_gemm,
367+
(group_A, group_B),
368+
rtol=1e-2,
369+
atol=1e-2,
370+
)
345371
print("✓ Persistent kernel passed")
346372

347373
print("\nAll tests passed!")

examples/int4_gemm.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import helion
2323
from helion._testing import DEVICE
24+
from helion._testing import run_example
2425
import helion.language as hl
2526

2627
# %%
@@ -130,37 +131,72 @@ def run_kernel() -> torch.Tensor:
130131

131132

132133
# %%
133-
def check(m: int, k: int, n: int) -> None:
134+
def _pack_int4_matrix(unpacked: torch.Tensor) -> torch.Tensor:
134135
"""
135-
Test the INT4 GEMM implementation.
136+
Pack int4 matrix into int8 container with two values per byte.
136137
137138
Args:
138-
m (int): Number of rows in the left input matrix.
139-
k (int): Shared dimension (must be even).
140-
n (int): Number of columns in the right input matrix.
139+
unpacked (torch.Tensor): Tensor of shape [K, N] with values in [-8, 7].
140+
141+
Returns:
142+
torch.Tensor: Packed tensor of shape [K//2, N] in int8 format.
141143
"""
142-
# Create test matrices
143-
A = torch.randn(m, k, dtype=torch.bfloat16, device=DEVICE)
144+
k, n = unpacked.shape
145+
assert k % 2 == 0, "K dimension must be even for int4 packing"
146+
reshaped = unpacked.reshape(k // 2, 2, n).permute(1, 0, 2)
147+
return ((reshaped[0] & 0xF) | (reshaped[1] << 4)).to(torch.int8)
144148

145-
# Create packed int4 matrix B (K//2 x N)
146-
# Generate random int4 values in range [-8, 7] and pack them
147-
B_unpacked = torch.randint(-8, 8, (k, n), dtype=torch.int8, device=DEVICE)
148149

149-
# Pack using the same format as tritonbench
150-
B_reshaped = B_unpacked.reshape(k // 2, 2, n).permute(1, 0, 2)
151-
B_packed = ((B_reshaped[0] & 0xF) | (B_reshaped[1] << 4)).to(torch.int8)
150+
def _unpack_int4_matrix(packed: torch.Tensor) -> torch.Tensor:
151+
"""
152+
Unpack an int4 matrix stored as two 4-bit values per int8 byte.
153+
154+
Args:
155+
packed (torch.Tensor): Packed tensor of shape [K//2, N] in int8 format.
156+
157+
Returns:
158+
torch.Tensor: Unpacked tensor of shape [K, N] in int8 format.
159+
"""
160+
b_lo = ((packed << 4) >> 4).to(torch.int8)
161+
b_hi = (packed >> 4).to(torch.int8)
162+
stacked = torch.stack([b_lo, b_hi], dim=1)
163+
return stacked.reshape(packed.shape[0] * 2, packed.shape[1])
164+
152165

153-
# Convert unpacked values to bfloat16 for reference
154-
B_unpacked_bf16 = B_unpacked.to(torch.bfloat16)
166+
def reference_matmul_bf16_int4(A: Tensor, B_packed: Tensor) -> Tensor:
167+
"""
168+
Reference implementation that unpacks the int4 weights and performs matmul.
169+
170+
Args:
171+
A (Tensor): Input tensor in bfloat16 format.
172+
B_packed (Tensor): Packed int4 tensor.
173+
174+
Returns:
175+
Tensor: Output tensor in bfloat16 format.
176+
"""
177+
B_unpacked = _unpack_int4_matrix(B_packed).to(torch.bfloat16)
178+
return torch.matmul(A, B_unpacked)
155179

156-
# Compute reference result
157-
expected = torch.matmul(A, B_unpacked_bf16)
158180

159-
# Run the kernel
160-
result = matmul_bf16_int4(A, B_packed)
181+
def check(m: int, k: int, n: int) -> None:
182+
"""
183+
Test the INT4 GEMM implementation using the run_example utility.
161184
162-
# Check accuracy with appropriate tolerance
163-
torch.testing.assert_close(result, expected, rtol=2e-1, atol=1.0)
185+
Args:
186+
m (int): Number of rows in the left input matrix.
187+
k (int): Shared dimension (must be even).
188+
n (int): Number of columns in the right input matrix.
189+
"""
190+
A = torch.randn(m, k, dtype=torch.bfloat16, device=DEVICE)
191+
B_unpacked = torch.randint(-8, 8, (k, n), dtype=torch.int8, device=DEVICE)
192+
B_packed = _pack_int4_matrix(B_unpacked)
193+
run_example(
194+
matmul_bf16_int4,
195+
reference_matmul_bf16_int4,
196+
(A, B_packed),
197+
rtol=2e-1,
198+
atol=1.0,
199+
)
164200
print(f"Test passed for shapes: M={m}, K={k}, N={n}")
165201

166202

helion/autotuner/base_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def parallel_benchmark(
409409
iterator = iter_with_progress(
410410
zip(configs, fns, is_workings, strict=True),
411411
total=len(configs),
412-
description=f"{desc}: exploring neighbors",
412+
description=f"{desc} exploring neighbors",
413413
enabled=self.settings.autotune_progress_bar,
414414
)
415415
for config, fn, is_working in iterator:

0 commit comments

Comments
 (0)