Skip to content

Commit 6dbf730

Browse files
committed
expected
1 parent 2ad0bf3 commit 6dbf730

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

test/test_indexing.expected

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,115 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor,
185185
_launcher(_helion_broadcast_add_3d, (triton.cdiv(d0, _BLOCK_SIZE_0) * triton.cdiv(d1, _BLOCK_SIZE_1) * triton.cdiv(d2, _BLOCK_SIZE_2),), x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
186186
return out
187187

188+
--- assertExpectedJournal(TestIndexing.test_indirect_indexing_2d)
189+
from __future__ import annotations
190+
191+
import torch
192+
import triton
193+
import triton.language as tl
194+
from helion.runtime import default_launcher as _default_launcher
195+
196+
@triton.jit
197+
def _helion_test(col, B_flat, val, C, B_flat_stride_0, C_stride_0, C_stride_1, col_stride_0, col_stride_1, val_stride_0, val_stride_1, M, N, K, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
198+
num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_0)
199+
pid_0 = tl.program_id(0) % num_blocks_0
200+
pid_1 = tl.program_id(0) // num_blocks_0
201+
offset_0 = pid_0 * _BLOCK_SIZE_0
202+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
203+
mask_0 = indices_0 < M
204+
offset_1 = pid_1 * _BLOCK_SIZE_1
205+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
206+
mask_1 = indices_1 < N
207+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
208+
for offset_3 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_2):
209+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
210+
mask_2 = indices_3 < K
211+
acc_copy = acc
212+
acc_copy_0 = acc_copy
213+
cols_2d = tl.load(col + (indices_0[:, None] * col_stride_0 + indices_3[None, :] * col_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
214+
v_0 = cols_2d * N
215+
subscript = v_0[:, :, None]
216+
v_1 = tl.cast(indices_1, tl.int64)
217+
v_2 = subscript + v_1
218+
B_slice = tl.load(B_flat + v_2 * B_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], other=0)
219+
vals_2d = tl.load(val + (indices_0[:, None] * val_stride_0 + indices_3[None, :] * val_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
220+
subscript_1 = vals_2d[:, :, None]
221+
v_3 = subscript_1 * B_slice
222+
contrib_1 = tl.cast(tl.sum(v_3, 1), tl.float32)
223+
acc = acc_copy_0 + contrib_1
224+
tl.store(C + (indices_0[:, None] * C_stride_0 + indices_1[None, :] * C_stride_1), acc, mask_0[:, None] & mask_1[None, :])
225+
226+
def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
227+
M, K = col.shape
228+
_, N = B.shape
229+
out_dtype = torch.promote_types(val.dtype, B.dtype)
230+
C = torch.empty((M, N), dtype=out_dtype, device=B.device)
231+
B_flat = B.reshape(-1)
232+
_BLOCK_SIZE_0 = 8
233+
_BLOCK_SIZE_1 = 8
234+
_BLOCK_SIZE_2 = 4
235+
_RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_1)
236+
_launcher(_helion_test, (triton.cdiv(M, _BLOCK_SIZE_0) * triton.cdiv(N, _BLOCK_SIZE_1),), col, B_flat, val, C, B_flat.stride(0), C.stride(0), C.stride(1), col.stride(0), col.stride(1), val.stride(0), val.stride(1), M, N, K, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
237+
return C
238+
239+
--- assertExpectedJournal(TestIndexing.test_indirect_indexing_3d)
240+
from __future__ import annotations
241+
242+
import torch
243+
import triton
244+
import triton.language as tl
245+
from helion.runtime import default_launcher as _default_launcher
246+
247+
@triton.jit
248+
def _helion_test(col, B, val, C, B_stride_0, B_stride_1, B_stride_2, C_stride_0, C_stride_1, C_stride_2, C_stride_3, col_stride_0, col_stride_1, col_stride_2, val_stride_0, val_stride_1, val_stride_2, M, N, P, Q, K, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
249+
num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_0)
250+
num_blocks_1 = tl.cdiv(N, _BLOCK_SIZE_1)
251+
num_blocks_2 = tl.cdiv(P, _BLOCK_SIZE_2)
252+
pid_0 = tl.program_id(0) % num_blocks_0
253+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
254+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) % num_blocks_2
255+
pid_3 = tl.program_id(0) // (num_blocks_0 * num_blocks_1 * num_blocks_2)
256+
offset_0 = pid_0 * _BLOCK_SIZE_0
257+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
258+
mask_0 = indices_0 < M
259+
offset_1 = pid_1 * _BLOCK_SIZE_1
260+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
261+
mask_1 = indices_1 < N
262+
offset_2 = pid_2 * _BLOCK_SIZE_2
263+
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
264+
mask_2 = indices_2 < P
265+
offset_3 = pid_3 * _BLOCK_SIZE_3
266+
indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32)
267+
mask_3 = indices_3 < Q
268+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32)
269+
for offset_5 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_4):
270+
indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
271+
mask_4 = indices_5 < K
272+
acc_copy = acc
273+
acc_copy_0 = acc_copy
274+
cols_3d = tl.load(col + (indices_0[:, None, None] * col_stride_0 + indices_1[None, :, None] * col_stride_1 + indices_5[None, None, :] * col_stride_2), mask_0[:, None, None] & mask_1[None, :, None] & mask_4[None, None, :], other=0)
275+
B_slice = tl.load(B + (cols_3d[:, :, :, None, None] * B_stride_0 + indices_2[None, None, None, :, None] * B_stride_1 + indices_3[None, None, None, None, :] * B_stride_2), mask_0[:, None, None, None, None] & mask_1[None, :, None, None, None] & mask_4[None, None, :, None, None] & mask_2[None, None, None, :, None] & mask_3[None, None, None, None, :], other=0)
276+
vals_3d = tl.load(val + (indices_0[:, None, None] * val_stride_0 + indices_1[None, :, None] * val_stride_1 + indices_5[None, None, :] * val_stride_2), mask_0[:, None, None] & mask_1[None, :, None] & mask_4[None, None, :], other=0)
277+
subscript = vals_3d[:, :, :, None, None]
278+
v_0 = subscript * B_slice
279+
contrib_1 = tl.cast(tl.sum(v_0, 2), tl.float32)
280+
acc = acc_copy_0 + contrib_1
281+
tl.store(C + (indices_0[:, None, None, None] * C_stride_0 + indices_1[None, :, None, None] * C_stride_1 + indices_2[None, None, :, None] * C_stride_2 + indices_3[None, None, None, :] * C_stride_3), acc, mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None] & mask_3[None, None, None, :])
282+
283+
def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
284+
M, N, K = col.shape
285+
_, P, Q = B.shape
286+
out_dtype = torch.promote_types(val.dtype, B.dtype)
287+
C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
288+
_BLOCK_SIZE_0 = 4
289+
_BLOCK_SIZE_1 = 4
290+
_BLOCK_SIZE_2 = 4
291+
_BLOCK_SIZE_3 = 4
292+
_BLOCK_SIZE_4 = 4
293+
_RDIM_SIZE_5 = triton.next_power_of_2(_BLOCK_SIZE_2)
294+
_launcher(_helion_test, (triton.cdiv(M, _BLOCK_SIZE_0) * triton.cdiv(N, _BLOCK_SIZE_1) * triton.cdiv(P, _BLOCK_SIZE_2) * triton.cdiv(Q, _BLOCK_SIZE_3),), col, B, val, C, B.stride(0), B.stride(1), B.stride(2), C.stride(0), C.stride(1), C.stride(2), C.stride(3), col.stride(0), col.stride(1), col.stride(2), val.stride(0), val.stride(1), val.stride(2), M, N, P, Q, K, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=3)
295+
return C
296+
188297
--- assertExpectedJournal(TestIndexing.test_mask_load)
189298
from __future__ import annotations
190299

0 commit comments

Comments
 (0)