Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sharktank] restore custom matmul kernel #896

Merged
merged 35 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8300bc8
restore custom matmul kernel
dan-garvey Feb 2, 2025
a98a332
not mergeable as-is
dan-garvey Feb 2, 2025
80fee98
Make batch_matmul_transpose_b accept accumulation dtype
sogartar Feb 4, 2025
9ff020e
Merge batch_matmul_transpose_b export tests into 1
sogartar Feb 4, 2025
7b93a6a
Merge branch 'main' into users/dan-garvey/enable_custom_fp8_matmul
dan-garvey Feb 7, 2025
781c8e8
Add exception to qlinear to not use the kernel when unsigned ints
sogartar Feb 8, 2025
9f1c3d4
Small fix
sogartar Feb 8, 2025
82b032a
Add eager execution to circamvent failure to compile for llvm-cpu
sogartar Feb 10, 2025
aa5c7b0
Merge branch 'main' into users/dan-garvey/enable_custom_fp8_matmul
sogartar Feb 11, 2025
de70094
Convert dtype when writing into the cache
sogartar Feb 11, 2025
ae89b55
Fix attention_dtype flag for paged_llm_v1
aviator19941 Feb 13, 2025
53f8cd1
Merge branch 'main' into users/dan-garvey/enable_custom_fp8_matmul
sogartar Feb 13, 2025
5bf4636
KV cache workaround for Torch not supporting torch.Tensor.index_copy_…
sogartar Feb 14, 2025
fe5c881
Fix kv_cache index_put_ issue
archana-ramalingam Feb 14, 2025
b4be2a8
Revert "Fix kv_cache index_put_ issue"
archana-ramalingam Feb 14, 2025
338fe67
Merge branch 'main' into users/dan-garvey/enable_custom_fp8_matmul
archana-ramalingam Feb 14, 2025
462ddc4
Fix KV cache index_copy_ f8 workaround
sogartar Feb 14, 2025
4dc2ac2
In linear for (Tensor, QuantizedTensor) raise if accum_dtype is given
sogartar Feb 14, 2025
13bfc68
Fix KV cache f8
sogartar Feb 20, 2025
8b20445
Remove unused HF dataset
sogartar Feb 20, 2025
9f160f1
Add KV cache dtype different from attention dtype
sogartar Feb 20, 2025
77a8443
Add more KV cache tests for various dtypes
sogartar Feb 20, 2025
1ea608a
Remove some unwanted corner casehandlings in linear layer
sogartar Feb 20, 2025
6f0c98b
Add more linear layer tests
sogartar Feb 20, 2025
664a847
Refactor quark parity test to use tmp dir
sogartar Feb 20, 2025
9816c35
Fix KV cache dtype CLI arg parsing
sogartar Feb 20, 2025
b8ff8cc
Merge remote-tracking branch 'origin/main' into users/dan-garvey/enab…
sogartar Feb 20, 2025
c17629e
Change doc example to not use the removed Llama dataset
sogartar Feb 20, 2025
9b7dfdf
Add KV cache dtype to benchmark
sogartar Feb 20, 2025
55c8701
Change testBenchmark8B_fp8_Non_Decomposed xfail reason to compilation…
sogartar Feb 21, 2025
fea5204
Merge branch 'main' into users/dan-garvey/enable_custom_fp8_matmul
archana-ramalingam Feb 21, 2025
740bb80
Put back in the llama3_8B_fp16 HF dataset
sogartar Feb 21, 2025
02215d5
Remove left behind comment
sogartar Feb 21, 2025
40f993a
Make quark parity test use f8 KV cache
sogartar Feb 21, 2025
85053ef
Add more bf16 qlinear tests and make ref dtype be f64
sogartar Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions sharktank/sharktank/kernels/batch_matmul_transpose_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,36 @@
from sharktank.kernels.base import *

import torch
from typing import cast, Optional

from iree.compiler.ir import IntegerType
from iree.compiler.ir import IntegerType, Type
from iree.turbine.support.conversions import (
TORCH_DTYPE_TO_IREE_TYPE_ASM,
IREE_TYPE_ASM_TO_TORCH_DTYPE,
)
from iree.turbine.runtime.op_reg import AttrArg

__all__ = [
"batch_matmul_transpose_b",
]


def batch_matmul_transpose_b(
lhs: torch.Tensor,
rhs: torch.Tensor,
/,
*,
accum_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if accum_dtype is None:
accum_dtype = lhs.dtype
return _batch_matmul_transpose_b(
lhs, rhs, accum_dtype=TORCH_DTYPE_TO_IREE_TYPE_ASM[accum_dtype]
)


@CustomOp.register(library=LIBRARY)
class batch_matmul_transpose_b(CustomOp):
class _batch_matmul_transpose_b(CustomOp):
"""Generic block scaled matmul with transposed RHS.

The LHS is expected to be a 3d tensor of shape [B, M, K]. RHS must be
Expand All @@ -25,11 +45,14 @@ class batch_matmul_transpose_b(CustomOp):
The kernel will be specialized for all values of N, K and LHS dtype.
"""

signature = "batch_matmul_transpose_b(Tensor lhs, Tensor rhs) -> (Tensor)"
signature = (
"batch_matmul_transpose_b(Tensor lhs, Tensor rhs, str accum_dtype) -> (Tensor)"
)

def select(self, ksel: KernelSelection):
lhs_desc = ksel.arg_tensor(0) # Shape [B, M, K]
rhs_desc = ksel.arg_tensor(1) # Shape [B, N, K]
accum_type_attr = ksel.attr_str(2)

# Rank check.
torch._check(
Expand Down Expand Up @@ -60,7 +83,8 @@ def select(self, ksel: KernelSelection):
)
# Shape batch, m, n
c_desc = ksel.return_new_tensor(
[lhs_batch, lhs_m, rhs_n], dtype=lhs_desc.t.dtype
[lhs_batch, lhs_m, rhs_n],
dtype=IREE_TYPE_ASM_TO_TORCH_DTYPE[accum_type_attr.v],
)
specialize_all_known_dims(lhs_desc)
specialize_all_known_dims(rhs_desc)
Expand All @@ -74,12 +98,14 @@ def select(self, ksel: KernelSelection):
def generate(self, ksel: KernelSelection, kb: KernelBuilder):
lhs = kb.arg_value(0)
rhs = kb.arg_value(1)
accum_type_str = cast(AttrArg, ksel.arg_descs[2]).v
result_desc = ksel.result_descs[0]

# Generate specialization signature and types.
a_asm_type, a_ident, accum_type = unpack_tensor_type(lhs.type)
a_asm_type, a_ident, _ = unpack_tensor_type(lhs.type)
b_asm_type, b_ident, _ = unpack_tensor_type(rhs.type)
spec_sig = f"L{a_ident}_R{b_ident}"
accum_type = Type.parse(accum_type_str)
spec_sig = f"L{a_ident}_R{b_ident}_{accum_type_str}"
template_file = "batch_matmul_transpose_b.mlir"
target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}"
cst_zero = "0" if IntegerType.isinstance(accum_type) else "0."
Expand Down
4 changes: 2 additions & 2 deletions sharktank/sharktank/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def forward(self, x):
# We can truncate to fp16 in iree, so we do a cast here
# to account for this in the IR. This is may not be the right
# level to do this, but for now its here.
if not isinstance(y, QuantizedTensor):
if y.dtype == torch.float8_e4m3fnuz:
if not isinstance(y, QuantizedTensor) and isinstance(x, QuantizedTensor):
if x.unpack().qs.dtype == torch.float8_e4m3fnuz:
y = ops.to(y, torch.bfloat16)
return y
if qdq_output is not None:
Expand Down
30 changes: 22 additions & 8 deletions sharktank/sharktank/ops/qlinear_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def qlinear_tensor_scaled(

# Handle only integer and fp8 quantizations.
if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point:
if x_layout.qs.dtype == torch.float8_e4m3fnuz:
# assume quark
return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True)
else:
if (
x_layout.qs.dtype != torch.float8_e4m3fnuz
or weight_layout.qs.dtype != torch.float8_e4m3fnuz
):
return NotImplemented

# Bias.
Expand Down Expand Up @@ -170,7 +170,13 @@ def linear_quantized_weight(
linear.override(Tensor, QuantizedTensor, AnyTensor)(linear_quantized_weight)


def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype):
def _is_dtype_unsigned_integer(dtype: torch.dtype):
return not dtype.is_complex and not dtype.is_floating_point and not dtype.is_signed


def _invoke_mmt_kernel(
lhs: torch.Tensor, rhs: torch.Tensor, *, accum_dtype: torch.dtype
):
if debugging.flags.use_custom_iree_kernels:
# The custom kernel requires that the lhs and rhs be the same
# rank. Broadcast the rhs to match.
Expand All @@ -187,9 +193,17 @@ def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype):
rhs_size = [lhs.shape[0]] + list(rhs.shape)
rhs = rhs.unsqueeze(0).expand(rhs_size)
rhs_rank = len(rhs.shape)
y_qs = kernels.batch_matmul_transpose_b(
lhs.to(accum_dtype), rhs.to(accum_dtype)
)
if (
_is_dtype_unsigned_integer(lhs.dtype)
or _is_dtype_unsigned_integer(rhs.dtype)
or _is_dtype_unsigned_integer(accum_dtype)
):
# TODO: make the kernel work with unsigned types.
y_qs = kernels.batch_matmul_transpose_b(
lhs.to(dtype=accum_dtype), rhs.to(dtype=accum_dtype)
)
else:
y_qs = kernels.batch_matmul_transpose_b(lhs, rhs, accum_dtype=accum_dtype)
# Squeeze the batch dimension to maintain shape parity with other
# layers.
if len(y_qs.shape) > 2:
Expand Down
109 changes: 102 additions & 7 deletions sharktank/tests/kernels/batch_matmul_transpose_b_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@

import unittest
from parameterized import parameterized

import pytest
import torch

from iree.turbine import aot
from iree.turbine.support.conversions import TORCH_DTYPE_TO_IREE_TYPE_ASM
from sharktank import kernels
from sharktank.utils.testing import skip


class batch_matmul_transpose_b_test(unittest.TestCase):
Expand All @@ -40,24 +42,117 @@ def testBS32(self, atol, rtol):
ref = torch.matmul(a, bT)
torch.testing.assert_close(result, ref, atol=atol, rtol=rtol)

def testExportStaticDims(self):
@pytest.mark.xfail(
reason="""Does not compile for llvm-cpu with
<unknown>:0: error: 'llvm.fpext' op operand #0 must be floating point LLVM type or LLVM dialect-compatible vector of floating point LLVM type, but got 'vector<4xi8>'
<unknown>:0: note: see current operation: %120 = "llvm.fpext"(%109) : (vector<4xi8>) -> vector<4xf32>
"""
)
def testArgF8AccumF32(self):
arg_dtype = torch.float8_e4m3fnuz
a = torch.rand([3, 4, 6]).to(arg_dtype)
b = torch.rand([3, 5, 6]).to(arg_dtype)
accum_dtype = torch.float32
result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype)

# Dequantize and test with normal matmul.
# Tolerances are empirical and results are not expected to match exactly.
bT = torch.transpose(b, 1, 2)
ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype))
torch.testing.assert_close(result, ref, atol=1e-3, rtol=0)

@pytest.mark.xfail(
reason="Does not work with unsigned types. The kernel needs to be adapted."
)
def testArgUi8AccumI32(self):
arg_dtype = torch.uint8
a = ((torch.rand([2, 3, 5]) * 255) + 0.5).to(dtype=arg_dtype)
b = ((torch.rand([2, 4, 5]) * 255) + 0.5).to(dtype=arg_dtype)
accum_dtype = torch.int32
result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype)

bT = torch.transpose(b, 1, 2)
ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype))
torch.testing.assert_close(result, ref, atol=0, rtol=0)

@pytest.mark.xfail(
reason="Does not work with unsigned types. The kernel needs to be adapted."
)
def testArgLhsI8RhsUi8AccumI32(self):
a = ((torch.rand([2, 3, 5]) - 0.5) * 255).to(dtype=torch.int8)
b = ((torch.rand([2, 4, 5]) * 255) + 0.5).to(dtype=torch.uint8)
accum_dtype = torch.int32
result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype)

bT = torch.transpose(b, 1, 2)
ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype))
torch.testing.assert_close(result, ref, atol=0, rtol=0)

def testArgI8AccumI32(self):
arg_dtype = torch.int8
a = ((torch.rand([2, 3, 5]) - 0.5) * 255).to(dtype=arg_dtype)
b = ((torch.rand([2, 3, 5]) - 0.5) * 255).to(dtype=arg_dtype)
accum_dtype = torch.int32
result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype)

bT = torch.transpose(b, 1, 2)
ref = torch.matmul(a.to(dtype=accum_dtype), bT.to(dtype=accum_dtype))
torch.testing.assert_close(result, ref, atol=0, rtol=0)

@pytest.mark.xfail(
reason="""No uint32 dtype conversions in IREE Turbine.
Does not work with unsigned types. The kernel needs to be adapted.
The problem is that we reinterpret cast to signless integer types.
Maybe linalg.batch_matmul_transpose_b when promoting from i8 to i32 assumes a
signed type even though i8 is signless."""
)
def testArgUi8AccumUi32(self):
arg_dtype = torch.uint8
a = ((torch.rand([2, 3, 5]) * 255) + 0.5).to(dtype=arg_dtype)
b = ((torch.rand([2, 4, 5]) * 255) + 0.5).to(dtype=arg_dtype)
accum_dtype = torch.uint32
result = kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype)

bT = torch.transpose(b, 1, 2)
ref = torch.matmul(a.to(dtype=torch.int32), bT.to(dtype=torch.int32))
ref = ref.to(dtype=accum_dtype)
torch.testing.assert_close(result, ref, atol=0, rtol=0)

@parameterized.expand(
[
(torch.int32, None),
(torch.float8_e4m3fnuz, torch.float32),
]
)
def testExportStaticDims(
self, arg_dtype: torch.dtype, accum_dtype: torch.dtype | None
):
class MyModule(torch.nn.Module):
def forward(self, a, b):
return kernels.batch_matmul_transpose_b(a, b)
return kernels.batch_matmul_transpose_b(a, b, accum_dtype=accum_dtype)

mod = MyModule()
dtype = torch.int32
ep = torch.export.export(
mod,
args=(
(torch.rand([4, 16, 2]) * 64).to(dtype),
(torch.rand([4, 8, 2]) * 64).to(dtype),
(torch.rand([4, 16, 2]) * 64).to(arg_dtype),
(torch.rand([4, 8, 2]) * 64).to(arg_dtype),
),
)
output = aot.export(ep)
output.verify()
asm = str(output.mlir_module)
self.assertIn("@sharktank_batch_matmul_transpose_b_L4x16x2xi32_R4x8x2xi32", asm)
arg_dtype_asm = TORCH_DTYPE_TO_IREE_TYPE_ASM[arg_dtype]
accum_dtype_asm = arg_dtype_asm
if accum_dtype is not None:
accum_dtype_asm = TORCH_DTYPE_TO_IREE_TYPE_ASM[accum_dtype]
self.assertIn(
(
"@sharktank_batch_matmul_transpose_b_"
f"L4x16x2x{arg_dtype_asm}_R4x8x2x{arg_dtype_asm}_{accum_dtype_asm}"
),
asm,
)


if __name__ == "__main__":
Expand Down
Loading