@@ -185,6 +185,115 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor,
185185 _launcher(_helion_broadcast_add_3d, (triton.cdiv(d0, _BLOCK_SIZE_0) * triton.cdiv(d1, _BLOCK_SIZE_1) * triton.cdiv(d2, _BLOCK_SIZE_2),), x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
186186 return out
187187
188+ --- assertExpectedJournal(TestIndexing.test_indirect_indexing_2d)
189+ from __future__ import annotations
190+
191+ import torch
192+ import triton
193+ import triton.language as tl
194+ from helion.runtime import default_launcher as _default_launcher
195+
196+ @triton.jit
197+ def _helion_test(col, B_flat, val, C, B_flat_stride_0, C_stride_0, C_stride_1, col_stride_0, col_stride_1, val_stride_0, val_stride_1, M, N, K, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
198+ num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_0)
199+ pid_0 = tl.program_id(0) % num_blocks_0
200+ pid_1 = tl.program_id(0) // num_blocks_0
201+ offset_0 = pid_0 * _BLOCK_SIZE_0
202+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
203+ mask_0 = indices_0 < M
204+ offset_1 = pid_1 * _BLOCK_SIZE_1
205+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
206+ mask_1 = indices_1 < N
207+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
208+ for offset_3 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_2):
209+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
210+ mask_2 = indices_3 < K
211+ acc_copy = acc
212+ acc_copy_0 = acc_copy
213+ cols_2d = tl.load(col + (indices_0[:, None] * col_stride_0 + indices_3[None, :] * col_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
214+ v_0 = cols_2d * N
215+ subscript = v_0[:, :, None]
216+ v_1 = tl.cast(indices_1, tl.int64)
217+ v_2 = subscript + v_1
218+ B_slice = tl.load(B_flat + v_2 * B_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], other=0)
219+ vals_2d = tl.load(val + (indices_0[:, None] * val_stride_0 + indices_3[None, :] * val_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
220+ subscript_1 = vals_2d[:, :, None]
221+ v_3 = subscript_1 * B_slice
222+ contrib_1 = tl.cast(tl.sum(v_3, 1), tl.float32)
223+ acc = acc_copy_0 + contrib_1
224+ tl.store(C + (indices_0[:, None] * C_stride_0 + indices_1[None, :] * C_stride_1), acc, mask_0[:, None] & mask_1[None, :])
225+
226+ def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
227+ M, K = col.shape
228+ _, N = B.shape
229+ out_dtype = torch.promote_types(val.dtype, B.dtype)
230+ C = torch.empty((M, N), dtype=out_dtype, device=B.device)
231+ B_flat = B.reshape(-1)
232+ _BLOCK_SIZE_0 = 8
233+ _BLOCK_SIZE_1 = 8
234+ _BLOCK_SIZE_2 = 4
235+ _RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_1)
236+ _launcher(_helion_test, (triton.cdiv(M, _BLOCK_SIZE_0) * triton.cdiv(N, _BLOCK_SIZE_1),), col, B_flat, val, C, B_flat.stride(0), C.stride(0), C.stride(1), col.stride(0), col.stride(1), val.stride(0), val.stride(1), M, N, K, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
237+ return C
238+
239+ --- assertExpectedJournal(TestIndexing.test_indirect_indexing_3d)
240+ from __future__ import annotations
241+
242+ import torch
243+ import triton
244+ import triton.language as tl
245+ from helion.runtime import default_launcher as _default_launcher
246+
247+ @triton.jit
248+ def _helion_test(col, B, val, C, B_stride_0, B_stride_1, B_stride_2, C_stride_0, C_stride_1, C_stride_2, C_stride_3, col_stride_0, col_stride_1, col_stride_2, val_stride_0, val_stride_1, val_stride_2, M, N, P, Q, K, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
249+ num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_0)
250+ num_blocks_1 = tl.cdiv(N, _BLOCK_SIZE_1)
251+ num_blocks_2 = tl.cdiv(P, _BLOCK_SIZE_2)
252+ pid_0 = tl.program_id(0) % num_blocks_0
253+ pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
254+ pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) % num_blocks_2
255+ pid_3 = tl.program_id(0) // (num_blocks_0 * num_blocks_1 * num_blocks_2)
256+ offset_0 = pid_0 * _BLOCK_SIZE_0
257+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
258+ mask_0 = indices_0 < M
259+ offset_1 = pid_1 * _BLOCK_SIZE_1
260+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
261+ mask_1 = indices_1 < N
262+ offset_2 = pid_2 * _BLOCK_SIZE_2
263+ indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
264+ mask_2 = indices_2 < P
265+ offset_3 = pid_3 * _BLOCK_SIZE_3
266+ indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32)
267+ mask_3 = indices_3 < Q
268+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32)
269+ for offset_5 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_4):
270+ indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
271+ mask_4 = indices_5 < K
272+ acc_copy = acc
273+ acc_copy_0 = acc_copy
274+ cols_3d = tl.load(col + (indices_0[:, None, None] * col_stride_0 + indices_1[None, :, None] * col_stride_1 + indices_5[None, None, :] * col_stride_2), mask_0[:, None, None] & mask_1[None, :, None] & mask_4[None, None, :], other=0)
275+ B_slice = tl.load(B + (cols_3d[:, :, :, None, None] * B_stride_0 + indices_2[None, None, None, :, None] * B_stride_1 + indices_3[None, None, None, None, :] * B_stride_2), mask_0[:, None, None, None, None] & mask_1[None, :, None, None, None] & mask_4[None, None, :, None, None] & mask_2[None, None, None, :, None] & mask_3[None, None, None, None, :], other=0)
276+ vals_3d = tl.load(val + (indices_0[:, None, None] * val_stride_0 + indices_1[None, :, None] * val_stride_1 + indices_5[None, None, :] * val_stride_2), mask_0[:, None, None] & mask_1[None, :, None] & mask_4[None, None, :], other=0)
277+ subscript = vals_3d[:, :, :, None, None]
278+ v_0 = subscript * B_slice
279+ contrib_1 = tl.cast(tl.sum(v_0, 2), tl.float32)
280+ acc = acc_copy_0 + contrib_1
281+ tl.store(C + (indices_0[:, None, None, None] * C_stride_0 + indices_1[None, :, None, None] * C_stride_1 + indices_2[None, None, :, None] * C_stride_2 + indices_3[None, None, None, :] * C_stride_3), acc, mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None] & mask_3[None, None, None, :])
282+
283+ def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
284+ M, N, K = col.shape
285+ _, P, Q = B.shape
286+ out_dtype = torch.promote_types(val.dtype, B.dtype)
287+ C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
288+ _BLOCK_SIZE_0 = 4
289+ _BLOCK_SIZE_1 = 4
290+ _BLOCK_SIZE_2 = 4
291+ _BLOCK_SIZE_3 = 4
292+ _BLOCK_SIZE_4 = 4
293+ _RDIM_SIZE_5 = triton.next_power_of_2(_BLOCK_SIZE_2)
294+ _launcher(_helion_test, (triton.cdiv(M, _BLOCK_SIZE_0) * triton.cdiv(N, _BLOCK_SIZE_1) * triton.cdiv(P, _BLOCK_SIZE_2) * triton.cdiv(Q, _BLOCK_SIZE_3),), col, B, val, C, B.stride(0), B.stride(1), B.stride(2), C.stride(0), C.stride(1), C.stride(2), C.stride(3), col.stride(0), col.stride(1), col.stride(2), val.stride(0), val.stride(1), val.stride(2), M, N, P, Q, K, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=3)
295+ return C
296+
188297--- assertExpectedJournal(TestIndexing.test_mask_load)
189298from __future__ import annotations
190299
0 commit comments