Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions examples/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl

# %%
Expand Down Expand Up @@ -212,6 +213,27 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
)


def reference_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
"""
Reference implementation using the symmetric memory one-shot primitive.
"""
dist_group = dist.group.WORLD
if dist_group is None:
raise RuntimeError("No distributed group available")

a_shared_clone = symm_mem.empty(
a_shared.shape,
dtype=a_shared.dtype,
device=a_shared.device,
)
symm_mem.rendezvous(a_shared_clone, dist_group.group_name)
a_shared_clone.copy_(a_shared)

return torch.ops.symm_mem.one_shot_all_reduce( # pyright: ignore[reportCallIssue]
a_shared_clone, "sum", dist_group.group_name
)


# %%
# Testing Function
# ----------------
Expand All @@ -232,21 +254,13 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
world_size = dist.get_world_size()
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()

a_shared_clone = symm_mem.empty(
a_shared.shape,
dtype=a_shared.dtype,
device=a_shared.device,
run_example(
helion_one_shot_all_reduce,
reference_one_shot_all_reduce,
(a_shared,),
rtol=1e-1,
atol=1e-1,
)
symm_mem.rendezvous(a_shared_clone, dist_group.group_name)
a_shared_clone.copy_(a_shared)

a_out = helion_one_shot_all_reduce(a_shared)

gloden_o = torch.ops.symm_mem.one_shot_all_reduce(
a_shared_clone, "sum", dist_group.group_name
)

torch.testing.assert_close(a_out, gloden_o, rtol=1e-1, atol=1e-1)


def main() -> None:
Expand Down
23 changes: 15 additions & 8 deletions examples/bf16xint16_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl


Expand Down Expand Up @@ -140,19 +141,25 @@ def check(m: int, k: int, n: int) -> None:
"""
x = torch.randn([m, k], device=DEVICE, dtype=torch.bfloat16)
w = torch.randint(-(2**15), 2**15 - 1, (k, n), device=DEVICE, dtype=torch.int16)

result = bf16xint16_gemm(x, w, transpose=False)
expected = reference_bf16xint16_pytorch(x, w, transpose=False)
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
run_example(
bf16xint16_gemm,
reference_bf16xint16_pytorch,
(x, w, False),
rtol=1e-2,
atol=1e-2,
)

x_int16 = torch.randint(
-(2**15), 2**15 - 1, (m, k), device=DEVICE, dtype=torch.int16
)
w_bf16 = torch.randn([k, n], device=DEVICE, dtype=torch.bfloat16)

result = bf16xint16_gemm(x_int16, w_bf16, transpose=True)
expected = reference_bf16xint16_pytorch(x_int16, w_bf16, transpose=True)
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
run_example(
bf16xint16_gemm,
reference_bf16xint16_pytorch,
(x_int16, w_bf16, True),
rtol=1e-2,
atol=1e-2,
)


# %%
Expand Down
44 changes: 35 additions & 9 deletions examples/grouped_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl

# %%
Expand Down Expand Up @@ -310,6 +311,26 @@ def _reference_grouped_gemm(
return torch.cat(outs, dim=0)


def grouped_gemm_jagged_example(
group_A: list[torch.Tensor], group_B: list[torch.Tensor]
) -> torch.Tensor:
"""
Wrapper to run grouped_gemm_jagged with unpacked TritonBench inputs.
"""
A_packed, B_shared, group_offsets = _pack_group_inputs(group_A, group_B)
return grouped_gemm_jagged(A_packed, B_shared, group_offsets)


def grouped_gemm_jagged_persistent_example(
group_A: list[torch.Tensor], group_B: list[torch.Tensor]
) -> torch.Tensor:
"""
Wrapper to run grouped_gemm_jagged_persistent with unpacked TritonBench inputs.
"""
A_packed, B_shared, group_offsets = _pack_group_inputs(group_A, group_B)
return grouped_gemm_jagged_persistent(A_packed, B_shared, group_offsets)


# %%
# Test Harness and Validation
# ---------------------------
Expand All @@ -330,18 +351,23 @@ def main() -> None:
# Shared weight matrix B replicated for each group (as per TritonBench convention)
group_B = [torch.randn(K, N, device=device, dtype=dtype).contiguous()] * G

ref = _reference_grouped_gemm(group_A, group_B)

print("Testing grouped GEMM kernels...")

# Test basic jagged kernel correctness
out = grouped_gemm_jagged_tritonbench(None, group_A, group_B)()
torch.testing.assert_close(out.float(), ref.float(), atol=1e-2, rtol=1e-2)
run_example(
grouped_gemm_jagged_example,
_reference_grouped_gemm,
(group_A, group_B),
rtol=1e-2,
atol=1e-2,
)
print("✓ Non-persistent kernel passed")

# Test persistent kernel with dynamic tiling
out_p = grouped_gemm_jagged_persistent_tritonbench(None, group_A, group_B)()
torch.testing.assert_close(out_p.float(), ref.float(), atol=1e-2, rtol=1e-2)
run_example(
grouped_gemm_jagged_persistent_example,
_reference_grouped_gemm,
(group_A, group_B),
rtol=1e-2,
atol=1e-2,
)
print("✓ Persistent kernel passed")

print("\nAll tests passed!")
Expand Down
78 changes: 57 additions & 21 deletions examples/int4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl

# %%
Expand Down Expand Up @@ -130,37 +131,72 @@ def run_kernel() -> torch.Tensor:


# %%
def check(m: int, k: int, n: int) -> None:
def _pack_int4_matrix(unpacked: torch.Tensor) -> torch.Tensor:
"""
Test the INT4 GEMM implementation.
Pack int4 matrix into int8 container with two values per byte.

Args:
m (int): Number of rows in the left input matrix.
k (int): Shared dimension (must be even).
n (int): Number of columns in the right input matrix.
unpacked (torch.Tensor): Tensor of shape [K, N] with values in [-8, 7].

Returns:
torch.Tensor: Packed tensor of shape [K//2, N] in int8 format.
"""
# Create test matrices
A = torch.randn(m, k, dtype=torch.bfloat16, device=DEVICE)
k, n = unpacked.shape
assert k % 2 == 0, "K dimension must be even for int4 packing"
reshaped = unpacked.reshape(k // 2, 2, n).permute(1, 0, 2)
return ((reshaped[0] & 0xF) | (reshaped[1] << 4)).to(torch.int8)

# Create packed int4 matrix B (K//2 x N)
# Generate random int4 values in range [-8, 7] and pack them
B_unpacked = torch.randint(-8, 8, (k, n), dtype=torch.int8, device=DEVICE)

# Pack using the same format as tritonbench
B_reshaped = B_unpacked.reshape(k // 2, 2, n).permute(1, 0, 2)
B_packed = ((B_reshaped[0] & 0xF) | (B_reshaped[1] << 4)).to(torch.int8)
def _unpack_int4_matrix(packed: torch.Tensor) -> torch.Tensor:
"""
Unpack an int4 matrix stored as two 4-bit values per int8 byte.

Args:
packed (torch.Tensor): Packed tensor of shape [K//2, N] in int8 format.

Returns:
torch.Tensor: Unpacked tensor of shape [K, N] in int8 format.
"""
b_lo = ((packed << 4) >> 4).to(torch.int8)
b_hi = (packed >> 4).to(torch.int8)
stacked = torch.stack([b_lo, b_hi], dim=1)
return stacked.reshape(packed.shape[0] * 2, packed.shape[1])


# Convert unpacked values to bfloat16 for reference
B_unpacked_bf16 = B_unpacked.to(torch.bfloat16)
def reference_matmul_bf16_int4(A: Tensor, B_packed: Tensor) -> Tensor:
"""
Reference implementation that unpacks the int4 weights and performs matmul.

Args:
A (Tensor): Input tensor in bfloat16 format.
B_packed (Tensor): Packed int4 tensor.

Returns:
Tensor: Output tensor in bfloat16 format.
"""
B_unpacked = _unpack_int4_matrix(B_packed).to(torch.bfloat16)
return torch.matmul(A, B_unpacked)

# Compute reference result
expected = torch.matmul(A, B_unpacked_bf16)

# Run the kernel
result = matmul_bf16_int4(A, B_packed)
def check(m: int, k: int, n: int) -> None:
"""
Test the INT4 GEMM implementation using the run_example utility.

# Check accuracy with appropriate tolerance
torch.testing.assert_close(result, expected, rtol=2e-1, atol=1.0)
Args:
m (int): Number of rows in the left input matrix.
k (int): Shared dimension (must be even).
n (int): Number of columns in the right input matrix.
"""
A = torch.randn(m, k, dtype=torch.bfloat16, device=DEVICE)
B_unpacked = torch.randint(-8, 8, (k, n), dtype=torch.int8, device=DEVICE)
B_packed = _pack_int4_matrix(B_unpacked)
run_example(
matmul_bf16_int4,
reference_matmul_bf16_int4,
(A, B_packed),
rtol=2e-1,
atol=1.0,
)
print(f"Test passed for shapes: M={m}, K={k}, N={n}")


Expand Down
2 changes: 1 addition & 1 deletion helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def parallel_benchmark(
iterator = iter_with_progress(
zip(configs, fns, is_workings, strict=True),
total=len(configs),
description=f"{desc}: exploring neighbors",
description=f"{desc} exploring neighbors",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: curious if we intentionally want to remove :

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was printing "::" since desc already had :

enabled=self.settings.autotune_progress_bar,
)
for config, fn, is_working in iterator:
Expand Down
Loading