Skip to content

Commit 44d624f

Browse files
committed
test
1 parent 6a98dc9 commit 44d624f

File tree

3 files changed

+293
-1
lines changed

3 files changed

+293
-1
lines changed

test/test_errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def fn(x: torch.Tensor) -> torch.Tensor:
197197
out = x.new_empty(batch)
198198
for tile_batch in hl.tile([batch]):
199199
scalar_val = x[tile_batch].sum() # 1d index for 2d tensor
200-
out = scalar_val
200+
out[tile_batch] = scalar_val
201201
return out
202202

203203
with self.assertRaisesRegex(

test/test_indexing.expected

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,174 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor,
285285
# src[test_indexing.py:N]: return out
286286
return out
287287

288+
--- assertExpectedJournal(TestIndexing.test_indirect_indexing_2d)
289+
from __future__ import annotations
290+
291+
import torch
292+
import triton
293+
import triton.language as tl
294+
from helion.runtime import default_launcher as _default_launcher
295+
296+
@triton.jit
297+
def _helion_test(col, B_flat, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
298+
# src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
299+
num_blocks_0 = tl.cdiv(32, _BLOCK_SIZE_0)
300+
pid_0 = tl.program_id(0) % num_blocks_0
301+
pid_1 = tl.program_id(0) // num_blocks_0
302+
offset_0 = pid_0 * _BLOCK_SIZE_0
303+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
304+
offset_1 = pid_1 * _BLOCK_SIZE_1
305+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
306+
# src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
307+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
308+
# src[test_indexing.py:N]: for tile_k in hl.tile(K):
309+
# src[test_indexing.py:N]: # [tile_m, tile_k]
310+
# src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
311+
# src[test_indexing.py:N-N]: ...
312+
for offset_3 in tl.range(0, 16, _BLOCK_SIZE_2):
313+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
314+
acc_copy = acc
315+
acc_copy_0 = acc_copy
316+
# src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
317+
cols_2d = tl.load(col + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None)
318+
# src[test_indexing.py:N]: [(cols_2d * N)[:, :, None] + tile_n.index[None, None, :]]
319+
v_0 = tl.full([], 24, tl.int64)
320+
v_1 = tl.cast(cols_2d * v_0, tl.int64)
321+
subscript = v_1[:, :, None]
322+
v_2 = tl.cast(indices_1, tl.int64)
323+
v_3 = subscript + v_2
324+
# src[test_indexing.py:N]: B_slice = hl.load(
325+
# src[test_indexing.py:N]: B_flat,
326+
# src[test_indexing.py:N]: [(cols_2d * N)[:, :, None] + tile_n.index[None, None, :]]
327+
# src[test_indexing.py:N-N]: ...
328+
B_slice = tl.load(B_flat + v_3 * 1, None)
329+
# src[test_indexing.py:N]: vals_2d = val[tile_m, tile_k]
330+
vals_2d = tl.load(val + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None)
331+
# src[test_indexing.py:N]: contrib = vals_2d[:, :, None] * B_slice
332+
subscript_1 = vals_2d[:, :, None]
333+
v_4 = subscript_1 * B_slice
334+
# src[test_indexing.py:N]: contrib = contrib.sum(dim=1)
335+
contrib_1 = tl.cast(tl.sum(v_4, 1), tl.float32)
336+
# src[test_indexing.py:N]: acc = acc + contrib
337+
acc = acc_copy_0 + contrib_1
338+
# src[test_indexing.py:N]: C[tile_m, tile_n] = acc.to(out_dtype)
339+
tl.store(C + (indices_0[:, None] * 24 + indices_1[None, :] * 1), acc, None)
340+
341+
def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
342+
# src[test_indexing.py:N]: M, K = col.shape
343+
M, K = col.shape
344+
# src[test_indexing.py:N]: _, N = B.shape
345+
_, N = B.shape
346+
# src[test_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype)
347+
out_dtype = torch.promote_types(val.dtype, B.dtype)
348+
# src[test_indexing.py:N]: C = torch.empty((M, N), dtype=out_dtype, device=B.device)
349+
C = torch.empty((M, N), dtype=out_dtype, device=B.device)
350+
# src[test_indexing.py:N]: B_flat = B.reshape(-1) # [K*N]
351+
B_flat = B.reshape(-1)
352+
# src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
353+
_BLOCK_SIZE_0 = 8
354+
_BLOCK_SIZE_1 = 8
355+
# src[test_indexing.py:N]: for tile_k in hl.tile(K):
356+
# src[test_indexing.py:N]: # [tile_m, tile_k]
357+
# src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
358+
# src[test_indexing.py:N-N]: ...
359+
_BLOCK_SIZE_2 = 4
360+
# src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
361+
# src[test_indexing.py:N]: # [tile_m, tile_n]
362+
# src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
363+
# src[test_indexing.py:N-N]: ...
364+
_RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_1)
365+
_launcher(_helion_test, (triton.cdiv(32, _BLOCK_SIZE_0) * triton.cdiv(24, _BLOCK_SIZE_1),), col, B_flat, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
366+
# src[test_indexing.py:N]: return C
367+
return C
368+
369+
--- assertExpectedJournal(TestIndexing.test_indirect_indexing_3d)
370+
from __future__ import annotations
371+
372+
import torch
373+
import triton
374+
import triton.language as tl
375+
from helion.runtime import default_launcher as _default_launcher
376+
377+
@triton.jit
378+
def _helion_test(col, B, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
379+
# src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
380+
num_blocks_0 = tl.cdiv(16, _BLOCK_SIZE_0)
381+
num_blocks_1 = tl.cdiv(12, _BLOCK_SIZE_1)
382+
num_blocks_2 = tl.cdiv(10, _BLOCK_SIZE_2)
383+
pid_0 = tl.program_id(0) % num_blocks_0
384+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
385+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) % num_blocks_2
386+
pid_3 = tl.program_id(0) // (num_blocks_0 * num_blocks_1 * num_blocks_2)
387+
offset_0 = pid_0 * _BLOCK_SIZE_0
388+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
389+
offset_1 = pid_1 * _BLOCK_SIZE_1
390+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
391+
offset_2 = pid_2 * _BLOCK_SIZE_2
392+
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
393+
mask_2 = indices_2 < 10
394+
offset_3 = pid_3 * _BLOCK_SIZE_3
395+
indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32)
396+
mask_3 = indices_3 < 14
397+
# src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
398+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32)
399+
# src[test_indexing.py:N]: for tile_k in hl.tile(K):
400+
# src[test_indexing.py:N]: # [tile_m, tile_n, tile_k]
401+
# src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
402+
# src[test_indexing.py:N-N]: ...
403+
for offset_5 in tl.range(0, 8, _BLOCK_SIZE_4):
404+
indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
405+
acc_copy = acc
406+
acc_copy_0 = acc_copy
407+
# src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
408+
cols_3d = tl.load(col + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None)
409+
# src[test_indexing.py:N]: cols_3d[:, :, :, None, None],
410+
subscript = cols_3d[:, :, :, None, None]
411+
# src[test_indexing.py:N]: B_slice = B[
412+
# src[test_indexing.py:N]: cols_3d[:, :, :, None, None],
413+
# src[test_indexing.py:N]: tile_p.index[None, None, :, None],
414+
# src[test_indexing.py:N-N]: ...
415+
B_slice = tl.load(B + (subscript * 140 + indices_2[None, None, None, :, None] * 14 + indices_3[None, None, None, None, :] * 1), mask_2[None, None, None, :, None] & mask_3[None, None, None, None, :], other=0)
416+
# src[test_indexing.py:N]: vals_3d = val[tile_m, tile_n, tile_k]
417+
vals_3d = tl.load(val + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None)
418+
# src[test_indexing.py:N]: contrib = vals_3d[:, :, :, None, None] * B_slice
419+
subscript_1 = vals_3d[:, :, :, None, None]
420+
v_0 = subscript_1 * B_slice
421+
# src[test_indexing.py:N]: contrib = contrib.sum(dim=2)
422+
contrib_1 = tl.cast(tl.sum(v_0, 2), tl.float32)
423+
# src[test_indexing.py:N]: acc = acc + contrib
424+
acc = acc_copy_0 + contrib_1
425+
# src[test_indexing.py:N]: C[tile_m, tile_n, tile_p, tile_q] = acc.to(out_dtype)
426+
tl.store(C + (indices_0[:, None, None, None] * 1680 + indices_1[None, :, None, None] * 140 + indices_2[None, None, :, None] * 14 + indices_3[None, None, None, :] * 1), acc, mask_2[None, None, :, None] & mask_3[None, None, None, :])
427+
428+
def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
429+
# src[test_indexing.py:N]: M, N, K = col.shape
430+
M, N, K = col.shape
431+
# src[test_indexing.py:N]: _, P, Q = B.shape
432+
_, P, Q = B.shape
433+
# src[test_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype)
434+
out_dtype = torch.promote_types(val.dtype, B.dtype)
435+
# src[test_indexing.py:N]: C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
436+
C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
437+
# src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
438+
_BLOCK_SIZE_0 = 4
439+
_BLOCK_SIZE_1 = 4
440+
_BLOCK_SIZE_2 = 4
441+
_BLOCK_SIZE_3 = 4
442+
# src[test_indexing.py:N]: for tile_k in hl.tile(K):
443+
# src[test_indexing.py:N]: # [tile_m, tile_n, tile_k]
444+
# src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
445+
# src[test_indexing.py:N-N]: ...
446+
_BLOCK_SIZE_4 = 4
447+
# src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
448+
# src[test_indexing.py:N]: # [tile_m, tile_n, tile_p, tile_q]
449+
# src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
450+
# src[test_indexing.py:N-N]: ...
451+
_RDIM_SIZE_5 = triton.next_power_of_2(_BLOCK_SIZE_2)
452+
_launcher(_helion_test, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(12, _BLOCK_SIZE_1) * triton.cdiv(10, _BLOCK_SIZE_2) * triton.cdiv(14, _BLOCK_SIZE_3),), col, B, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=1)
453+
# src[test_indexing.py:N]: return C
454+
return C
455+
288456
--- assertExpectedJournal(TestIndexing.test_mask_load)
289457
from __future__ import annotations
290458

test/test_indexing.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,6 +1604,130 @@ def load_store_kernel(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
16041604
self.assertEqual(code3, code4)
16051605
self.assertExpectedJournal(code4)
16061606

1607+
def test_indirect_indexing_2d(self):
1608+
@helion.kernel()
1609+
def test(
1610+
col: torch.Tensor, # [M, K] int64
1611+
val: torch.Tensor, # [M, K] fp32
1612+
B: torch.Tensor, # [K, N] fp32
1613+
) -> torch.Tensor: # [M, N] fp32
1614+
M, K = col.shape
1615+
_, N = B.shape
1616+
out_dtype = torch.promote_types(val.dtype, B.dtype)
1617+
C = torch.empty((M, N), dtype=out_dtype, device=B.device)
1618+
B_flat = B.reshape(-1) # [K*N]
1619+
1620+
for tile_m, tile_n in hl.tile([M, N]):
1621+
# [tile_m, tile_n]
1622+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
1623+
1624+
for tile_k in hl.tile(K):
1625+
# [tile_m, tile_k]
1626+
cols_2d = col[tile_m, tile_k]
1627+
# [tile_m, tile_k, tile_n]
1628+
B_slice = hl.load(
1629+
B_flat,
1630+
[(cols_2d * N)[:, :, None] + tile_n.index[None, None, :]]
1631+
)
1632+
# [tile_m, tile_k]
1633+
vals_2d = val[tile_m, tile_k]
1634+
# [tile_m, tile_k, tile_n]
1635+
contrib = vals_2d[:, :, None] * B_slice
1636+
# [tile_m, tile_n]
1637+
contrib = contrib.sum(dim=1)
1638+
# [tile_m, tile_n]
1639+
acc = acc + contrib
1640+
1641+
C[tile_m, tile_n] = acc.to(out_dtype)
1642+
1643+
return C
1644+
1645+
M, K, N = 32, 16, 24
1646+
col = torch.randint(0, K, (M, K), device=DEVICE, dtype=torch.int64)
1647+
val = torch.rand((M, K), device=DEVICE, dtype=torch.float32)
1648+
B = torch.rand((K, N), device=DEVICE, dtype=torch.float32)
1649+
1650+
code, result = code_and_output(
1651+
test,
1652+
(col, val, B),
1653+
block_size=[8, 8, 4],
1654+
)
1655+
1656+
# For each output position (i,j), compute sum over k: val[i,k] * B[col[i,k], j]
1657+
expected = torch.zeros((M, N), device=DEVICE, dtype=torch.float32)
1658+
for i in range(M):
1659+
for j in range(N):
1660+
for k in range(K):
1661+
expected[i, j] += val[i, k] * B[col[i, k], j]
1662+
1663+
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
1664+
self.assertExpectedJournal(code)
1665+
1666+
def test_indirect_indexing_3d(self):
1667+
@helion.kernel()
1668+
def test(
1669+
col: torch.Tensor, # [M, N, K] int64 - indices for first dimension of B
1670+
val: torch.Tensor, # [M, N, K] fp32 - values to multiply
1671+
B: torch.Tensor, # [K, P, Q] fp32 - tensor to index into
1672+
) -> torch.Tensor: # [M, N, P, Q] fp32
1673+
M, N, K = col.shape
1674+
_, P, Q = B.shape
1675+
out_dtype = torch.promote_types(val.dtype, B.dtype)
1676+
C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
1677+
1678+
for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
1679+
# [tile_m, tile_n, tile_p, tile_q]
1680+
acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
1681+
1682+
for tile_k in hl.tile(K):
1683+
# [tile_m, tile_n, tile_k]
1684+
cols_3d = col[tile_m, tile_n, tile_k]
1685+
1686+
# [tile_m, tile_n, tile_k, tile_p, tile_q]
1687+
# Direct indexing into B using gather
1688+
B_slice = B[
1689+
cols_3d[:, :, :, None, None],
1690+
tile_p.index[None, None, :, None],
1691+
tile_q.index[None, None, None, :],
1692+
]
1693+
1694+
# [tile_m, tile_n, tile_k]
1695+
vals_3d = val[tile_m, tile_n, tile_k]
1696+
1697+
# [tile_m, tile_n, tile_k, tile_p, tile_q]
1698+
contrib = vals_3d[:, :, :, None, None] * B_slice
1699+
1700+
# [tile_m, tile_n, tile_p, tile_q] - sum over k dimension
1701+
contrib = contrib.sum(dim=2)
1702+
1703+
# [tile_m, tile_n, tile_p, tile_q]
1704+
acc = acc + contrib
1705+
1706+
C[tile_m, tile_n, tile_p, tile_q] = acc.to(out_dtype)
1707+
return C
1708+
1709+
M, N, K, P, Q = 16, 12, 8, 10, 14
1710+
col = torch.randint(0, K, (M, N, K), device=DEVICE, dtype=torch.int64)
1711+
val = torch.rand((M, N, K), device=DEVICE, dtype=torch.float32)
1712+
B = torch.rand((K, P, Q), device=DEVICE, dtype=torch.float32)
1713+
1714+
code, result = code_and_output(
1715+
test,
1716+
(col, val, B),
1717+
block_size=[4, 4, 4, 4, 4], # 5D tiling for M, N, P, Q, K
1718+
)
1719+
1720+
# For each output position (i,j,p,q), compute sum over k: val[i,j,k] * B[col[i,j,k], p, q]
1721+
expected = torch.zeros((M, N, P, Q), device=DEVICE, dtype=torch.float32)
1722+
for i in range(M):
1723+
for j in range(N):
1724+
for p in range(P):
1725+
for q in range(Q):
1726+
for k in range(K):
1727+
expected[i, j, p, q] += val[i, j, k] * B[col[i, j, k], p, q]
1728+
1729+
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
1730+
self.assertExpectedJournal(code)
16071731

16081732
if __name__ == "__main__":
16091733
unittest.main()

0 commit comments

Comments
 (0)