@@ -285,6 +285,174 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor,
285285 # src[test_indexing.py:N]: return out
286286 return out
287287
288+ --- assertExpectedJournal(TestIndexing.test_indirect_indexing_2d)
289+ from __future__ import annotations
290+
291+ import torch
292+ import triton
293+ import triton.language as tl
294+ from helion.runtime import default_launcher as _default_launcher
295+
296+ @triton.jit
297+ def _helion_test(col, B_flat, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
298+ # src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
299+ num_blocks_0 = tl.cdiv(32, _BLOCK_SIZE_0)
300+ pid_0 = tl.program_id(0) % num_blocks_0
301+ pid_1 = tl.program_id(0) // num_blocks_0
302+ offset_0 = pid_0 * _BLOCK_SIZE_0
303+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
304+ offset_1 = pid_1 * _BLOCK_SIZE_1
305+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
306+ # src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
307+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
308+ # src[test_indexing.py:N]: for tile_k in hl.tile(K):
309+ # src[test_indexing.py:N]: # [tile_m, tile_k]
310+ # src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
311+ # src[test_indexing.py:N-N]: ...
312+ for offset_3 in tl.range(0, 16, _BLOCK_SIZE_2):
313+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
314+ acc_copy = acc
315+ acc_copy_0 = acc_copy
316+ # src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
317+ cols_2d = tl.load(col + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None)
318+ # src[test_indexing.py:N]: [(cols_2d * N)[:, :, None] + tile_n.index[None, None, :]]
319+ v_0 = tl.full([], 24, tl.int64)
320+ v_1 = tl.cast(cols_2d * v_0, tl.int64)
321+ subscript = v_1[:, :, None]
322+ v_2 = tl.cast(indices_1, tl.int64)
323+ v_3 = subscript + v_2
324+ # src[test_indexing.py:N]: B_slice = hl.load(
325+ # src[test_indexing.py:N]: B_flat,
326+ # src[test_indexing.py:N]: [(cols_2d * N)[:, :, None] + tile_n.index[None, None, :]]
327+ # src[test_indexing.py:N-N]: ...
328+ B_slice = tl.load(B_flat + v_3 * 1, None)
329+ # src[test_indexing.py:N]: vals_2d = val[tile_m, tile_k]
330+ vals_2d = tl.load(val + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None)
331+ # src[test_indexing.py:N]: contrib = vals_2d[:, :, None] * B_slice
332+ subscript_1 = vals_2d[:, :, None]
333+ v_4 = subscript_1 * B_slice
334+ # src[test_indexing.py:N]: contrib = contrib.sum(dim=1)
335+ contrib_1 = tl.cast(tl.sum(v_4, 1), tl.float32)
336+ # src[test_indexing.py:N]: acc = acc + contrib
337+ acc = acc_copy_0 + contrib_1
338+ # src[test_indexing.py:N]: C[tile_m, tile_n] = acc.to(out_dtype)
339+ tl.store(C + (indices_0[:, None] * 24 + indices_1[None, :] * 1), acc, None)
340+
341+ def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
342+ # src[test_indexing.py:N]: M, K = col.shape
343+ M, K = col.shape
344+ # src[test_indexing.py:N]: _, N = B.shape
345+ _, N = B.shape
346+ # src[test_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype)
347+ out_dtype = torch.promote_types(val.dtype, B.dtype)
348+ # src[test_indexing.py:N]: C = torch.empty((M, N), dtype=out_dtype, device=B.device)
349+ C = torch.empty((M, N), dtype=out_dtype, device=B.device)
350+ # src[test_indexing.py:N]: B_flat = B.reshape(-1) # [K*N]
351+ B_flat = B.reshape(-1)
352+ # src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
353+ _BLOCK_SIZE_0 = 8
354+ _BLOCK_SIZE_1 = 8
355+ # src[test_indexing.py:N]: for tile_k in hl.tile(K):
356+ # src[test_indexing.py:N]: # [tile_m, tile_k]
357+ # src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
358+ # src[test_indexing.py:N-N]: ...
359+ _BLOCK_SIZE_2 = 4
360+ # src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
361+ # src[test_indexing.py:N]: # [tile_m, tile_n]
362+ # src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
363+ # src[test_indexing.py:N-N]: ...
364+ _RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_1)
365+ _launcher(_helion_test, (triton.cdiv(32, _BLOCK_SIZE_0) * triton.cdiv(24, _BLOCK_SIZE_1),), col, B_flat, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
366+ # src[test_indexing.py:N]: return C
367+ return C
368+
369+ --- assertExpectedJournal(TestIndexing.test_indirect_indexing_3d)
370+ from __future__ import annotations
371+
372+ import torch
373+ import triton
374+ import triton.language as tl
375+ from helion.runtime import default_launcher as _default_launcher
376+
377+ @triton.jit
378+ def _helion_test(col, B, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
379+ # src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
380+ num_blocks_0 = tl.cdiv(16, _BLOCK_SIZE_0)
381+ num_blocks_1 = tl.cdiv(12, _BLOCK_SIZE_1)
382+ num_blocks_2 = tl.cdiv(10, _BLOCK_SIZE_2)
383+ pid_0 = tl.program_id(0) % num_blocks_0
384+ pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
385+ pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) % num_blocks_2
386+ pid_3 = tl.program_id(0) // (num_blocks_0 * num_blocks_1 * num_blocks_2)
387+ offset_0 = pid_0 * _BLOCK_SIZE_0
388+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
389+ offset_1 = pid_1 * _BLOCK_SIZE_1
390+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
391+ offset_2 = pid_2 * _BLOCK_SIZE_2
392+ indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
393+ mask_2 = indices_2 < 10
394+ offset_3 = pid_3 * _BLOCK_SIZE_3
395+ indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32)
396+ mask_3 = indices_3 < 14
397+ # src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
398+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32)
399+ # src[test_indexing.py:N]: for tile_k in hl.tile(K):
400+ # src[test_indexing.py:N]: # [tile_m, tile_n, tile_k]
401+ # src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
402+ # src[test_indexing.py:N-N]: ...
403+ for offset_5 in tl.range(0, 8, _BLOCK_SIZE_4):
404+ indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
405+ acc_copy = acc
406+ acc_copy_0 = acc_copy
407+ # src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
408+ cols_3d = tl.load(col + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None)
409+ # src[test_indexing.py:N]: cols_3d[:, :, :, None, None],
410+ subscript = cols_3d[:, :, :, None, None]
411+ # src[test_indexing.py:N]: B_slice = B[
412+ # src[test_indexing.py:N]: cols_3d[:, :, :, None, None],
413+ # src[test_indexing.py:N]: tile_p.index[None, None, :, None],
414+ # src[test_indexing.py:N-N]: ...
415+ B_slice = tl.load(B + (subscript * 140 + indices_2[None, None, None, :, None] * 14 + indices_3[None, None, None, None, :] * 1), mask_2[None, None, None, :, None] & mask_3[None, None, None, None, :], other=0)
416+ # src[test_indexing.py:N]: vals_3d = val[tile_m, tile_n, tile_k]
417+ vals_3d = tl.load(val + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None)
418+ # src[test_indexing.py:N]: contrib = vals_3d[:, :, :, None, None] * B_slice
419+ subscript_1 = vals_3d[:, :, :, None, None]
420+ v_0 = subscript_1 * B_slice
421+ # src[test_indexing.py:N]: contrib = contrib.sum(dim=2)
422+ contrib_1 = tl.cast(tl.sum(v_0, 2), tl.float32)
423+ # src[test_indexing.py:N]: acc = acc + contrib
424+ acc = acc_copy_0 + contrib_1
425+ # src[test_indexing.py:N]: C[tile_m, tile_n, tile_p, tile_q] = acc.to(out_dtype)
426+ tl.store(C + (indices_0[:, None, None, None] * 1680 + indices_1[None, :, None, None] * 140 + indices_2[None, None, :, None] * 14 + indices_3[None, None, None, :] * 1), acc, mask_2[None, None, :, None] & mask_3[None, None, None, :])
427+
428+ def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
429+ # src[test_indexing.py:N]: M, N, K = col.shape
430+ M, N, K = col.shape
431+ # src[test_indexing.py:N]: _, P, Q = B.shape
432+ _, P, Q = B.shape
433+ # src[test_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype)
434+ out_dtype = torch.promote_types(val.dtype, B.dtype)
435+ # src[test_indexing.py:N]: C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
436+ C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
437+ # src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
438+ _BLOCK_SIZE_0 = 4
439+ _BLOCK_SIZE_1 = 4
440+ _BLOCK_SIZE_2 = 4
441+ _BLOCK_SIZE_3 = 4
442+ # src[test_indexing.py:N]: for tile_k in hl.tile(K):
443+ # src[test_indexing.py:N]: # [tile_m, tile_n, tile_k]
444+ # src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
445+ # src[test_indexing.py:N-N]: ...
446+ _BLOCK_SIZE_4 = 4
447+ # src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
448+ # src[test_indexing.py:N]: # [tile_m, tile_n, tile_p, tile_q]
449+ # src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
450+ # src[test_indexing.py:N-N]: ...
451+ _RDIM_SIZE_5 = triton.next_power_of_2(_BLOCK_SIZE_2)
452+ _launcher(_helion_test, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(12, _BLOCK_SIZE_1) * triton.cdiv(10, _BLOCK_SIZE_2) * triton.cdiv(14, _BLOCK_SIZE_3),), col, B, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=1)
453+ # src[test_indexing.py:N]: return C
454+ return C
455+
288456--- assertExpectedJournal(TestIndexing.test_mask_load)
289457from __future__ import annotations
290458
0 commit comments