From 49d8ac4b031bbc132668b7b7f090174ec646442a Mon Sep 17 00:00:00 2001 From: JackWeiw <51255903105@stu.ecnu.edu.cn> Date: Thu, 16 Jan 2025 13:29:08 +0800 Subject: [PATCH 1/5] [dlinfer]add w8a8 support --- .../pytorch/backends/dlinfer/op_backend.py | 6 ++ lmdeploy/pytorch/backends/dlinfer/qmodules.py | 97 +++++++++++++++++++ .../pytorch/kernels/dlinfer/w8a8_kernels.py | 49 ++++++++++ 3 files changed, 152 insertions(+) create mode 100644 lmdeploy/pytorch/backends/dlinfer/qmodules.py create mode 100644 lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py diff --git a/lmdeploy/pytorch/backends/dlinfer/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/op_backend.py index a0f04f34b..c8d1f9e51 100644 --- a/lmdeploy/pytorch/backends/dlinfer/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/op_backend.py @@ -52,6 +52,12 @@ def get_layer_impl_builder(cls, layer_type: OpType): elif layer_type == OpType.RotaryEmbedding: from .rotary_embedding import DlinferRotaryEmbeddingBuilder return DlinferRotaryEmbeddingBuilder + elif layer_type == OpType.LinearW8A8: + from .qmodules import DlinferLinearW8A8Builder + return DlinferLinearW8A8Builder + elif layer_type == OpType.RMSNormW8A8: + from .qmodules import DlinferRMSNormW8A8Builder + return DlinferRMSNormW8A8Builder else: logger.debug( f'Op {layer_type} fallback to default implementation.') diff --git a/lmdeploy/pytorch/backends/dlinfer/qmodules.py b/lmdeploy/pytorch/backends/dlinfer/qmodules.py new file mode 100644 index 000000000..409ff3363 --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/qmodules.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch + +from lmdeploy.pytorch.kernels.dlinfer.rms_norm import rms_norm +from lmdeploy.pytorch.kernels.dlinfer.w8a8_kernels import ( + per_token_quant_int8, smooth_quant_matmul) +from lmdeploy.pytorch.models.q_modules import QTensor + +from ..qmodules import (LinearW8A8Builder, LinearW8A8Impl, RMSNormW8A8Builder, + RMSNormW8A8Impl) + + +class DlinferRMSNormW8A8Impl(RMSNormW8A8Impl): + """Dlinfer RMS norm w8a8 implementation api.""" + + def __init__(self, + hidden_size: int, + eps: float = 1e-6, + quant_dtype: torch.dtype = torch.int8): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.quant_dtype = quant_dtype + + def forward(self, + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor = None): + """forward.""" + if residual is None: + (x, rms_scale) = rms_norm(x, weight, self.eps, None, + self.quant_dtype) + x = QTensor(x, rms_scale) + return x + else: + (x, rms_scale, residual) = rms_norm(x, weight, self.eps, residual, + self.quant_dtype) + x = QTensor(x, rms_scale) + return x, residual + + +class DlinferRMSNormW8A8Builder(RMSNormW8A8Builder): + """Dlinfer RMS norm w8a8 implementation builder.""" + + @staticmethod + def build(hidden_size: int, + eps: float = 1e-6, + quant_dtype: torch.dtype = torch.int8): + """build.""" + return DlinferRMSNormW8A8Impl(hidden_size, eps, quant_dtype) + + +class DlinferLinearW8A8Impl(LinearW8A8Impl): + """Dlinfer linear w8a8 implementation.""" + + def __init__( + self, + in_features: int, + out_features: int, + out_dtype: torch.dtype = torch.float16, + ): + self.in_features = in_features + self.out_features = out_features + self.out_dtype = out_dtype + + def forward(self, + x, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): + """forward.""" + + if isinstance(x, torch.Tensor): + input_quant, input_scale = per_token_quant_int8(x) + else: + assert isinstance(x, QTensor) + input_quant, input_scale = x.tensor, x.scale + + out = smooth_quant_matmul(input_quant, input_scale, weight, scale, + self.out_dtype, bias, all_reduce) + return out + + +class DlinferLinearW8A8Builder(LinearW8A8Builder): + """Dlinfer linear w8a8 implementation builder.""" + + @staticmethod + def build(in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + quant_dtype: torch.dtype = torch.int8): + """build.""" + return DlinferLinearW8A8Impl(in_features, out_features, dtype) diff --git a/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py b/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py new file mode 100644 index 000000000..955720463 --- /dev/null +++ b/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import dlinfer.ops as ext_ops +import torch +from torch import Tensor + + +def per_token_quant_int8(x): + """Function to perform per-token quantization on an input tensor `x`. + + It converts the tensor values into signed 8-bit integers and returns the + quantized tensor along with the scaling factor used for quantization. + """ + input_quant, input_scale = ext_ops.per_token_quant_int8(x) + return input_quant, input_scale + + +def smooth_quant_matmul( + a: Tensor, + a_scale: Optional[torch.Tensor], + b: Tensor, + b_scale: Optional[torch.Tensor], + out_dtype: torch.dtype, + bias: Tensor = None, + all_reduce: bool = False, +): + """This function performs matrix multiplication with dynamic quantization. + + It takes two input tensors `a` and `b`, scales them with `rms_scale` and + `linear_scale`, and optionally adds a `bias`. The output is returned in the + specified `output_dtype`. + """ + return ext_ops.smooth_quant_matmul(a, a_scale, b, b_scale, out_dtype, bias, + all_reduce) + + +def rms_norm_w8a8( + hidden_states: Tensor, + weight: Tensor, + epsilon: float, + residual: Tensor = None, +): + """rms norm kernel.""" + if residual is None: + return ext_ops.rms_norm_w8a8(hidden_states, weight, epsilon) + else: + return ext_ops.add_rms_norm_w8a8(hidden_states, residual, weight, + epsilon) From 6536371f858a742cc8fab84f7e3466c0af39d7fe Mon Sep 17 00:00:00 2001 From: JackWeiw <51255903105@stu.ecnu.edu.cn> Date: Thu, 16 Jan 2025 16:16:14 +0800 Subject: [PATCH 2/5] [dlinfer]modify w8a8 rmsnorm --- lmdeploy/pytorch/kernels/dlinfer/rms_norm.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py b/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py index 37e651e1b..5fd4237f4 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py +++ b/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py @@ -1,19 +1,24 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import dlinfer.ops as ext_ops -from torch import Tensor +from torch import Tensor, dtype def rms_norm(hidden_states: Tensor, weight: Tensor, epsilon: float = 1e-6, residual: Tensor = None, + quant_dtype: Optional[dtype] = None, out: Tensor = None): if residual is None: - rms_norm_out = ext_ops.rms_norm(hidden_states, weight, epsilon) + rms_norm_out = ext_ops.rms_norm(hidden_states, weight, epsilon, + quant_dtype) if out is None: out = rms_norm_out else: out.copy_(rms_norm_out) return out else: - return ext_ops.add_rms_norm(hidden_states, residual, weight, epsilon) + return ext_ops.add_rms_norm(hidden_states, residual, weight, epsilon, + quant_dtype) From 0c5bc2340a490c99109eae6686b6c0d623106839 Mon Sep 17 00:00:00 2001 From: JackWeiw <51255903105@stu.ecnu.edu.cn> Date: Tue, 21 Jan 2025 18:54:05 +0800 Subject: [PATCH 3/5] sovle conflict --- lmdeploy/pytorch/kernels/dlinfer/rms_norm.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py b/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py index 5fd4237f4..04e168e35 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py +++ b/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py @@ -1,24 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional - import dlinfer.ops as ext_ops -from torch import Tensor, dtype +from torch import Tensor -def rms_norm(hidden_states: Tensor, - weight: Tensor, - epsilon: float = 1e-6, - residual: Tensor = None, - quant_dtype: Optional[dtype] = None, - out: Tensor = None): +def rms_norm(hidden_states: Tensor, weight: Tensor, epsilon: float = 1e-6, residual: Tensor = None, out: Tensor = None): if residual is None: - rms_norm_out = ext_ops.rms_norm(hidden_states, weight, epsilon, - quant_dtype) + rms_norm_out = ext_ops.rms_norm(hidden_states, weight, epsilon) if out is None: out = rms_norm_out else: out.copy_(rms_norm_out) return out else: - return ext_ops.add_rms_norm(hidden_states, residual, weight, epsilon, - quant_dtype) + return ext_ops.add_rms_norm(hidden_states, residual, weight, epsilon) \ No newline at end of file From cc6bbf88c7093a723f3638c6e202c2bc3acfbdfd Mon Sep 17 00:00:00 2001 From: JackWeiw <51255903105@stu.ecnu.edu.cn> Date: Tue, 21 Jan 2025 19:00:33 +0800 Subject: [PATCH 4/5] sovle conflict --- lmdeploy/pytorch/backends/dlinfer/qmodules.py | 29 +++++-------------- lmdeploy/pytorch/kernels/dlinfer/rms_norm.py | 15 +++++++--- .../pytorch/kernels/dlinfer/w8a8_kernels.py | 15 ++++------ 3 files changed, 25 insertions(+), 34 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/qmodules.py b/lmdeploy/pytorch/backends/dlinfer/qmodules.py index 409ff3363..52fd4918e 100644 --- a/lmdeploy/pytorch/backends/dlinfer/qmodules.py +++ b/lmdeploy/pytorch/backends/dlinfer/qmodules.py @@ -4,39 +4,29 @@ import torch from lmdeploy.pytorch.kernels.dlinfer.rms_norm import rms_norm -from lmdeploy.pytorch.kernels.dlinfer.w8a8_kernels import ( - per_token_quant_int8, smooth_quant_matmul) +from lmdeploy.pytorch.kernels.dlinfer.w8a8_kernels import per_token_quant_int8, smooth_quant_matmul from lmdeploy.pytorch.models.q_modules import QTensor -from ..qmodules import (LinearW8A8Builder, LinearW8A8Impl, RMSNormW8A8Builder, - RMSNormW8A8Impl) +from ..qmodules import LinearW8A8Builder, LinearW8A8Impl, RMSNormW8A8Builder, RMSNormW8A8Impl class DlinferRMSNormW8A8Impl(RMSNormW8A8Impl): """Dlinfer RMS norm w8a8 implementation api.""" - def __init__(self, - hidden_size: int, - eps: float = 1e-6, - quant_dtype: torch.dtype = torch.int8): + def __init__(self, hidden_size: int, eps: float = 1e-6, quant_dtype: torch.dtype = torch.int8): super().__init__() self.hidden_size = hidden_size self.eps = eps self.quant_dtype = quant_dtype - def forward(self, - x: torch.Tensor, - weight: torch.Tensor, - residual: torch.Tensor = None): + def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None): """forward.""" if residual is None: - (x, rms_scale) = rms_norm(x, weight, self.eps, None, - self.quant_dtype) + (x, rms_scale) = rms_norm(x, weight, self.eps, None, self.quant_dtype) x = QTensor(x, rms_scale) return x else: - (x, rms_scale, residual) = rms_norm(x, weight, self.eps, residual, - self.quant_dtype) + (x, rms_scale, residual) = rms_norm(x, weight, self.eps, residual, self.quant_dtype) x = QTensor(x, rms_scale) return x, residual @@ -45,9 +35,7 @@ class DlinferRMSNormW8A8Builder(RMSNormW8A8Builder): """Dlinfer RMS norm w8a8 implementation builder.""" @staticmethod - def build(hidden_size: int, - eps: float = 1e-6, - quant_dtype: torch.dtype = torch.int8): + def build(hidden_size: int, eps: float = 1e-6, quant_dtype: torch.dtype = torch.int8): """build.""" return DlinferRMSNormW8A8Impl(hidden_size, eps, quant_dtype) @@ -79,8 +67,7 @@ def forward(self, assert isinstance(x, QTensor) input_quant, input_scale = x.tensor, x.scale - out = smooth_quant_matmul(input_quant, input_scale, weight, scale, - self.out_dtype, bias, all_reduce) + out = smooth_quant_matmul(input_quant, input_scale, weight, scale, self.out_dtype, bias, all_reduce) return out diff --git a/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py b/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py index 04e168e35..8c2b89744 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py +++ b/lmdeploy/pytorch/kernels/dlinfer/rms_norm.py @@ -1,15 +1,22 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import dlinfer.ops as ext_ops -from torch import Tensor +from torch import Tensor, dtype -def rms_norm(hidden_states: Tensor, weight: Tensor, epsilon: float = 1e-6, residual: Tensor = None, out: Tensor = None): +def rms_norm(hidden_states: Tensor, + weight: Tensor, + epsilon: float = 1e-6, + residual: Tensor = None, + quant_dtype: Optional[dtype] = None, + out: Tensor = None): if residual is None: - rms_norm_out = ext_ops.rms_norm(hidden_states, weight, epsilon) + rms_norm_out = ext_ops.rms_norm(hidden_states, weight, epsilon, quant_dtype) if out is None: out = rms_norm_out else: out.copy_(rms_norm_out) return out else: - return ext_ops.add_rms_norm(hidden_states, residual, weight, epsilon) \ No newline at end of file + return ext_ops.add_rms_norm(hidden_states, residual, weight, epsilon, quant_dtype) diff --git a/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py b/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py index 955720463..4a2f7f414 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py +++ b/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py @@ -9,8 +9,8 @@ def per_token_quant_int8(x): """Function to perform per-token quantization on an input tensor `x`. - It converts the tensor values into signed 8-bit integers and returns the - quantized tensor along with the scaling factor used for quantization. + It converts the tensor values into signed 8-bit integers and returns the quantized tensor along with the scaling + factor used for quantization. """ input_quant, input_scale = ext_ops.per_token_quant_int8(x) return input_quant, input_scale @@ -27,12 +27,10 @@ def smooth_quant_matmul( ): """This function performs matrix multiplication with dynamic quantization. - It takes two input tensors `a` and `b`, scales them with `rms_scale` and - `linear_scale`, and optionally adds a `bias`. The output is returned in the - specified `output_dtype`. + It takes two input tensors `a` and `b`, scales them with `rms_scale` and `linear_scale`, and optionally adds a + `bias`. The output is returned in the specified `output_dtype`. """ - return ext_ops.smooth_quant_matmul(a, a_scale, b, b_scale, out_dtype, bias, - all_reduce) + return ext_ops.smooth_quant_matmul(a, a_scale, b, b_scale, out_dtype, bias, all_reduce) def rms_norm_w8a8( @@ -45,5 +43,4 @@ def rms_norm_w8a8( if residual is None: return ext_ops.rms_norm_w8a8(hidden_states, weight, epsilon) else: - return ext_ops.add_rms_norm_w8a8(hidden_states, residual, weight, - epsilon) + return ext_ops.add_rms_norm_w8a8(hidden_states, residual, weight, epsilon) From f5cebfc4807d4562ce9db76526c4e0192b261457 Mon Sep 17 00:00:00 2001 From: JackWeiw <51255903105@stu.ecnu.edu.cn> Date: Wed, 22 Jan 2025 15:05:36 +0800 Subject: [PATCH 5/5] [dlinfer]rm useless w8a8rmsnorm --- lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py b/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py index 4a2f7f414..528d52e23 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py +++ b/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py @@ -31,16 +31,3 @@ def smooth_quant_matmul( `bias`. The output is returned in the specified `output_dtype`. """ return ext_ops.smooth_quant_matmul(a, a_scale, b, b_scale, out_dtype, bias, all_reduce) - - -def rms_norm_w8a8( - hidden_states: Tensor, - weight: Tensor, - epsilon: float, - residual: Tensor = None, -): - """rms norm kernel.""" - if residual is None: - return ext_ops.rms_norm_w8a8(hidden_states, weight, epsilon) - else: - return ext_ops.add_rms_norm_w8a8(hidden_states, residual, weight, epsilon)