Skip to content

Commit f60f36a

Browse files
committed
test
1 parent a2bb673 commit f60f36a

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

test/test_matmul.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,54 @@ def test_matmul_static_shapes3(self):
199199
torch.testing.assert_close(output, args[0] @ args[1], atol=1e-1, rtol=1e-2)
200200
self.assertExpectedJournal(code)
201201

202+
def test_matmul_packed_int4_block_size_constexpr(self):
203+
torch.manual_seed(0)
204+
M = N = K = 32
205+
206+
@helion.kernel(use_default_config=True, static_shapes=True)
207+
def matmul_bf16_packed_int4(
208+
A: torch.Tensor, B_packed: torch.Tensor, C: torch.Tensor
209+
) -> torch.Tensor:
210+
M0, K0 = A.shape
211+
_, N0 = B_packed.shape
212+
213+
block_n = hl.register_block_size(N0)
214+
block_k = hl.register_block_size(K0)
215+
216+
for tile_m in hl.tile(M0):
217+
for tile_n in hl.tile(N0, block_size=block_n):
218+
acc = hl.zeros((tile_m, tile_n), dtype=torch.float32)
219+
220+
for tile_k in hl.tile(K0, block_size=block_k):
221+
tile_k_begin = tile_k.begin
222+
b_tile = B_packed[
223+
tile_k_begin // 2 : tile_k_begin // 2 + block_k // 2,
224+
tile_n,
225+
]
226+
shift = hl.full((1,), 4, dtype=torch.int8)
227+
b_lo = (b_tile << shift) >> shift
228+
b_hi = b_tile >> shift
229+
stacked = torch.stack(
230+
(b_lo.to(torch.float16), b_hi.to(torch.float16)), dim=2
231+
)
232+
stacked = stacked.permute(0, 2, 1)
233+
b_block = stacked.reshape([block_k, block_n])
234+
acc = hl.dot(A[tile_m, tile_k], b_block, acc=acc)
235+
236+
C[tile_m, tile_n] = acc
237+
238+
return C
239+
240+
A = torch.randn((M, K), dtype=torch.bfloat16, device=DEVICE)
241+
B_packed = torch.randint(0, 16, (K // 2, N), dtype=torch.int8, device=DEVICE)
242+
C = torch.zeros((M, N), dtype=torch.float32, device=DEVICE)
243+
244+
matmul_bf16_packed_int4(A, B_packed, C)
245+
torch.cuda.synchronize()
246+
247+
self.assertTrue(torch.isfinite(C).all())
248+
self.assertFalse(torch.allclose(C, torch.zeros_like(C)))
249+
202250
def test_matmul_split_k(self):
203251
@helion.kernel(dot_precision="ieee")
204252
def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)