diff --git a/docs/en/index.rst b/docs/en/index.rst index b64c230cb8..3eefbd90c6 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -80,6 +80,7 @@ Documentation quantization/w4a16.md quantization/w8a8.md quantization/kv_quant.md + quantization/blocked_fp8.md .. _benchmark: .. toctree:: diff --git a/docs/en/quantization/blocked_fp8.md b/docs/en/quantization/blocked_fp8.md new file mode 100644 index 0000000000..6009bacee4 --- /dev/null +++ b/docs/en/quantization/blocked_fp8.md @@ -0,0 +1,59 @@ +# Blocked FP8 Quantization + +LMDeploy supports a weight-only blocked FP8 quantization method. This approach quantizes the weights of a model to 8-bit floating-point numbers in a blocked format, which can reduce the model's memory footprint while maintaining good performance on supported hardware. + +Before proceeding, please ensure that lmdeploy is installed by following the [installation guide](../get_started/installation.md). A typical installation command is: + +```shell +pip install lmdeploy[all] +``` + +## Quantization + +A single command is all that is needed to perform blocked FP8 quantization. The script will load the model, quantize the linear layers to blocked FP8, and save the resulting model and configuration to the specified working directory. + +The command for this is `lmdeploy lite blocked_fp8`. + +Here is an example of how to quantize `OpenGVLab/InternVL3_5-8B`: + +```shell +export HF_MODEL=OpenGVLab/InternVL3_5-8B +export WORK_DIR=OpenGVLab/InternVL3_5-8B-FP8 + +lmdeploy lite blocked_fp8 $HF_MODEL \ + --work-dir $WORK_DIR \ + --quant-dtype fp8 \ + --block-size 128 +``` + +Key arguments for the command: + +- `--work-dir`: The directory where the quantized model weights and configuration will be saved. +- `--quant-dtype`: The target FP8 format. Can be `float8_e4m3fn` (same as passing 'fp8', recommended) or `float8_e5m2`. +- `--block-size`: The block size for quantization. The default of `128` is generally a good choice. + +## Inference + +You can perform batched offline inference with the quantized model using both the `turbomind` and `pytorch` backend. + +Here is a simple code example: + +```python +from lmdeploy import pipeline + +pipe = pipeline("OpenGVLab/InternVL3_5-8B-FP8") +response = pipe(["Hi, pls intro yourself", "Shanghai is"]) +print(response) +``` + +## Service + +LMDeploy's `api_server` can be used to serve the blocked FP8 model. + +```shell +lmdeploy serve api_server OpenGVLab/InternVL3_5-8B-FP8 +``` + +The default port for the `api_server` is `23333`. + +You can view the available API endpoints through the Swagger UI at `http://0.0.0.0:23333`. For more details on the API, please refer to the [API Server documentation](../llm/api_server.md). diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index bd946ba96e..ddb01449c3 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -81,6 +81,7 @@ LMDeploy 工具箱提供以下核心功能: quantization/w4a16.md quantization/w8a8.md quantization/kv_quant.md + quantization/blocked_fp8.md .. _测试基准: .. toctree:: diff --git a/docs/zh_cn/quantization/blocked_fp8.md b/docs/zh_cn/quantization/blocked_fp8.md new file mode 100644 index 0000000000..f60b5a94ec --- /dev/null +++ b/docs/zh_cn/quantization/blocked_fp8.md @@ -0,0 +1,59 @@ +# Blocked FP8 模型量化 + +LMDeploy 支持一种仅权重的 (weight-only) Blocked FP8 量化方法。该方法将模型的权重以分块(blocked)的形式量化为 8-bit 浮点数,可以在支持的硬件上保持良好性能的同时,有效降低模型的显存占用。 + +在进行量化和推理之前,请确保按照[安装指南](../get_started/installation.md)安装了 lmdeploy。 + +```shell +pip install lmdeploy[all] +``` + +## 模型量化 + +仅需执行一条命令,就可以完成模型量化工作。该脚本会加载模型,将线性层量化为 Blocked FP8 格式,并将最终的模型和配置文件保存在指定的工作目录中。 + +使用的命令是 `lmdeploy lite blocked_fp8`。 + +以下是如何量化 `OpenGVLab/InternVL3_5-8B` 的示例: + +```shell +export HF_MODEL=OpenGVLab/InternVL3_5-8B +export WORK_DIR=OpenGVLab/InternVL3_5-8B-FP8 + +lmdeploy lite blocked_fp8 $HF_MODEL \ + --work-dir $WORK_DIR \ + --quant-dtype fp8 \ + --block-size 128 +``` + +命令行的主要参数说明: + +- `--work-dir`: 用于保存量化后的模型权重和配置的工作目录。 +- `--quant-dtype`: 目标 FP8 格式。可以是 `float8_e4m3fn` (与传入 'fp8' 效果相同,推荐) 或 `float8_e5m2`。 +- `--block-size`: 量化的块大小。默认值 `128` 通常是一个不错的选择。 + +## 模型推理 + +您可以使用 `turbomind` 和 `pytorch` 后端对量化后的模型进行批量离线推理。 + +这是一个简单的代码示例: + +```python +from lmdeploy import pipeline + +pipe = pipeline("OpenGVLab/InternVL3_5-8B-FP8") +response = pipe(["Hi, pls intro yourself", "Shanghai is"]) +print(response) +``` + +## 推理服务 + +LMDeploy 的 `api_server` 可用于服务化部署 Blocked FP8 模型。 + +```shell +lmdeploy serve api_server OpenGVLab/InternVL3_5-8B-FP8 +``` + +服务的默认端口是 `23333`。 + +您可以通过 Swagger UI `http://0.0.0.0:23333` 在线阅读和试用 `api_server` 的各接口,也可直接查阅[文档](../llm/api_server.md),了解各接口的定义和使用方法。 diff --git a/lmdeploy/cli/lite.py b/lmdeploy/cli/lite.py index 768ef47544..e5d2fe11b0 100644 --- a/lmdeploy/cli/lite.py +++ b/lmdeploy/cli/lite.py @@ -103,6 +103,28 @@ def add_parser_smooth_quant(): ArgumentHelper.revision(parser) ArgumentHelper.download_dir(parser) + @staticmethod + def add_parser_blocked_fp8(): + """Add parser for blocked_fp8 command.""" + parser = SubCliLite.subparsers.add_parser('blocked_fp8', + formatter_class=DefaultsAndTypesHelpFormatter, + description=SubCliLite.blocked_fp8.__doc__, + help=SubCliLite.blocked_fp8.__doc__) + parser.set_defaults(run=SubCliLite.blocked_fp8) + parser.add_argument('model', type=str, help='The name or path of the model to be loaded') + parser.add_argument('--work-dir', + type=str, + default='./work_dir', + help='The working directory for outputs. defaults to "./work_dir"') + parser.add_argument('--quant-dtype', + type=str, + default='float8_e4m3fn', + choices=['fp8', 'float8_e4m3fn', 'float8_e5m2'], + help='The quantization data type for weight') + parser.add_argument('--block-size', type=int, default=128, help='Block size for blocked-fp8 quantization') + ArgumentHelper.revision(parser) + ArgumentHelper.download_dir(parser) + @staticmethod def auto_awq(args): """Perform weight quantization using AWQ algorithm.""" @@ -117,6 +139,13 @@ def auto_gptq(args): kwargs = convert_args(args) auto_gptq(**kwargs) + @staticmethod + def blocked_fp8(args): + """Perform weight quantization to blocked fp8 format.""" + from lmdeploy.lite.apis.blocked_fp8 import blocked_fp8 + kwargs = convert_args(args) + blocked_fp8(**kwargs) + @staticmethod def calibrate(args): """Perform calibration on a given dataset.""" @@ -138,3 +167,4 @@ def add_parsers(): SubCliLite.add_parser_auto_gptq() SubCliLite.add_parser_calibrate() SubCliLite.add_parser_smooth_quant() + SubCliLite.add_parser_blocked_fp8() diff --git a/lmdeploy/lite/apis/blocked_fp8.py b/lmdeploy/lite/apis/blocked_fp8.py new file mode 100644 index 0000000000..1736c6af9c --- /dev/null +++ b/lmdeploy/lite/apis/blocked_fp8.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import os +import os.path as osp +from typing import Literal + +import fire +import torch +from torch import nn +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +from lmdeploy.lite.quantization.weight.quant_utils import quant_blocked_fp8 +from lmdeploy.lite.utils import collect_target_modules +from lmdeploy.pytorch.models import QLinear + + +def blocked_fp8(model: str, + work_dir: str = './work_dir', + quant_dtype: Literal['fp8', 'float8_e4m3fn', 'float8_e5m2'] = 'float8_e4m3fn', + block_size: int = 128, + revision: str = None, + download_dir: str = None): + if quant_dtype == 'fp8': + quant_dtype = 'float8_e4m3fn' + + q_dtype = getattr(torch, quant_dtype, None) + assert q_dtype is not None + + if not osp.exists(model): + print(f'can\'t find model from local_path {model}, ' + 'try to download from remote') + from lmdeploy.utils import get_model + model_path = get_model(model, revision=revision, download_dir=download_dir) + else: + model_path = model + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, dtype=torch.bfloat16) + model = model.eval().cuda() + + # collect all linear layers + fcs = collect_target_modules(model, nn.Linear) + skip_patterns = [ + 'lm_head', + 'embed_tokens', + 'mlp.gate', # sparse MOE router gate + 'vision_model', # non-HF InternVL, vision part + 'mlp1', # non-HF InternVL, projector + 'mlp2', # non-HF InternVL-Flash, projector + 'vision_tower', # HF InternVL, vision part + 'multi_modal_projector', # HF InternVL, projector + ] + modules_to_not_convert = [] + + # quantize and replace linear layers + for name, linear in tqdm(fcs.items(), desc='Quantizing'): + # skip not to convert modules + if any([x in name for x in skip_patterns]): + modules_to_not_convert.append(name) + continue + + linear.to('cuda') + # quantize weight + q_weight, scales = quant_blocked_fp8(weight=linear.weight, fp8_dtype=q_dtype, block_size=block_size) + + # create and replace with QLinear + q_linear = QLinear.from_float(linear, quant_dtype=q_dtype, initialization=False) + q_linear.weight.data = q_weight + q_linear.weight_scale_inv.data = scales + if linear.bias is not None: + q_linear.bias.data = linear.bias.detach() + parent_name, _, child_name = name.rpartition('.') + parent = model.get_submodule(parent_name) + setattr(parent, child_name, q_linear) + + # move original layer to CPU to free GPU memory + linear.to('cpu') + torch.cuda.empty_cache() + + model.to('cpu') + + # update model config + if quant_dtype == 'float8_e4m3fn': + fmt = 'e4m3' + elif quant_dtype == 'float8_e5m2': + fmt = 'e5m2' + quant_config = dict(activation_scheme='dynamic', + modules_to_not_convert=modules_to_not_convert, + fmt=fmt, + quant_method='fp8', + weight_block_size=[block_size, block_size]) + model.config.update(dict(quantization_config=quant_config)) + + # save model and tokenizer + if not osp.exists(work_dir): + os.makedirs(work_dir) + print('Saving the quantized model ...') + model.save_pretrained(work_dir, safe_serialization=True) + tokenizer.save_pretrained(work_dir) + print(f'Blocked FP8 model successfully saved to {work_dir}') + + +if __name__ == '__main__': + fire.Fire(blocked_fp8) diff --git a/lmdeploy/pytorch/configurations/internvl.py b/lmdeploy/pytorch/configurations/internvl.py index a4e9812a0e..923e8bf82f 100644 --- a/lmdeploy/pytorch/configurations/internvl.py +++ b/lmdeploy/pytorch/configurations/internvl.py @@ -12,7 +12,10 @@ def condition(cls, hf_config): @classmethod def build(cls, hf_config, model_path: str = None, **kwargs): - """Build llava hf.""" + """Build internvl hf.""" + # hack quantization_config + if hasattr(hf_config, 'quantization_config') and not hasattr(hf_config.llm_config, 'quantization_config'): + setattr(hf_config.llm_config, 'quantization_config', hf_config.quantization_config) cfg = DefaultModelConfigBuilder.build(hf_config.llm_config, model_path, **kwargs) cfg.hf_config = hf_config return cfg diff --git a/lmdeploy/pytorch/models/q_modules.py b/lmdeploy/pytorch/models/q_modules.py index 36f9506327..11fa883ff0 100644 --- a/lmdeploy/pytorch/models/q_modules.py +++ b/lmdeploy/pytorch/models/q_modules.py @@ -16,7 +16,7 @@ class QTensor: This class wraps around a regular Pytorch tensor and adds quantization- specific parameters. """ tensor: torch.Tensor - scale: torch.Tensor + weight_scale_inv: torch.Tensor zero_point: torch.Tensor = None def __post_init__(self): @@ -58,7 +58,7 @@ def forward(self, hidden_states): """Defines the computation performed at every call. Performs RMS normalization followed by dynamic quantization on hidden_states. Returns a QTensor which wraps the - quantized tensor along with its scale factor. + quantized tensor along with its weight_scale_inv factor. """ hidden_states_quant, rms_scale = rms_norm_dynamic_quant(hidden_states, self.weight, @@ -91,7 +91,7 @@ def __init__(self, self.out_features = out_features self.quant_dtype = quant_dtype self.register_buffer('weight', torch.empty((out_features, in_features), device=device, dtype=quant_dtype)) - self.register_buffer('scale', torch.empty((out_features, 1), device=device, dtype=torch.float32)) + self.register_buffer('weight_scale_inv', torch.empty((out_features, 1), device=device, dtype=torch.float32)) if bias: self.register_buffer('bias', torch.empty(out_features, **factory_kwargs)) else: @@ -112,9 +112,9 @@ def from_float(cls, mod: nn.Module, initialization: bool = True, quant_dtype=tor quant_dtype=quant_dtype) if initialization: - weight_quant, scale = per_channel_quant(mod.weight.detach(), quant_dtype) + weight_quant, weight_scale_inv = per_channel_quant(mod.weight.detach(), quant_dtype) q_mod.weight.data = weight_quant - q_mod.scale = scale + q_mod.weight_scale_inv = weight_scale_inv if mod.bias is not None: q_mod.bias.data = mod.bias.detach() @@ -132,12 +132,12 @@ def forward(self, input): input_quant, input_scale = per_token_quant_int8(input, 1e-7, quant_dtype=self.quant_dtype) else: assert isinstance(input, QTensor) - input_quant, input_scale = input.tensor, input.scale + input_quant, input_scale = input.tensor, input.weight_scale_inv out = matmul_kernel_dynamic_quant(input_quant, self.weight, input_scale, - self.scale, + self.weight_scale_inv, output_dtype=torch.float16, bias=self.bias) return out