Skip to content

Commit 97af1f9

Browse files
Wrapper gemm to fix get_config lru cache break (ROCm#1249)
* wrapper gemm to fix get_config lru cache * use dtype
1 parent f700bb1 commit 97af1f9

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

aiter/ops/gemm_op_a8w8.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,14 +405,29 @@ def gemm_a8w8_bpreshuffle(
405405
return gemm_a8w8_bpreshuffle_ck(XQ, WQ, x_scale, w_scale, Y)
406406

407407

408-
def gemm_a8w8_blockscale(
408+
def gemm_a8w8_blockscale_fake(
409409
XQ: Tensor,
410410
WQ: Tensor,
411411
x_scale: Tensor,
412412
w_scale: Tensor,
413413
dtype=dtypes.bf16,
414414
isBpreshuffled=False,
415-
):
415+
) -> torch.Tensor:
416+
m = XQ.shape[0]
417+
n = WQ.shape[0]
418+
Y = torch.empty(m, n, dtype=dtype, device=XQ.device)
419+
return Y
420+
421+
422+
@torch_compile_guard(gen_fake=gemm_a8w8_blockscale_fake)
423+
def gemm_a8w8_blockscale(
424+
XQ: Tensor,
425+
WQ: Tensor,
426+
x_scale: Tensor,
427+
w_scale: Tensor,
428+
dtype: torch.dtype = dtypes.bf16,
429+
isBpreshuffled: bool = False,
430+
) -> torch.Tensor:
416431
assert dtype in [
417432
dtypes.bf16,
418433
dtypes.fp16,
@@ -427,7 +442,7 @@ def gemm_a8w8_blockscale(
427442
if get_gfx() in ["gfx950"] and m >= 16 and k >= 512 and dtype == dtypes.bf16:
428443
return mi350_a8w8_blockscale_ASM(XQ, WQ, x_scale, w_scale, Y)
429444
else:
430-
assert 0, f"asm kernel only support B preshuffle and m >= 16"
445+
assert 0, "asm kernel only support B preshuffle and m >= 16"
431446
else:
432447
get_CKGEMM_config(m, n, k, AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE)
433448
return gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Y)

0 commit comments

Comments
 (0)