@@ -405,14 +405,29 @@ def gemm_a8w8_bpreshuffle(
405405 return gemm_a8w8_bpreshuffle_ck (XQ , WQ , x_scale , w_scale , Y )
406406
407407
408- def gemm_a8w8_blockscale (
408+ def gemm_a8w8_blockscale_fake (
409409 XQ : Tensor ,
410410 WQ : Tensor ,
411411 x_scale : Tensor ,
412412 w_scale : Tensor ,
413413 dtype = dtypes .bf16 ,
414414 isBpreshuffled = False ,
415- ):
415+ ) -> torch .Tensor :
416+ m = XQ .shape [0 ]
417+ n = WQ .shape [0 ]
418+ Y = torch .empty (m , n , dtype = dtype , device = XQ .device )
419+ return Y
420+
421+
422+ @torch_compile_guard (gen_fake = gemm_a8w8_blockscale_fake )
423+ def gemm_a8w8_blockscale (
424+ XQ : Tensor ,
425+ WQ : Tensor ,
426+ x_scale : Tensor ,
427+ w_scale : Tensor ,
428+ dtype : torch .dtype = dtypes .bf16 ,
429+ isBpreshuffled : bool = False ,
430+ ) -> torch .Tensor :
416431 assert dtype in [
417432 dtypes .bf16 ,
418433 dtypes .fp16 ,
@@ -427,7 +442,7 @@ def gemm_a8w8_blockscale(
427442 if get_gfx () in ["gfx950" ] and m >= 16 and k >= 512 and dtype == dtypes .bf16 :
428443 return mi350_a8w8_blockscale_ASM (XQ , WQ , x_scale , w_scale , Y )
429444 else :
430- assert 0 , f "asm kernel only support B preshuffle and m >= 16"
445+ assert 0 , "asm kernel only support B preshuffle and m >= 16"
431446 else :
432447 get_CKGEMM_config (m , n , k , AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE )
433448 return gemm_a8w8_blockscale_ck (XQ , WQ , x_scale , w_scale , Y )
0 commit comments