2828@helion .kernel (
2929 # static_shapes=True gives a performance boost for matmuls
3030 static_shapes = True ,
31+ config = helion .Config (
32+ block_sizes = [64 , 64 , 64 ],
33+ loop_orders = [[0 , 1 ]],
34+ l2_groupings = [4 ],
35+ range_unroll_factors = [0 , 1 ],
36+ range_num_stages = [0 , 3 ],
37+ range_multi_buffers = [None , False ],
38+ range_flattens = [None , None ],
39+ num_warps = 8 ,
40+ num_stages = 6 ,
41+ indexing = 'tensor_descriptor' ,
42+ pid_type = 'flat'
43+ )
3144)
3245def matmul (
3346 x : Tensor ,
@@ -44,17 +57,22 @@ def matmul(
4457 Returns:
4558 Tensor: Resulting matrix of shape [m, n].
4659 """
60+
4761 m , k = x .size ()
4862 k2 , n = y .size ()
4963 assert k == k2 , f"size mismatch { k } != { k2 } "
5064 out = torch .empty (
5165 [m , n ], dtype = torch .promote_types (x .dtype , y .dtype ), device = x .device
5266 )
53- for tile_m , tile_n in hl .tile ([m , n ]):
67+ block_m = hl .register_block_size (m )
68+ block_n = hl .register_block_size (n )
69+ for tile_m , tile_n in hl .tile ([m , n ], block_size = [block_m , block_n ]):
5470 acc = hl .zeros ([tile_m , tile_n ], dtype = torch .float32 )
5571 for tile_k in hl .tile (k ):
5672 acc = torch .addmm (acc , x [tile_m , tile_k ], y [tile_k , tile_n ])
57- out [tile_m , tile_n ] = epilogue (acc , (tile_m , tile_n ))
73+
74+ acc = epilogue (acc , (tile_m , tile_n ))
75+ out [tile_m , tile_n ] = acc
5876 return out
5977
6078
@@ -298,97 +316,97 @@ def check(m: int, k: int, n: int) -> None:
298316 # Test without bias
299317 run_example (matmul , torch .matmul , (x , y ))
300318
301- # Test for addmm with scalar bias
302- def addmm (bias : Tensor , mat1 : Tensor , mat2 : Tensor ) -> Tensor :
303- m , k = mat1 .size ()
304- k2 , n = mat2 .size ()
305- bias = torch .broadcast_to (bias , [m , n ])
306- return matmul (mat1 , mat2 , lambda acc , tile : acc + bias [tile [0 ], tile [1 ]])
307-
308- run_example (addmm , torch .addmm , (bias_scalar , x , y ))
309-
310- # Test with bias
311- def helion_linear (x : Tensor , y : Tensor , bias : Tensor ) -> Tensor :
312- return matmul (x , y , lambda acc , tile : acc + bias [tile [1 ]])
313-
314- def baseline_linear (x : Tensor , y : Tensor , bias : Tensor ) -> Tensor :
315- return torch .nn .functional .linear (x , y .T , bias )
316-
317- run_example (helion_linear , baseline_linear , (x , y , bias ))
318-
319- # Test more complex epilogue
320- def epilogue (acc : Tensor , tile : tuple [Tensor , ...]) -> Tensor :
321- # The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg
322- return torch .relu (acc + bias [tile [1 ]])
323-
324- def kernel_wrapper (x : Tensor , y : Tensor ) -> Tensor :
325- return matmul (x , y , epilogue )
326-
327- def baseline_wrapper (x : Tensor , y : Tensor ) -> Tensor :
328- return torch .relu (x @ y + bias )
329-
330- run_example (
331- kernel_wrapper ,
332- baseline_wrapper ,
333- (x , y ),
334- )
335-
336- # Test matmul forward + backward pass
337- print ("\n \n === MatMul Forward + Backward Pass Test ===" )
338- x_grad = torch .randn ([m , k ], device = DEVICE , dtype = torch .float16 , requires_grad = True )
339- y_grad = torch .randn ([k , n ], device = DEVICE , dtype = torch .float16 , requires_grad = True )
340-
341- run_example (
342- matmul_autograd ,
343- torch .matmul ,
344- (x_grad , y_grad ),
345- kernel_name = "helion_matmul_autograd" ,
346- baseline_name = "torch" ,
347- rtol = 1e-2 ,
348- atol = 1e-2 ,
349- bwd = True ,
350- )
351-
352- # Test addmm forward + backward pass
353- print ("\n \n === AddMM Forward + Backward Pass Test ===" )
354- input_grad = torch .randn (
355- [m , n ], device = DEVICE , dtype = torch .float16 , requires_grad = True
356- )
357- mat1_grad = torch .randn (
358- [m , k ], device = DEVICE , dtype = torch .float16 , requires_grad = True
359- )
360- mat2_grad = torch .randn (
361- [k , n ], device = DEVICE , dtype = torch .float16 , requires_grad = True
362- )
363-
364- # Use lambda to handle the keyword argument format for torch.addmm
365- run_example (
366- addmm_autograd ,
367- lambda bias , mat1 , mat2 , alpha , beta : torch .addmm (
368- bias , mat1 , mat2 , alpha = alpha , beta = beta
369- ),
370- (input_grad , mat1_grad , mat2_grad , 1.0 , 1.0 ),
371- kernel_name = "helion_addmm_autograd" ,
372- baseline_name = "torch" ,
373- rtol = 1e-2 ,
374- atol = 1e-2 ,
375- bwd = True ,
376- )
377-
378- # Test addmm forward + backward with different alpha/beta values
379- print ("\n \n === AddMM Forward + Backward Test (Alpha=2.0, Beta=0.5) ===" )
380- run_example (
381- addmm_autograd ,
382- lambda bias , mat1 , mat2 , alpha , beta : torch .addmm (
383- bias , mat1 , mat2 , alpha = alpha , beta = beta
384- ),
385- (input_grad , mat1_grad , mat2_grad , 2.0 , 0.5 ),
386- kernel_name = "helion_addmm_autograd_scaled" ,
387- baseline_name = "torch" ,
388- rtol = 1e-2 ,
389- atol = 1e-2 ,
390- bwd = True ,
391- )
319+ # # Test for addmm with scalar bias
320+ # def addmm(bias: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor:
321+ # m, k = mat1.size()
322+ # k2, n = mat2.size()
323+ # bias = torch.broadcast_to(bias, [m, n])
324+ # return matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]])
325+
326+ # run_example(addmm, torch.addmm, (bias_scalar, x, y))
327+
328+ # # Test with bias
329+ # def helion_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
330+ # return matmul(x, y, lambda acc, tile: acc + bias[tile[1]])
331+
332+ # def baseline_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
333+ # return torch.nn.functional.linear(x, y.T, bias)
334+
335+ # run_example(helion_linear, baseline_linear, (x, y, bias))
336+
337+ # # Test more complex epilogue
338+ # def epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor:
339+ # # The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg
340+ # return torch.relu(acc + bias[tile[1]])
341+
342+ # def kernel_wrapper(x: Tensor, y: Tensor) -> Tensor:
343+ # return matmul(x, y, epilogue)
344+
345+ # def baseline_wrapper(x: Tensor, y: Tensor) -> Tensor:
346+ # return torch.relu(x @ y + bias)
347+
348+ # run_example(
349+ # kernel_wrapper,
350+ # baseline_wrapper,
351+ # (x, y),
352+ # )
353+
354+ # # Test matmul forward + backward pass
355+ # print("\n\n=== MatMul Forward + Backward Pass Test ===")
356+ # x_grad = torch.randn([m, k], device=DEVICE, dtype=torch.float16, requires_grad=True)
357+ # y_grad = torch.randn([k, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
358+
359+ # run_example(
360+ # matmul_autograd,
361+ # torch.matmul,
362+ # (x_grad, y_grad),
363+ # kernel_name="helion_matmul_autograd",
364+ # baseline_name="torch",
365+ # rtol=1e-2,
366+ # atol=1e-2,
367+ # bwd=True,
368+ # )
369+
370+ # # Test addmm forward + backward pass
371+ # print("\n\n=== AddMM Forward + Backward Pass Test ===")
372+ # input_grad = torch.randn(
373+ # [m, n], device=DEVICE, dtype=torch.float16, requires_grad=True
374+ # )
375+ # mat1_grad = torch.randn(
376+ # [m, k], device=DEVICE, dtype=torch.float16, requires_grad=True
377+ # )
378+ # mat2_grad = torch.randn(
379+ # [k, n], device=DEVICE, dtype=torch.float16, requires_grad=True
380+ # )
381+
382+ # # Use lambda to handle the keyword argument format for torch.addmm
383+ # run_example(
384+ # addmm_autograd,
385+ # lambda bias, mat1, mat2, alpha, beta: torch.addmm(
386+ # bias, mat1, mat2, alpha=alpha, beta=beta
387+ # ),
388+ # (input_grad, mat1_grad, mat2_grad, 1.0, 1.0),
389+ # kernel_name="helion_addmm_autograd",
390+ # baseline_name="torch",
391+ # rtol=1e-2,
392+ # atol=1e-2,
393+ # bwd=True,
394+ # )
395+
396+ # # Test addmm forward + backward with different alpha/beta values
397+ # print("\n\n=== AddMM Forward + Backward Test (Alpha=2.0, Beta=0.5) ===")
398+ # run_example(
399+ # addmm_autograd,
400+ # lambda bias, mat1, mat2, alpha, beta: torch.addmm(
401+ # bias, mat1, mat2, alpha=alpha, beta=beta
402+ # ),
403+ # (input_grad, mat1_grad, mat2_grad, 2.0, 0.5),
404+ # kernel_name="helion_addmm_autograd_scaled",
405+ # baseline_name="torch",
406+ # rtol=1e-2,
407+ # atol=1e-2,
408+ # bwd=True,
409+ # )
392410
393411
394412# %%
0 commit comments