Skip to content

Commit 0c49062

Browse files
authored
Support torch.rand / torch.rand_like with dynamic tile sizes (#1057)
1 parent da9c440 commit 0c49062

File tree

4 files changed

+157
-36
lines changed

4 files changed

+157
-36
lines changed

helion/_compiler/inductor_lowering.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
from .ast_extension import statement_from_string
5353
from .compile_environment import CompileEnvironment
5454
from .compile_environment import FixedBlockSizeSource
55-
from .device_function import SymbolArgument
5655
from .device_function import VarInfo
5756
from .device_function import contains_only_block_size_symbols
5857
from .dtype_utils import cast_ast
@@ -1559,27 +1558,29 @@ def _codegen_rng_op(
15591558
node: The FX node for this operation
15601559
rng_function: Either "rand" or "randn"
15611560
"""
1561+
from .generate_ast import GenerateAST
1562+
15621563
assert rng_function in ["rand", "randn"]
1564+
assert isinstance(ctx.cg, GenerateAST)
15631565

15641566
# Get unique seed index for this RNG operation
15651567
device_fn = ctx.cg.device_function
15661568
seed_index = device_fn.allocate_rng_seed()
15671569

15681570
# Get dimensionality and dtype
15691571
assert hasattr(node, "meta") and "val" in node.meta
1570-
ndim = node.meta["val"].ndim
1572+
fake_value = node.meta["val"]
1573+
ndim = fake_value.ndim
15711574
dtype = node.kwargs.get("dtype", None)
15721575

1573-
# Get the dimension variable names from the device function's symbol arguments
1574-
device_fn = ctx.cg.device_function
1575-
symbol_args = [
1576-
arg for arg in device_fn.arguments if isinstance(arg, SymbolArgument)
1577-
]
1578-
1579-
# Extract dimension names - they should be the last ndim symbol arguments
1576+
# Get dimension names for offset calculation
1577+
env = CompileEnvironment.current()
15801578
dim_names = []
1581-
assert len(symbol_args) >= ndim, "Not enough symbol arguments for dimensions"
1582-
dim_names = [arg.name for arg in symbol_args[-ndim:]]
1579+
for size in fake_value.size():
1580+
block_id = env.get_block_id(size)
1581+
assert block_id is not None
1582+
block_size = env.block_sizes[block_id].size
1583+
dim_names.append(device_fn.literal_expr(block_size))
15831584

15841585
offset_parts = []
15851586

helion/language/ref_tile.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
from typing import TYPE_CHECKING
4+
from typing import TypeVar
45

56
import torch
7+
from torch.utils._pytree import tree_map_only
68

79
from .. import exc
810
from .._utils import convert_tile_indices_to_slices
@@ -12,6 +14,8 @@
1214
if TYPE_CHECKING:
1315
from collections.abc import Callable
1416

17+
_T = TypeVar("_T")
18+
1519

1620
_ADD_OPS: set[object] = {
1721
torch.add,
@@ -71,8 +75,24 @@ def __torch_function__(
7175
if func in _SUB_OPS:
7276
return cls._handle_sub(args)
7377

78+
# For any other torch.* function or torch.Tensor.* method, convert tiles to sizes
79+
is_torch_func = getattr(func, "__module__", "") == "torch"
80+
is_tensor_method = hasattr(torch.Tensor, getattr(func, "__name__", ""))
81+
if is_torch_func or is_tensor_method:
82+
new_args = cls._tiles_to_sizes(args)
83+
new_kwargs = cls._tiles_to_sizes(kwargs) if kwargs else {}
84+
return func(*new_args, **new_kwargs)
85+
7486
raise exc.IncorrectTileUsage(func)
7587

88+
@classmethod
89+
def _tiles_to_sizes(cls, it: _T) -> _T:
90+
return tree_map_only(RefTile, cls._tile_to_size, it)
91+
92+
@staticmethod
93+
def _tile_to_size(tile: RefTile) -> int:
94+
return tile.block_size
95+
7696
@classmethod
7797
def _handle_add(cls, args: tuple[object, ...]) -> torch.Tensor:
7898
tile, offset, flipped = cls._extract_tile_and_offset(args, torch.add)

test/test_rng.expected

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,38 +10,36 @@ import triton.language as tl
1010
from helion.runtime import default_launcher as _default_launcher
1111

1212
@triton.jit
13-
def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal_stride_0, normal_stride_1, rand1_stride_0, rand1_stride_1, rand2_stride_0, rand2_stride_1, randn_a_stride_0, randn_a_stride_1, randn_b_stride_0, randn_b_stride_1, randn_c_stride_0, randn_c_stride_1, uniform_stride_0, uniform_stride_1, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, rng_seed_buffer):
13+
def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, rng_seed_buffer):
1414
# src[test_rng.py:N]: for tile_m, tile_n in hl.tile([m, n]):
15-
num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0)
15+
num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0)
1616
pid_0 = tl.program_id(0) % num_blocks_0
1717
pid_1 = tl.program_id(0) // num_blocks_0
1818
offset_0 = pid_0 * _BLOCK_SIZE_0
1919
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
20-
mask_0 = indices_0 < m
2120
offset_1 = pid_1 * _BLOCK_SIZE_1
2221
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
23-
mask_1 = indices_1 < n
2422
# src[test_rng.py:N]: rand1[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n])
25-
rand = tl.rand(tl.load(rng_seed_buffer + 0), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
26-
tl.store(rand1 + (indices_0[:, None] * rand1_stride_0 + indices_1[None, :] * rand1_stride_1), rand, mask_0[:, None] & mask_1[None, :])
23+
rand = tl.rand(tl.load(rng_seed_buffer + 0), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
24+
tl.store(rand1 + (indices_0[:, None] * 64 + indices_1[None, :] * 1), rand, None)
2725
# src[test_rng.py:N]: rand2[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n])
28-
rand_1 = tl.rand(tl.load(rng_seed_buffer + 1), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
29-
tl.store(rand2 + (indices_0[:, None] * rand2_stride_0 + indices_1[None, :] * rand2_stride_1), rand_1, mask_0[:, None] & mask_1[None, :])
26+
rand_1 = tl.rand(tl.load(rng_seed_buffer + 1), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
27+
tl.store(rand2 + (indices_0[:, None] * 64 + indices_1[None, :] * 1), rand_1, None)
3028
# src[test_rng.py:N]: uniform[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n])
31-
rand_2 = tl.rand(tl.load(rng_seed_buffer + 2), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
32-
tl.store(uniform + (indices_0[:, None] * uniform_stride_0 + indices_1[None, :] * uniform_stride_1), rand_2, mask_0[:, None] & mask_1[None, :])
29+
rand_2 = tl.rand(tl.load(rng_seed_buffer + 2), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
30+
tl.store(uniform + (indices_0[:, None] * 64 + indices_1[None, :] * 1), rand_2, None)
3331
# src[test_rng.py:N]: normal[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n])
34-
randn = tl.randn(tl.load(rng_seed_buffer + 3), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
35-
tl.store(normal + (indices_0[:, None] * normal_stride_0 + indices_1[None, :] * normal_stride_1), randn, mask_0[:, None] & mask_1[None, :])
32+
randn = tl.randn(tl.load(rng_seed_buffer + 3), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
33+
tl.store(normal + (indices_0[:, None] * 64 + indices_1[None, :] * 1), randn, None)
3634
# src[test_rng.py:N]: randn_a[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n])
37-
randn_1 = tl.randn(tl.load(rng_seed_buffer + 4), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
38-
tl.store(randn_a + (indices_0[:, None] * randn_a_stride_0 + indices_1[None, :] * randn_a_stride_1), randn_1, mask_0[:, None] & mask_1[None, :])
35+
randn_1 = tl.randn(tl.load(rng_seed_buffer + 4), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
36+
tl.store(randn_a + (indices_0[:, None] * 64 + indices_1[None, :] * 1), randn_1, None)
3937
# src[test_rng.py:N]: randn_b[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n])
40-
randn_2 = tl.randn(tl.load(rng_seed_buffer + 5), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
41-
tl.store(randn_b + (indices_0[:, None] * randn_b_stride_0 + indices_1[None, :] * randn_b_stride_1), randn_2, mask_0[:, None] & mask_1[None, :])
38+
randn_2 = tl.randn(tl.load(rng_seed_buffer + 5), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
39+
tl.store(randn_b + (indices_0[:, None] * 64 + indices_1[None, :] * 1), randn_2, None)
4240
# src[test_rng.py:N]: randn_c[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n])
43-
randn_3 = tl.randn(tl.load(rng_seed_buffer + 6), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
44-
tl.store(randn_c + (indices_0[:, None] * randn_c_stride_0 + indices_1[None, :] * randn_c_stride_1), randn_3, mask_0[:, None] & mask_1[None, :])
41+
randn_3 = tl.randn(tl.load(rng_seed_buffer + 6), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
42+
tl.store(randn_c + (indices_0[:, None] * 64 + indices_1[None, :] * 1), randn_3, None)
4543

4644
def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
4745
from torch._inductor import inductor_prims
@@ -73,7 +71,7 @@ def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
7371
# src[test_rng.py:N]: # Two independent rand operations
7472
# src[test_rng.py:N]: rand1[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n])
7573
# src[test_rng.py:N-N]: ...
76-
_launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=1)
74+
_launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=1)
7775
# src[test_rng.py:N]: randn_sum = randn_a + randn_b + randn_c
7876
randn_sum = randn_a + randn_b + randn_c
7977
# src[test_rng.py:N]: return rand1, rand2, uniform, normal, randn_sum

test/test_rng.py

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from typing import Callable
34
import unittest
45

56
import torch
@@ -16,7 +17,7 @@ class TestRNG(RefEagerTestBase, TestCase):
1617
def test_rand(self):
1718
"""Test RNG seeding behavior, reproducibility, output range, and distribution."""
1819

19-
@helion.kernel(static_shapes=False)
20+
@helion.kernel(static_shapes=True, autotune_effort="none")
2021
def rand_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor:
2122
output = torch.zeros_like(x)
2223
m, n = x.shape
@@ -87,7 +88,7 @@ def rand_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor:
8788
def test_rand_3d_tensor(self):
8889
"""Test 3D RNG with tiled operations."""
8990

90-
@helion.kernel(static_shapes=False)
91+
@helion.kernel(static_shapes=True, autotune_effort="none")
9192
def rand_kernel_3d(x: torch.Tensor) -> torch.Tensor:
9293
output = torch.zeros_like(x)
9394
b, m, n = x.shape
@@ -135,7 +136,7 @@ def rand_kernel_3d(x: torch.Tensor) -> torch.Tensor:
135136
def test_multiple_rng_ops(self):
136137
"""Test multiple RNG operations: independence, reproducibility, mixed rand/randn."""
137138

138-
@helion.kernel(static_shapes=False)
139+
@helion.kernel(static_shapes=True, autotune_effort="none")
139140
def multiple_rng_ops_kernel(
140141
x: torch.Tensor,
141142
) -> tuple[
@@ -258,7 +259,7 @@ def multiple_rng_ops_kernel(
258259
def test_randn_different_seeds_tiled(self):
259260
"""Test that different torch.manual_seed values produce different outputs for randn."""
260261

261-
@helion.kernel(static_shapes=False)
262+
@helion.kernel(static_shapes=True, autotune_effort="none")
262263
def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor:
263264
output = torch.zeros_like(x)
264265
m, n = x.shape
@@ -280,7 +281,7 @@ def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor:
280281
def test_randn_normal_distribution(self):
281282
"""Test that torch.randn_like produces normal distribution (mean≈0, std≈1)."""
282283

283-
@helion.kernel(static_shapes=False)
284+
@helion.kernel(static_shapes=True, autotune_effort="none")
284285
def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor:
285286
output = torch.zeros_like(x)
286287
m, n = x.shape
@@ -315,7 +316,7 @@ def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor:
315316
def test_randn_3d_tensor(self):
316317
"""Test 3D randn with tiled operations."""
317318

318-
@helion.kernel(static_shapes=False)
319+
@helion.kernel(static_shapes=True, autotune_effort="none")
319320
def randn_kernel_3d(x: torch.Tensor) -> torch.Tensor:
320321
output = torch.zeros_like(x)
321322
b, m, n = x.shape
@@ -348,6 +349,107 @@ def randn_kernel_3d(x: torch.Tensor) -> torch.Tensor:
348349
f"Slice {b_idx} std {slice_std} is not well distributed",
349350
)
350351

352+
def _test_rng_with_dynamic_tile_sizes(self, rng_func, is_uniform, rng_name):
353+
"""Common test logic for RNG operations with dynamic tile sizes."""
354+
355+
# Single kernel that takes an RNG callable as a parameter
356+
@helion.kernel(static_shapes=True, autotune_effort="none")
357+
def rng_kernel(
358+
x: torch.Tensor,
359+
rng_func: Callable[[int, int, torch.dtype], torch.Tensor],
360+
) -> torch.Tensor:
361+
output = torch.zeros_like(x)
362+
m, n = x.shape
363+
for tile_m, tile_n in hl.tile([m, n]):
364+
output[tile_m, tile_n] = rng_func(tile_m, tile_n, x.dtype)
365+
return output
366+
367+
x = torch.ones(64, 64, device=DEVICE)
368+
torch.manual_seed(42)
369+
_code, output = code_and_output(rng_kernel, (x, rng_func))
370+
371+
# Check distribution properties based on RNG type
372+
if is_uniform:
373+
# For rand: values in [0, 1), mean ~0.5
374+
self.assertTrue(
375+
torch.all(output >= 0.0), f"{rng_name}: All values should be >= 0"
376+
)
377+
self.assertTrue(
378+
torch.all(output < 1.0), f"{rng_name}: All values should be < 1"
379+
)
380+
mean_val = output.mean().item()
381+
self.assertTrue(
382+
0.4 < mean_val < 0.6,
383+
f"{rng_name}: Mean {mean_val:.3f} should be ~0.5",
384+
)
385+
else:
386+
# For randn: mean ~0, std ~1
387+
mean_val = output.mean().item()
388+
std_val = output.std().item()
389+
self.assertTrue(
390+
-0.15 < mean_val < 0.15, f"{rng_name}: Mean {mean_val:.3f} should be ~0"
391+
)
392+
self.assertTrue(
393+
0.9 < std_val < 1.1, f"{rng_name}: Std {std_val:.3f} should be ~1"
394+
)
395+
396+
# Test reproducibility with same seed
397+
torch.manual_seed(42)
398+
_code2, output2 = code_and_output(rng_kernel, (x, rng_func))
399+
torch.testing.assert_close(
400+
output,
401+
output2,
402+
msg=f"{rng_name}: Same seed should produce identical outputs",
403+
)
404+
405+
# Test that different seeds produce different outputs
406+
torch.manual_seed(99)
407+
_code3, output3 = code_and_output(rng_kernel, (x, rng_func))
408+
self.assertFalse(
409+
torch.allclose(output, output3),
410+
f"{rng_name}: Different seeds should produce different outputs",
411+
)
412+
413+
def test_rand_with_dynamic_tile_sizes(self):
414+
"""Test torch.rand with dynamic tile dimensions."""
415+
self._test_rng_with_dynamic_tile_sizes(
416+
rng_func=lambda tile_m, tile_n, dtype: torch.rand(
417+
(tile_m, tile_n), dtype=dtype, device=DEVICE
418+
),
419+
is_uniform=True,
420+
rng_name="rand",
421+
)
422+
423+
def test_rand_like_with_dynamic_tile_sizes(self):
424+
"""Test torch.rand_like with dynamic tile dimensions."""
425+
self._test_rng_with_dynamic_tile_sizes(
426+
rng_func=lambda tile_m, tile_n, dtype: torch.rand_like(
427+
torch.ones((tile_m, tile_n), dtype=dtype, device=DEVICE)
428+
),
429+
is_uniform=True,
430+
rng_name="rand_like",
431+
)
432+
433+
def test_randn_with_dynamic_tile_sizes(self):
434+
"""Test torch.randn with dynamic tile dimensions."""
435+
self._test_rng_with_dynamic_tile_sizes(
436+
rng_func=lambda tile_m, tile_n, dtype: torch.randn(
437+
(tile_m, tile_n), dtype=dtype, device=DEVICE
438+
),
439+
is_uniform=False,
440+
rng_name="randn",
441+
)
442+
443+
def test_randn_like_with_dynamic_tile_sizes(self):
444+
"""Test torch.randn_like with dynamic tile dimensions."""
445+
self._test_rng_with_dynamic_tile_sizes(
446+
rng_func=lambda tile_m, tile_n, dtype: torch.randn_like(
447+
torch.ones((tile_m, tile_n), dtype=dtype, device=DEVICE)
448+
),
449+
is_uniform=False,
450+
rng_name="randn_like",
451+
)
452+
351453

352454
if __name__ == "__main__":
353455
unittest.main()

0 commit comments

Comments
 (0)