@@ -10,38 +10,36 @@ import triton.language as tl
1010from helion.runtime import default_launcher as _default_launcher
1111
1212@triton.jit
13- def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal_stride_0, normal_stride_1, rand1_stride_0, rand1_stride_1, rand2_stride_0, rand2_stride_1, randn_a_stride_0, randn_a_stride_1, randn_b_stride_0, randn_b_stride_1, randn_c_stride_0, randn_c_stride_1, uniform_stride_0, uniform_stride_1, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, rng_seed_buffer):
13+ def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, rng_seed_buffer):
1414 # src[test_rng.py:N]: for tile_m, tile_n in hl.tile([m, n]):
15- num_blocks_0 = tl.cdiv(m , _BLOCK_SIZE_0)
15+ num_blocks_0 = tl.cdiv(64 , _BLOCK_SIZE_0)
1616 pid_0 = tl.program_id(0) % num_blocks_0
1717 pid_1 = tl.program_id(0) // num_blocks_0
1818 offset_0 = pid_0 * _BLOCK_SIZE_0
1919 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
20- mask_0 = indices_0 < m
2120 offset_1 = pid_1 * _BLOCK_SIZE_1
2221 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
23- mask_1 = indices_1 < n
2422 # src[test_rng.py:N]: rand1[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n])
25- rand = tl.rand(tl.load(rng_seed_buffer + 0), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
26- tl.store(rand1 + (indices_0[:, None] * rand1_stride_0 + indices_1[None, :] * rand1_stride_1 ), rand, mask_0[:, None] & mask_1[None, :] )
23+ rand = tl.rand(tl.load(rng_seed_buffer + 0), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
24+ tl.store(rand1 + (indices_0[:, None] * 64 + indices_1[None, :] * 1 ), rand, None)
2725 # src[test_rng.py:N]: rand2[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n])
28- rand_1 = tl.rand(tl.load(rng_seed_buffer + 1), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
29- tl.store(rand2 + (indices_0[:, None] * rand2_stride_0 + indices_1[None, :] * rand2_stride_1 ), rand_1, mask_0[:, None] & mask_1[None, :] )
26+ rand_1 = tl.rand(tl.load(rng_seed_buffer + 1), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
27+ tl.store(rand2 + (indices_0[:, None] * 64 + indices_1[None, :] * 1 ), rand_1, None)
3028 # src[test_rng.py:N]: uniform[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n])
31- rand_2 = tl.rand(tl.load(rng_seed_buffer + 2), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
32- tl.store(uniform + (indices_0[:, None] * uniform_stride_0 + indices_1[None, :] * uniform_stride_1 ), rand_2, mask_0[:, None] & mask_1[None, :] )
29+ rand_2 = tl.rand(tl.load(rng_seed_buffer + 2), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
30+ tl.store(uniform + (indices_0[:, None] * 64 + indices_1[None, :] * 1 ), rand_2, None)
3331 # src[test_rng.py:N]: normal[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n])
34- randn = tl.randn(tl.load(rng_seed_buffer + 3), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
35- tl.store(normal + (indices_0[:, None] * normal_stride_0 + indices_1[None, :] * normal_stride_1 ), randn, mask_0[:, None] & mask_1[None, :] )
32+ randn = tl.randn(tl.load(rng_seed_buffer + 3), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
33+ tl.store(normal + (indices_0[:, None] * 64 + indices_1[None, :] * 1 ), randn, None)
3634 # src[test_rng.py:N]: randn_a[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n])
37- randn_1 = tl.randn(tl.load(rng_seed_buffer + 4), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
38- tl.store(randn_a + (indices_0[:, None] * randn_a_stride_0 + indices_1[None, :] * randn_a_stride_1 ), randn_1, mask_0[:, None] & mask_1[None, :] )
35+ randn_1 = tl.randn(tl.load(rng_seed_buffer + 4), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
36+ tl.store(randn_a + (indices_0[:, None] * 64 + indices_1[None, :] * 1 ), randn_1, None)
3937 # src[test_rng.py:N]: randn_b[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n])
40- randn_2 = tl.randn(tl.load(rng_seed_buffer + 5), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
41- tl.store(randn_b + (indices_0[:, None] * randn_b_stride_0 + indices_1[None, :] * randn_b_stride_1 ), randn_2, mask_0[:, None] & mask_1[None, :] )
38+ randn_2 = tl.randn(tl.load(rng_seed_buffer + 5), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
39+ tl.store(randn_b + (indices_0[:, None] * 64 + indices_1[None, :] * 1 ), randn_2, None)
4240 # src[test_rng.py:N]: randn_c[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n])
43- randn_3 = tl.randn(tl.load(rng_seed_buffer + 6), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
44- tl.store(randn_c + (indices_0[:, None] * randn_c_stride_0 + indices_1[None, :] * randn_c_stride_1 ), randn_3, mask_0[:, None] & mask_1[None, :] )
41+ randn_3 = tl.randn(tl.load(rng_seed_buffer + 6), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
42+ tl.store(randn_c + (indices_0[:, None] * 64 + indices_1[None, :] * 1 ), randn_3, None)
4543
4644def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
4745 from torch._inductor import inductor_prims
@@ -73,7 +71,7 @@ def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
7371 # src[test_rng.py:N]: # Two independent rand operations
7472 # src[test_rng.py:N]: rand1[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n])
7573 # src[test_rng.py:N-N]: ...
76- _launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m , _BLOCK_SIZE_0) * triton.cdiv(n , _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n , _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=1)
74+ _launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(64 , _BLOCK_SIZE_0) * triton.cdiv(64 , _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=1)
7775 # src[test_rng.py:N]: randn_sum = randn_a + randn_b + randn_c
7876 randn_sum = randn_a + randn_b + randn_c
7977 # src[test_rng.py:N]: return rand1, rand2, uniform, normal, randn_sum
0 commit comments