diff --git a/lmdeploy/pytorch/backends/dlinfer/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/op_backend.py index 6d95a8d91..ff576bf74 100644 --- a/lmdeploy/pytorch/backends/dlinfer/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/op_backend.py @@ -52,6 +52,12 @@ def get_layer_impl_builder(cls, layer_type: OpType): elif layer_type == OpType.RotaryEmbedding: from .rotary_embedding import DlinferRotaryEmbeddingBuilder return DlinferRotaryEmbeddingBuilder + elif layer_type == OpType.LinearW8A8: + from .qmodules import DlinferLinearW8A8Builder + return DlinferLinearW8A8Builder + elif layer_type == OpType.RMSNormW8A8: + from .qmodules import DlinferRMSNormW8A8Builder + return DlinferRMSNormW8A8Builder else: logger.debug(f'Op {layer_type} fallback to default implementation.') return super().get_layer_impl_builder(layer_type) diff --git a/lmdeploy/pytorch/backends/dlinfer/qmodules.py b/lmdeploy/pytorch/backends/dlinfer/qmodules.py new file mode 100644 index 000000000..52fd4918e --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/qmodules.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch + +from lmdeploy.pytorch.kernels.dlinfer.rms_norm import rms_norm +from lmdeploy.pytorch.kernels.dlinfer.w8a8_kernels import per_token_quant_int8, smooth_quant_matmul +from lmdeploy.pytorch.models.q_modules import QTensor + +from ..qmodules import LinearW8A8Builder, LinearW8A8Impl, RMSNormW8A8Builder, RMSNormW8A8Impl + + +class DlinferRMSNormW8A8Impl(RMSNormW8A8Impl): + """Dlinfer RMS norm w8a8 implementation api.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6, quant_dtype: torch.dtype = torch.int8): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.quant_dtype = quant_dtype + + def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None): + """forward.""" + if residual is None: + (x, rms_scale) = rms_norm(x, weight, self.eps, None, self.quant_dtype) + x = QTensor(x, rms_scale) + return x + else: + (x, rms_scale, residual) = rms_norm(x, weight, self.eps, residual, self.quant_dtype) + x = QTensor(x, rms_scale) + return x, residual + + +class DlinferRMSNormW8A8Builder(RMSNormW8A8Builder): + """Dlinfer RMS norm w8a8 implementation builder.""" + + @staticmethod + def build(hidden_size: int, eps: float = 1e-6, quant_dtype: torch.dtype = torch.int8): + """build.""" + return DlinferRMSNormW8A8Impl(hidden_size, eps, quant_dtype) + + +class DlinferLinearW8A8Impl(LinearW8A8Impl): + """Dlinfer linear w8a8 implementation.""" + + def __init__( + self, + in_features: int, + out_features: int, + out_dtype: torch.dtype = torch.float16, + ): + self.in_features = in_features + self.out_features = out_features + self.out_dtype = out_dtype + + def forward(self, + x, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): + """forward.""" + + if isinstance(x, torch.Tensor): + input_quant, input_scale = per_token_quant_int8(x) + else: + assert isinstance(x, QTensor) + input_quant, input_scale = x.tensor, x.scale + + out = smooth_quant_matmul(input_quant, input_scale, weight, scale, self.out_dtype, bias, all_reduce) + return out + + +class DlinferLinearW8A8Builder(LinearW8A8Builder): + """Dlinfer linear w8a8 implementation builder.""" + + @staticmethod + def build(in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + quant_dtype: torch.dtype = torch.int8): + """build.""" + return DlinferLinearW8A8Impl(in_features, out_features, dtype) diff --git a/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py b/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py index 0a0f06491..8c2b89744 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py +++ b/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py @@ -1,15 +1,22 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import dlinfer.ops as ext_ops -from torch import Tensor +from torch import Tensor, dtype -def rms_norm(hidden_states: Tensor, weight: Tensor, epsilon: float = 1e-6, residual: Tensor = None, out: Tensor = None): +def rms_norm(hidden_states: Tensor, + weight: Tensor, + epsilon: float = 1e-6, + residual: Tensor = None, + quant_dtype: Optional[dtype] = None, + out: Tensor = None): if residual is None: - rms_norm_out = ext_ops.rms_norm(hidden_states, weight, epsilon) + rms_norm_out = ext_ops.rms_norm(hidden_states, weight, epsilon, quant_dtype) if out is None: out = rms_norm_out else: out.copy_(rms_norm_out) return out else: - return ext_ops.add_rms_norm(hidden_states, residual, weight, epsilon) + return ext_ops.add_rms_norm(hidden_states, residual, weight, epsilon, quant_dtype) diff --git a/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py b/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py new file mode 100644 index 000000000..528d52e23 --- /dev/null +++ b/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import dlinfer.ops as ext_ops +import torch +from torch import Tensor + + +def per_token_quant_int8(x): + """Function to perform per-token quantization on an input tensor `x`. + + It converts the tensor values into signed 8-bit integers and returns the quantized tensor along with the scaling + factor used for quantization. + """ + input_quant, input_scale = ext_ops.per_token_quant_int8(x) + return input_quant, input_scale + + +def smooth_quant_matmul( + a: Tensor, + a_scale: Optional[torch.Tensor], + b: Tensor, + b_scale: Optional[torch.Tensor], + out_dtype: torch.dtype, + bias: Tensor = None, + all_reduce: bool = False, +): + """This function performs matrix multiplication with dynamic quantization. + + It takes two input tensors `a` and `b`, scales them with `rms_scale` and `linear_scale`, and optionally adds a + `bias`. The output is returned in the specified `output_dtype`. + """ + return ext_ops.smooth_quant_matmul(a, a_scale, b, b_scale, out_dtype, bias, all_reduce)