From 118b3d7e7bab720d8ea9cd95338da60f7512c93a Mon Sep 17 00:00:00 2001
From: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
Date: Mon, 1 Apr 2024 16:39:43 +0800
Subject: [PATCH] Update TensorRT-LLM (#1387)
---
README.md | 12 +-
benchmarks/python/build.py | 226 +++----
benchmarks/python/gpt_benchmark.py | 5 +-
.../batch_manager/kvCacheManager.h | 3 +-
.../libtensorrt_llm_batch_manager_static.a | 4 +-
...sorrt_llm_batch_manager_static.pre_cxx11.a | 4 +-
.../aarch64-linux-gnu/version.txt | 6 +-
.../libtensorrt_llm_batch_manager_static.a | 4 +-
...sorrt_llm_batch_manager_static.pre_cxx11.a | 4 +-
.../executor/aarch64-linux-gnu/version.txt | 2 +-
.../libtensorrt_llm_executor_static.a | 2 +-
...ibtensorrt_llm_executor_static.pre_cxx11.a | 2 +-
docs/source/performance_analysis.md | 2 +-
examples/baichuan/requirements.txt | 2 +-
examples/bloom/requirements.txt | 2 +-
examples/chatglm/requirements.txt | 2 +-
examples/falcon/requirements.txt | 2 +-
examples/gemma/README.md | 76 +--
examples/gemma/convert_checkpoint.py | 51 +-
examples/gemma/requirements.txt | 7 +-
examples/gpt/requirements.txt | 2 +-
examples/gptneox/requirements.txt | 2 +-
examples/high-level-api/requirements.txt | 2 +-
examples/internlm/requirements.txt | 2 +-
examples/llama/requirements.txt | 2 +-
examples/mamba/requirements.txt | 2 +-
examples/medusa/requirements.txt | 2 +-
examples/mixtral/requirements.txt | 2 +-
examples/mpt/requirements.txt | 2 +-
examples/opt/requirements.txt | 2 +-
examples/phi/requirements.txt | 2 +-
examples/quantization/requirements.txt | 2 +-
examples/qwen/requirements.txt | 2 +-
examples/qwenvl/requirements.txt | 2 +-
examples/skywork/requirements.txt | 2 +-
examples/smaug/requirements.txt | 2 +-
examples/utils.py | 1 -
examples/whisper/build.py | 34 +-
examples/whisper/requirements.txt | 2 +-
requirements-windows.txt | 3 +
requirements.txt | 3 +
tensorrt_llm/_utils.py | 6 +-
tensorrt_llm/layers/moe.py | 4 -
tensorrt_llm/models/__init__.py | 3 -
tensorrt_llm/models/gemma/model.py | 10 +-
tensorrt_llm/models/gemma/smoothquant.py | 9 +-
tensorrt_llm/models/gemma/weight.py | 111 +---
tensorrt_llm/models/llama/weight.py | 93 +--
tensorrt_llm/models/modeling_utils.py | 8 -
tensorrt_llm/models/quantized/__init__.py | 14 -
tensorrt_llm/models/quantized/ammo.py | 137 -----
tensorrt_llm/models/quantized/quant.py | 550 ------------------
tensorrt_llm/quantization/quantize_by_ammo.py | 7 +-
tensorrt_llm/version.py | 2 +-
54 files changed, 207 insertions(+), 1240 deletions(-)
delete mode 100644 tensorrt_llm/models/quantized/__init__.py
delete mode 100644 tensorrt_llm/models/quantized/ammo.py
delete mode 100644 tensorrt_llm/models/quantized/quant.py
diff --git a/README.md b/README.md
index 61d4fe979..b3f83af27 100644
--- a/README.md
+++ b/README.md
@@ -17,9 +17,10 @@ TensorRT-LLM
## Latest News
+* [*Weekly*] Check out **[@NVIDIAAIDev](https://twitter.com/nvidiaaidev?lang=en)** & **[NVIDIA AI](https://www.linkedin.com/showcase/nvidia-ai/)** LinkedIn for the latest updates!
* [2024/02/06] [🚀 Speed up inference with SOTA quantization techniques in TRT-LLM](./docs/source/blogs/quantization-in-TRT-LLM.md)
-* [2024/01/30] [ New **XQA-kernel** provides **2.4x more Llama-70B throughput** within the same latency budget](./docs/source/blogs/XQA-kernel.md)
-* [2023/12/04] [**Falcon-180B** on a **single H200** GPU with INT4 AWQ, and **6.7x faster Llama-70B** over A100](./docs/source/blogs/Falcon180B-H200.md)
+* [2024/01/30] [ New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget](./docs/source/blogs/XQA-kernel.md)
+* [2023/12/04] [Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100](./docs/source/blogs/Falcon180B-H200.md)
* [2023/11/27] [SageMaker LMI now supports TensorRT-LLM - improves throughput by 60%, compared to previous version](https://aws.amazon.com/blogs/machine-learning/boost-inference-performance-for-llms-with-new-amazon-sagemaker-containers/)
* [2023/11/13] [H200 achieves nearly 12,000 tok/sec on Llama2-13B](./docs/source/blogs/H200launch.md)
* [2023/10/22] [🚀 RAG on Windows using TensorRT-LLM and LlamaIndex 🦙](https://github.com/NVIDIA/trt-llm-rag-windows#readme)
@@ -29,13 +30,6 @@ TensorRT-LLM
](https://blogs.nvidia.com/blog/2023/10/17/tensorrt-llm-windows-stable-diffusion-rtx/)
-[2023/11/27 - Amazon Sagemaker](https://aws.amazon.com/blogs/machine-learning/boost-inference-performance-for-llms-with-new-amazon-sagemaker-containers/)
-[2023/11/17 - Perplexity](https://blog.perplexity.ai/blog/turbocharging-llama-2-70b-with-nvidia-h100) ;
-[2023/10/31 - Phind](https://www.phind.com/blog/phind-model-beats-gpt4-fast) ;
-[2023/10/12 - Databricks (MosaicML)](https://www.databricks.com/blog/llm-inference-performance-engineering-best-practices) ;
-[2023/10/04 - Perplexity](https://blog.perplexity.ai/blog/introducing-pplx-api) ;
-[2023/09/27 - CloudFlare](https://www.cloudflare.com/press-releases/2023/cloudflare-powers-hyper-local-ai-inference-with-nvidia/);
-
## Table of Contents
- [TensorRT-LLM](#tensorrt-llm)
diff --git a/benchmarks/python/build.py b/benchmarks/python/build.py
index 993646e6b..280889de0 100644
--- a/benchmarks/python/build.py
+++ b/benchmarks/python/build.py
@@ -33,11 +33,12 @@
from tensorrt_llm.functional import LayerNormPositionType, LayerNormType
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
-from tensorrt_llm.models import PretrainedConfig, quantize_model
-from tensorrt_llm.models.modeling_utils import optimize_model
+from tensorrt_llm.models import PretrainedConfig
+from tensorrt_llm.models.modeling_utils import QuantConfig, optimize_model
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
-from tensorrt_llm.quantization import QuantAlgo, QuantMode
+from tensorrt_llm.quantization import QuantAlgo
+from tensorrt_llm.quantization.quantize import quantize
def parse_arguments():
@@ -159,90 +160,31 @@ def parse_arguments():
return parser.parse_args()
-def get_quant_mode(quantization):
- quant_mode = QuantMode(0)
- use_smooth_quant = False
- per_token = False
- per_channel = False
- weight_only_precision = 'int8'
-
+def get_quant_config(quantization: str):
if quantization == "fp8":
- quant_mode = quant_mode.set_fp8_qdq()
- quant_mode = quant_mode.set_fp8_kv_cache()
-
+ return QuantConfig(quant_algo=QuantAlgo.FP8,
+ kv_cache_quant_algo=QuantAlgo.FP8)
elif quantization == "fp8_gemm":
- quant_mode = quant_mode.set_fp8_qdq()
-
+ return QuantConfig(quant_algo=QuantAlgo.FP8)
elif quantization == "fp8_kv_cache":
- quant_mode = quant_mode.set_fp8_kv_cache()
-
+ return QuantConfig(kv_cache_quant_algo=QuantAlgo.FP8)
elif quantization == "int8_sq_per_tensor":
- use_smooth_quant = True
- quant_mode = QuantMode.use_smooth_quant(per_token, per_channel)
-
+ return QuantConfig(quant_algo=QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN)
elif quantization == "int8_sq_per_token_channel":
- use_smooth_quant = True
- per_token = True
- per_channel = True
- quant_mode = QuantMode.use_smooth_quant(per_token, per_channel)
-
+ return QuantConfig(
+ quant_algo=QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN)
elif quantization == "int8_weight_only":
- use_smooth_quant = False
- weight_only_precision = 'int8'
- quant_mode = QuantMode.use_weight_only(use_int4_weights=False)
-
+ return QuantConfig(quant_algo=QuantAlgo.W8A16)
elif quantization == "int4_weight_only":
- weight_only_precision = 'int4'
- quant_mode = QuantMode.use_weight_only(use_int4_weights=True)
-
+ return QuantConfig(quant_algo=QuantAlgo.W4A16)
elif quantization == "int4_weight_only_awq":
- weight_only_precision = 'int4_awq'
- quant_mode = QuantMode.from_description(quantize_weights=True,
- quantize_activations=False,
- per_token=False,
- per_channel=False,
- per_group=True,
- use_int4_weights=True)
-
+ return QuantConfig(quant_algo=QuantAlgo.W4A16_AWQ)
elif quantization == "int4_weight_only_gptq":
- weight_only_precision = 'int4_gptq'
- quant_mode = QuantMode.from_description(quantize_weights=True,
- quantize_activations=False,
- per_token=False,
- per_channel=False,
- per_group=True,
- use_int4_weights=True)
-
+ return QuantConfig(quant_algo=QuantAlgo.W4A16_GPTQ)
elif quantization is None:
- pass
-
+ return QuantConfig()
else:
- raise Exception(f'Unexpected quantization: {quantization}')
-
- return quant_mode, use_smooth_quant, weight_only_precision
-
-
-def get_quant_algo(quantization):
- if quantization == "fp8":
- return QuantAlgo.FP8, QuantAlgo.FP8
- elif quantization == "fp8_gemm":
- return QuantAlgo.FP8, None
- elif quantization == "fp8_kv_cache":
- return None, QuantAlgo.FP8
- elif quantization == "int8_sq_per_tensor":
- return QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN, None
- elif quantization == "int8_sq_per_token_channel":
- return QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN, None
- elif quantization == "int8_weight_only":
- return QuantAlgo.W8A16, None
- elif quantization == "int4_weight_only":
- return QuantAlgo.W4A16, None
- elif quantization == "int4_weight_only_awq":
- return QuantAlgo.W4A16_AWQ, None
- elif quantization == "int4_weight_only_gptq":
- return QuantAlgo.W4A16_GPTQ, None
- elif quantization is None:
- return None, None
+ raise Exception(f"Unexpected quantization: {quantization}")
def build_gpt(args):
@@ -274,9 +216,11 @@ def build_gpt(args):
if args.max_output_len is None else args.max_output_len
max_beam_width = build_config['max_beam_width'] \
if args.max_beam_width is None else args.max_beam_width
- quant_mode, use_smooth_quant, weight_only_precision = get_quant_mode(
- args.quantization)
- use_weight_only = quant_mode.is_weight_only()
+
+ quant_config = get_quant_config(args.quantization)
+ quant_algo = quant_config.quant_algo
+ kv_cache_quant_algo = quant_config.kv_cache_quant_algo
+ quant_mode = quant_config.quant_mode
builder = Builder()
builder_config = builder.create_builder_config(
@@ -306,8 +250,6 @@ def build_gpt(args):
engine_name = get_engine_name(args.model, args.dtype, world_size,
runtime_rank)
- kv_dtype = str_dtype_to_trt(args.dtype)
-
# Initialize Module
family = get_model_family(args.model)
if family == "gpt":
@@ -317,7 +259,7 @@ def build_gpt(args):
build_config['inter_size'] = build_config['hidden_size'] * 4
if build_config['position_embedding_type'] is None:
build_config['position_embedding_type'] = 'learned_absolute'
- quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization)
+
config = {
'architecture': 'GPTForCausalLM',
'dtype': args.dtype,
@@ -349,6 +291,7 @@ def build_gpt(args):
}
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.GPTForCausalLM(config)
+
elif family == "opt":
config = {
'architecture': 'OPTForCausalLM',
@@ -368,14 +311,14 @@ def build_gpt(args):
'embedding_sharding_dim': 0,
'do_layer_norm_before': build_config['do_layer_norm_before'],
'quantization': {
+ 'quant_algo': quant_algo,
+ 'kv_cache_quant_algo': kv_cache_quant_algo,
'group_size': 128
}
}
- quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization)
- config['quantization']['quant_algo'] = quant_algo
- config['quantization']['kv_cache_quant_algo'] = kv_cache_quant_algo
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.OPTForCausalLM(config)
+
elif family == "llama":
config = {
'architecture':
@@ -402,6 +345,8 @@ def build_gpt(args):
'hidden_act':
build_config['hidden_act'],
'quantization': {
+ 'quant_algo': quant_algo,
+ 'kv_cache_quant_algo': kv_cache_quant_algo,
'group_size': 128
},
'mapping': {
@@ -413,11 +358,9 @@ def build_gpt(args):
'moe_top_k':
build_config["moe_top_k"],
}
- quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization)
- config['quantization']['quant_algo'] = quant_algo
- config['quantization']['kv_cache_quant_algo'] = kv_cache_quant_algo
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(config)
+
elif family == "gptj":
config = {
'architecture': 'GPTJForCausalLM',
@@ -438,14 +381,14 @@ def build_gpt(args):
'embedding_sharding_dim': 0,
'do_layer_norm_before': build_config['do_layer_norm_before'],
'quantization': {
+ 'quant_algo': quant_algo,
+ 'kv_cache_quant_algo': kv_cache_quant_algo,
'group_size': 128
}
}
- quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization)
- config['quantization']['quant_algo'] = quant_algo
- config['quantization']['kv_cache_quant_algo'] = kv_cache_quant_algo
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.GPTJForCausalLM(config)
+
elif family == "gptneox":
config = {
'architecture':
@@ -482,16 +425,15 @@ def build_gpt(args):
'embedding_sharding_dim':
0,
'quantization': {
+ 'quant_algo': quant_algo,
+ 'kv_cache_quant_algo': kv_cache_quant_algo,
'group_size': 128,
}
}
- quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization)
- config['quantization']['quant_algo'] = quant_algo
- config['quantization']['kv_cache_quant_algo'] = kv_cache_quant_algo
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.GPTNeoXForCausalLM(config)
+
elif family == "chatglm":
- quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization)
config = {
'architecture': 'ChatGLMForCausalLM',
'dtype': args.dtype,
@@ -525,7 +467,6 @@ def build_gpt(args):
tensorrt_llm_model = tensorrt_llm.models.ChatGLMForCausalLM(config)
elif family in ["chatglm2", "chatglm3"]:
- quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization)
config = {
'architecture': 'ChatGLMForCausalLM',
'dtype': args.dtype,
@@ -576,14 +517,14 @@ def build_gpt(args):
'share_embedding_table': False,
'embedding_sharding_dim': 0,
'quantization': {
+ 'quant_algo': quant_algo,
+ 'kv_cache_quant_algo': kv_cache_quant_algo,
'group_size': 128
}
}
- quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization)
- config['quantization']['quant_algo'] = quant_algo
- config['quantization']['kv_cache_quant_algo'] = kv_cache_quant_algo
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(config)
+
elif family == "falcon":
config = {
'architecture':
@@ -609,6 +550,8 @@ def build_gpt(args):
'hidden_act':
build_config['hidden_act'],
'quantization': {
+ 'quant_algo': quant_algo,
+ 'kv_cache_quant_algo': kv_cache_quant_algo,
'group_size': 128
},
'mapping': {
@@ -622,9 +565,6 @@ def build_gpt(args):
'new_decoder_architecture':
build_config['new_decoder_architecture'],
}
- quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization)
- config['quantization']['quant_algo'] = quant_algo
- config['quantization']['kv_cache_quant_algo'] = kv_cache_quant_algo
if quant_mode.is_weight_only() and quant_mode.has_per_group_scaling():
config['quantization'].update({
'has_zero_point': False,
@@ -633,6 +573,7 @@ def build_gpt(args):
})
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(config)
+
elif family == "baichuan":
config = {
'architecture':
@@ -660,6 +601,8 @@ def build_gpt(args):
'position_embedding_type':
'alibi_with_scale' if '7b' in args.model else 'rope_gpt_neox',
'quantization': {
+ 'quant_algo': quant_algo,
+ 'kv_cache_quant_algo': kv_cache_quant_algo,
'group_size': 128
},
'mapping': {
@@ -667,12 +610,10 @@ def build_gpt(args):
'tp_size': world_size,
},
}
-
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(config)
- elif family == "internlm":
- quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization)
+ elif family == "internlm":
config = {
'architecture':
'LLaMAForCausalLM',
@@ -707,40 +648,50 @@ def build_gpt(args):
build_config['bias'],
}
if quant_mode.is_weight_only():
- if weight_only_precision == 'int4_awq':
+ if 'awq' in args.quantization:
config['quantization'].update({
"group_size": 128,
"has_zero_point": False,
"pre_quant_scale": True,
"exclude_modules": [],
})
- elif weight_only_precision == 'int4_gptq':
+ elif 'gptq' in args.quantization:
config['quantization'].update({
"group_size": 128,
"has_zero_point": True,
"pre_quant_scale": False,
})
-
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(config)
+
elif family == "qwen":
- tensorrt_llm_model = tensorrt_llm.models.QWenForCausalLM(
- num_layers=build_config['num_layers'],
- num_heads=build_config['num_heads'],
- num_kv_heads=num_kv_heads,
- hidden_size=build_config['hidden_size'],
- seq_length=2048,
- vocab_size=build_config['vocab_size'],
- hidden_act=build_config['hidden_act'],
- max_position_embeddings=build_config['n_positions'],
- dtype=kv_dtype,
- mlp_hidden_size=build_config['inter_size'],
- neox_rotary_style=True,
- mapping=tensorrt_llm.Mapping(world_size=world_size,
- tp_size=world_size), # TP only
- use_parallel_embedding=False,
- embedding_sharding_dim=1,
- quant_mode=quant_mode)
+ config = {
+ 'architecture': 'QWenForCausalLM',
+ 'dtype': args.dtype,
+ 'num_hidden_layers': build_config['num_layers'],
+ 'num_attention_heads': build_config['num_heads'],
+ 'hidden_size': build_config['hidden_size'],
+ 'intermediate_size': build_config['inter_size'],
+ 'num_key_value_heads': num_kv_heads,
+ 'vocab_size': build_config['vocab_size'],
+ 'position_embedding_type': 'rope_gpt_neox',
+ 'max_position_embeddings': build_config['n_positions'],
+ 'hidden_act': build_config['hidden_act'],
+ 'rotary_base': 10000.0,
+ 'norm_epsilon': 1e-06,
+ 'quantization': {
+ 'quant_algo': quant_algo,
+ 'kv_cache_quant_algo': kv_cache_quant_algo,
+ 'group_size': 128
+ },
+ 'mapping': {
+ 'world_size': world_size,
+ 'tp_size': world_size,
+ },
+ }
+ config = PretrainedConfig.from_dict(config)
+ tensorrt_llm_model = tensorrt_llm.models.QWenForCausalLM(config)
+
elif family == "mamba":
config = {
'architecture': 'MambaLMHeadModel',
@@ -761,15 +712,10 @@ def build_gpt(args):
}
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.MambaLMHeadModel(config)
+
else:
raise Exception(f'Unexpected model: {args.model}')
- if family not in [
- 'gpt', 'opt', 'bloom', 'falcon', 'llama', 'internlm', 'gptneox',
- 'gptj', 'mamba', 'baichuan', 'chatglm', 'chatglm2', 'chatglm3'
- ]:
- tensorrt_llm_model = quantize_model(tensorrt_llm_model, quant_mode)
-
if family in ['llama']:
tensorrt_llm_model = optimize_model(tensorrt_llm_model,
use_fused_mlp=True)
@@ -792,6 +738,8 @@ def build_gpt(args):
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
# Quantization plugins.
+ use_smooth_quant = quant_mode.has_act_and_weight_quant()
+ use_weight_only = quant_mode.is_weight_only()
if use_smooth_quant:
network.plugin_config.set_smooth_quant_plugins(dtype=args.dtype)
elif use_weight_only:
@@ -825,13 +773,8 @@ def build_gpt(args):
max_seq_len=max_input_len + max_output_len,
use_cache=True,
max_beam_width=max_beam_width)
- if family in [
- 'gpt', 'opt', 'bloom', 'falcon', 'llama', 'internlm', 'gptneox',
- 'gptj', 'mamba', 'baichuan', 'chatglm', 'chatglm2', 'chatglm3'
- ]:
- tensorrt_llm_model(**inputs)
- else:
- tensorrt_llm_model(*inputs)
+
+ tensorrt_llm_model(**inputs)
if args.mode in ['plugin', 'plugin-ifb']:
tensorrt_llm.graph_rewriting.optimize(network)
@@ -1032,7 +975,8 @@ def enc_dec_build_helper(component, config, args):
else:
rescale_before_lm_head = False
- quant_mode, _, _ = get_quant_mode(args.quantization)
+ quant_config = get_quant_config(args.quantization)
+ quant_mode = quant_config.quant_mode
use_weight_only = quant_mode.is_weight_only()
builder = Builder()
@@ -1089,7 +1033,7 @@ def enc_dec_build_helper(component, config, args):
n_layer=config['num_layers'],
dtype=dtype)
if use_weight_only:
- tllm_model = quantize_model(tllm_model, quant_mode)
+ tllm_model = quantize(tllm_model, quant_config)
else:
tllm_model = tensorrt_llm.models.EncoderModel(
num_layers=config['num_layers'],
@@ -1154,7 +1098,7 @@ def enc_dec_build_helper(component, config, args):
logits_dtype=logits_dtype, # by default
fp16_clamping=fp16_clamping)
if use_weight_only and family == 'whisper':
- tllm_model = quantize_model(tllm_model, quant_mode)
+ tllm_model = quantize(tllm_model, quant_config)
# Module -> Network
engine_name = get_engine_name(args.model, args.dtype, world_size,
diff --git a/benchmarks/python/gpt_benchmark.py b/benchmarks/python/gpt_benchmark.py
index 1bb72b522..a5e547113 100644
--- a/benchmarks/python/gpt_benchmark.py
+++ b/benchmarks/python/gpt_benchmark.py
@@ -23,7 +23,7 @@
from allowed_configs import get_build_config, BuildConfig # isort:skip
from base_benchmark import BaseBenchmark # isort:skip
-from build import build_gpt, get_quant_mode # isort:skip
+from build import build_gpt, get_quant_config # isort:skip
def element_size(dtype: str):
@@ -81,7 +81,8 @@ def __init__(self, args, batch_sizes, in_out_lens, rank, world_size):
if args.max_output_len is not None:
self.max_output_len = args.max_output_len
- self.quant_mode, _, _ = get_quant_mode(args.quantization)
+ self.quant_config = get_quant_config(args.quantization)
+ self.quant_mode = self.quant_config.quant_mode
self.enable_fp8 = self.quant_mode.has_fp8_qdq()
self.fp8_kv_cache = self.quant_mode.has_fp8_kv_cache()
if self.quant_mode.has_fp8_kv_cache():
diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
index aa4c915be..0e9bcc5ae 100644
--- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
+++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
@@ -469,7 +469,8 @@ class KVCacheManager
void getBlockPointersOfBatch(
runtime::ITensor& dstPointers, SizeType firstBatchSlotIdx, SizeType batchSize, SizeType beamWidth) const;
- void copyBlockPointers(
+ // returns maxBlockCount of all beams
+ SizeType copyBlockPointers(
runtime::ITensor& dstPointers, SizeType dstSlotOffset, SizeType seqSlotIdx, SizeType beamWidth) const;
// Volume of [2, numKvHeads, tokensPerBlock, sizePerHead]
diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a
index ede6c8a5e..86385a528 100644
--- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a
+++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:ba545e1931c9405b75028b019ac3949ec5cec57c304aaa10ea6c854f572225b1
-size 2856456
+oid sha256:2309c96080b61795e03130338d9a023c67c4ecba7b7ba9d32797d2ce8fe170aa
+size 2869834
diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
index 58817bfba..de1f17710 100644
--- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:8ef69cd446d54a1c876237f812839e6ecd9174c327edc5ff4f6594bb2b203aae
-size 2885046
+oid sha256:eb623f435dd920799783d353b5062fd4c3dcaa3529d2007d1550cc21ecae39ee
+size 2898404
diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt
index 59b3025b6..e26cc21c5 100644
--- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
-cd113ef1af7d78ac0791d4323a8ef370 libtensorrt_llm_batch_manager_static.a
-c3583c5524c71f2cd5ae7a3bba864377 libtensorrt_llm_batch_manager_static.pre_cxx11.a
-45a8cb4ea commit
+577dc50763c3152388738bcb22ed9f73 libtensorrt_llm_batch_manager_static.a
+2ae3fab3836c6fd961d82458663888bb libtensorrt_llm_batch_manager_static.pre_cxx11.a
+9fd0ada1f commit
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a
index 5583f889c..f664bbd38 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:32ca7c2a6701457ecb537a56d9558fb62d35ec5443905d63f1f1a288d8f48f87
-size 2780748
+oid sha256:26da47fbe5e5a2246db58152011c78b87a56deafb6df38821638b1c27c00af22
+size 2796970
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
index 8585ada27..e8553617c 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:c8fdf3d223bb7e0a5eeffbea0a82e50a8e0ec3815b274cdc95d6fb1c36f2178d
-size 2755044
+oid sha256:04243eaf43b74886b6b8ab8ec13482c2c1c80d3468e8f6f40a8e16a8d4cace6e
+size 2769552
diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt
index 73b2862d7..13bab21ee 100644
--- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
a552c727c128f6ca402ddc119d295ab0 libtensorrt_llm_executor_static.a
38ad482b0be0996970bc572622967acf libtensorrt_llm_executor_static.pre_cxx11.a
-45a8cb4ea commit
+9fd0ada1f commit
diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a
index 435319970..4eb217cfe 100644
--- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a
+++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:cc2d59c4878e74f7e38a65187ed303a77b43a3b71753b3e4dcc99a937ccbcdf8
+oid sha256:5aa174bdc52db4f8f193ccdbc4b764835564204a5576c91fffd36720e9d79fdd
size 884870
diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
index 0538c421a..a68485592 100644
--- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:dc7e967c9aa7ef50227a791c670fe71a9bdef907ce45d3282955ebd5e2ead88f
+oid sha256:e1117c2c02debd8f373d541c6d96cac5f66a7a9d374a10bbb7c89d6455b869ba
size 837988
diff --git a/docs/source/performance_analysis.md b/docs/source/performance_analysis.md
index fe6602055..4827a31a4 100644
--- a/docs/source/performance_analysis.md
+++ b/docs/source/performance_analysis.md
@@ -22,7 +22,7 @@ Toggling the CUDA profiler runtime API on and off:
[TensorRT-LLM][INFO] {"Active Request Count":249,"Context Requests":8,"Free KV cache blocks":0,"Generation Requests":231,"Iteration Counter":90,"Max KV cache blocks":2448,"Max Request Count":256,"MicroBatch ID":0,"Runtime CPU Memory Usage":28784,"Runtime GPU Memory Usage":540173600,"Runtime Pinned Memory Usage":0,"Scheduled Requests":239,"Timestamp":"12-13-2023 14:55:14","Tokens per KV cache block":128,"Total Context Tokens":6904,"Used KV cache blocks":2448}
```
### Inference Time Environment Variables
- * `TLLM_GPTM_PROFILE_START_STOP`, a csv of iterations to trigger start/stop for gptManagerBenchmark (corresponds to "Iteration Counter" in output above
+ * `TLLM_GPTM_PROFILE_START_STOP`, a csv of iterations to trigger start/stop for gptManagerBenchmark (corresponds to "Iteration Counter" in output above. Each value can be a range using the "-" separator e.g. 0-10. In the case of ranges all iterations in that range will be placed in the same nsys file.
* `TLLM_GPTS_PROFILE_START_STOP`, a csv of static batching iteration indexes to trigger start/stop for gptSessionBenchmark
## Coordinating with NVIDIA Nsight Systems Launch
diff --git a/examples/baichuan/requirements.txt b/examples/baichuan/requirements.txt
index 17b8f5b54..8b8effb4c 100644
--- a/examples/baichuan/requirements.txt
+++ b/examples/baichuan/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets~=2.15.0
evaluate~=0.4.1
rouge_score~=0.1.2
diff --git a/examples/bloom/requirements.txt b/examples/bloom/requirements.txt
index 803414f24..a6b54863c 100644
--- a/examples/bloom/requirements.txt
+++ b/examples/bloom/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets~=2.14.5
evaluate~=0.4.1
rouge_score~=0.1.2
diff --git a/examples/chatglm/requirements.txt b/examples/chatglm/requirements.txt
index 442393dfe..c69b08836 100644
--- a/examples/chatglm/requirements.txt
+++ b/examples/chatglm/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets~=2.14.5
evaluate~=0.4.1
protobuf
diff --git a/examples/falcon/requirements.txt b/examples/falcon/requirements.txt
index 674c286f1..39f581023 100644
--- a/examples/falcon/requirements.txt
+++ b/examples/falcon/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
transformers>=4.31.0
datasets~=2.14.5
evaluate~=0.4.1
diff --git a/examples/gemma/README.md b/examples/gemma/README.md
index a3861337b..c141b9b24 100644
--- a/examples/gemma/README.md
+++ b/examples/gemma/README.md
@@ -147,11 +147,10 @@ In this section, we demonstrate the scripts to convert checkpoint, building engi
#### Run inference under bfloat16 for HF checkpoint
```bash
-git clone git@hf.co:google/gemma-2b
-CKPT_PATH=gemma-2b/
+CKPT_PATH=/tmp/models/hf/gemma/gemma-2b/
UNIFIED_CKPT_PATH=/tmp/ckpt/hf/gemma/2b/1-gpu/
ENGINE_PATH=/tmp/engines/gemma/2B/bf16/1-gpu/
-VOCAB_FILE_PATH=gemma-2b/
+VOCAB_FILE_PATH=/tmp/models/hf/gemma/gemma-2b/
python3 ./examples/gemma/convert_checkpoint.py \
--ckpt-type hf \
@@ -166,7 +165,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--max_batch_size 8 \
--max_input_len 3000 \
--max_output_len 100 \
- --lookup_plugin bfloat16 \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
@@ -192,15 +190,10 @@ WARNING: This way of running FP8 will introduce noticeable accuracy drop. To avo
In this example, we demonstrate how to run FP8 inference on Gemma. Note that `convert_checkpoint.py` only uses identity activation scales, so the accuracy might be little worse than higher precision in some cases, but it is still very good because we don't do any calibration. This also shows the stability of FP8 compared to INT8.
```bash
-git clone git@hf.co:google/gemma-2b-it-keras
-GIT_LFS_SKIP_SMUDGE=1 git clone git@hf.co:google/gemma-2b-it-flax # clone tokenizer model
-cd gemma-2b-it-flax
-git lfs pull -I tokenizer.model
-
-CKPT_PATH=gemma-2b-it-keras
+CKPT_PATH=/tmp/models/gemma_keras/keras/gemma_2b_en/
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_en_tensorrt_llm/fp8/tp1/
ENGINE_PATH=/tmp/gemma/2B/fp8/1-gpu/
-VOCAB_FILE_PATH=gemma-2b-it-flax/tokenizer.model
+VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type keras \
@@ -217,7 +210,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--max_batch_size 8 \
--max_input_len 3000 \
--max_output_len 100 \
- --lookup_plugin bfloat16 \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
@@ -247,11 +239,10 @@ Average accuracy: 0.356
#### Run 2B inference under SmoothQuant for jax checkpoint
```bash
-git clone git@hf.co:google/gemma-2b-it-flax
-CKPT_PATH=gemma-2b-it-flax/2b-it/
+CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_2b_it
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_it_tensorrt_llm/sq/tp1
ENGINE_PATH=/tmp/gemma/2B/int8_sq/1-gpu/
-VOCAB_FILE_PATH=gemma-2b-it-flax/tokenizer.model
+VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type jax \
@@ -268,7 +259,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--max_input_len 3000 \
--max_output_len 100 \
--enable_xqa enable \
- --lookup_plugin float16 \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
@@ -294,11 +284,10 @@ Available precisions: `int8` and `int4`
* `int8`
```bash
-git clone git@hf.co:google/gemma-2b-it-flax
-CKPT_PATH=gemma-2b-it-flax/2b-it/
+CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_2b_it
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_it_tensorrt_llm/w8_a16/tp1/
ENGINE_PATH=/tmp/gemma/2B/w8_a16/1-gpu/
-VOCAB_FILE_PATH=gemma-2b-it-flax/tokenizer.model
+VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type jax \
@@ -314,7 +303,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--max_input_len 3000 \
--max_output_len 100 \
--enable_xqa enable \
- --lookup_plugin bfloat16 \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
@@ -336,11 +324,10 @@ python3 ../summarize.py --test_trt_llm \
* `int4`
```bash
-git clone git@hf.co:google/gemma-2b-it-flax
-CKPT_PATH=gemma-2b-it-flax/2b-it/
+CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_2b_it
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_it_tensorrt_llm/w4_a16/tp1/
ENGINE_PATH=/tmp/gemma/2B/w4_a16/1-gpu/
-VOCAB_FILE_PATH=gemma-2b-it-flax/tokenizer.model
+VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type jax \
@@ -356,7 +343,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--max_input_len 3000 \
--max_output_len 100 \
--enable_xqa enable \
- --lookup_plugin bfloat16 \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
@@ -378,11 +364,10 @@ python3 ../summarize.py --test_trt_llm \
#### Run inference under INT8 KV caches for jax checkpoint
```bash
-git clone git@hf.co:google/gemma-2b-it-flax
-CKPT_PATH=gemma-2b-it-flax/2b-it/
+CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_2b_it
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_it_tensorrt_llm/int8kv/tp1
ENGINE_PATH=/tmp/gemma/2B/int8kv/1-gpu/
-VOCAB_FILE_PATH=gemma-2b-it-flax/tokenizer.model
+VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type jax \
@@ -401,7 +386,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--max_output_len 100 \
--enable_xqa enable \
--strongly_type \
- --lookup_plugin bfloat16 \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
@@ -427,14 +411,12 @@ python3 ../summarize.py --test_trt_llm \
Since torch model does not have model config, we need to add it manually in `CKPT_PATH` with file name `config.json`.
```bash
-git clone git@hf.co:google/gemma-7b-pytorch
-
-CKPT_PATH=gemma-7b-pytorch/
+CKPT_PATH=/tmp/models/pytorch/ckpt/
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/bf16/tp1/
ENGINE_PATH=/tmp/gemma/7B/bf16/1-gpu/
-VOCAB_FILE_PATH=gemma-7b-pytorch/tokenizer.model
+VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
-python3 ./examples/gemma/convert_checkpoint.py \
+python3 ./convert_checkpoint.py \
--ckpt-type torch \
--model-dir ${CKPT_PATH} \
--dtype bfloat16 \
@@ -447,7 +429,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--max_batch_size 8 \
--max_input_len 3000 \
--max_output_len 100 \
- --lookup_plugin bfloat16 \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
@@ -492,7 +473,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--max_batch_size 8 \
--max_input_len 3000 \
--max_output_len 100 \
- --lookup_plugin bfloat16 \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
@@ -514,11 +494,10 @@ python3 ../summarize.py --test_trt_llm \
#### Run 7B inference under SmoothQuant for jax checkpoint
```bash
-git clone git@hf.co:google/gemma-7b-it-flax
-CKPT_PATH=gemma-7b-it-flax/7b-it/
+CKPT_PATH=/tmp/models/gemma_nv/checkpoints/tmp_7b_it
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/sq/tp1
ENGINE_PATH=/tmp/gemma/7B/int8_sq/1-gpu/
-VOCAB_FILE_PATH=gemma-7b-it-flax/tokenizer.model
+VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type jax \
@@ -535,7 +514,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--max_input_len 3000 \
--max_output_len 100 \
--enable_xqa enable \
- --lookup_plugin float16 \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
@@ -562,15 +540,10 @@ Available precisions: `int8` and `int4`
* `int8`
```bash
-git clone git@hf.co:google/gemma-7b-it-keras
-GIT_LFS_SKIP_SMUDGE=1 git clone git@hf.co:google/gemma-7b-it-flax # clone tokenizer model
-cd gemma-7b-it-flax
-git lfs pull -I tokenizer.model
-
-CKPT_PATH=gemma-7b-it-keras
+CKPT_PATH=/tmp/models/gemma_keras/keras/gemma_7b_en/
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_7b_it_tensorrt_llm/w8_a16/tp1/
ENGINE_PATH=/tmp/gemma/7B/w8_a16/1-gpu/
-VOCAB_FILE_PATH=gemma-7b-it-flax/tokenizer.model
+VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model
python3 ./convert_checkpoint.py \
--ckpt-type keras \
@@ -586,7 +559,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--max_input_len 3000 \
--max_output_len 100 \
--enable_xqa enable \
- --lookup_plugin bfloat16 \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
@@ -627,7 +599,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--max_input_len 3000 \
--max_output_len 100 \
--enable_xqa enable \
- --lookup_plugin bfloat16 \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
@@ -671,7 +642,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--max_output_len 100 \
--enable_xqa enable \
--strongly_type \
- --lookup_plugin bfloat16 \
--output_dir ${ENGINE_PATH}
python3 ../summarize.py --test_trt_llm \
@@ -717,7 +687,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--max_batch_size 8 \
--max_input_len 3000 \
--max_output_len 100 \
- --lookup_plugin float16 \
--output_dir ${ENGINE_PATH}
```
@@ -731,13 +700,12 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--max_input_len 3000 \
--max_output_len 100 \
--enable_xqa enable \
- --lookup_plugin float16 \
--output_dir ${ENGINE_PATH}
```
#### Accuracy Results on MMLU
| Model | fp8 | int4_awq | int8_sq (AMMO) | int8_sq (Native per-channel) |
-| ------------- | ----- | -------- | -------------- | ---------------------------- |
-| 2B Pretrained | 0.407 | 0.378 | 0.338 | 0.338 |
-| 7B Pretrained | 0.643 | 0.615 | 0.448 | 0.595 |
+|---------------|-------|----------|----------------|------------------|
+| 2B Pretrained | 0.407 | 0.378 | 0.338 | 0.338 |
+| 7B Pretrained | 0.643 | 0.615 | 0.448 | 0.595 |
diff --git a/examples/gemma/convert_checkpoint.py b/examples/gemma/convert_checkpoint.py
index 2562d145d..130915e49 100644
--- a/examples/gemma/convert_checkpoint.py
+++ b/examples/gemma/convert_checkpoint.py
@@ -22,8 +22,7 @@
from transformers import AutoConfig, AutoModelForCausalLM
import tensorrt_llm
-from tensorrt_llm._utils import (np_bfloat16, numpy_to_dtype, numpy_to_torch,
- torch_to_numpy)
+from tensorrt_llm._utils import torch_to_numpy
from tensorrt_llm.models.gemma.smoothquant import *
from tensorrt_llm.models.gemma.weight import (dummy_weights_awq,
load_from_fp8_llama,
@@ -110,7 +109,7 @@ def parse_arguments():
help='tokenizer path; defaults to jax_model_dir if left unspecified')
args = parser.parse_args()
- args.use_embedding_sharing = True
+
return args
@@ -234,11 +233,7 @@ def flatten_params(self, params):
def walk(name, obj):
if isinstance(obj, h5py.Dataset):
- if obj.dtype == "|V2":
- # bfloat16 case
- f_params[name] = np.array(obj).astype(np_bfloat16)
- else:
- f_params[name] = np.array(obj)
+ f_params[name] = np.array(obj)
params.visititems(walk)
return f_params
@@ -411,7 +406,7 @@ def add_trt_llm_weight(weights: typing.Dict[str, np.ndarray],
dtype: typing.Optional[np.dtype] = None):
assert name not in weights, f"{name} is already added."
if dtype is not None:
- param = numpy_to_dtype(param, dtype)
+ param = param.astype(dtype)
param = np.ascontiguousarray(param)
weights[name] = param
@@ -425,9 +420,8 @@ def quantize(param: np.ndarray,
else:
raise ValueError(f"Invalid configuration got quant_mode={quant_mode}")
- if param.dtype == np.dtype("bfloat16") or param.dtype == "|V2":
- param = torch.from_numpy(numpy_to_dtype(param,
- 'float32')).to(torch.bfloat16)
+ if param.dtype == np.dtype("bfloat16"):
+ param = torch.from_numpy(param.astype(np.float32)).to(torch.bfloat16)
else:
param = torch.from_numpy(param)
param = param.t().contiguous()
@@ -440,7 +434,7 @@ def quantize(param: np.ndarray,
param, quant_dtype)
if scales.dtype == torch.bfloat16:
- scales = numpy_to_dtype(scales.to(torch.float32).numpy(), "bfloat16")
+ scales = scales.to(torch.float32).numpy().astype("bfloat16")
else:
scales = scales.numpy()
return quantized_weights.numpy(), scales
@@ -472,9 +466,6 @@ def convert_from_checkpoint(
if trt_llm_name is None: # omit as used with other params
continue
- # TensorRT-LLM does not support bfloat16 datatype of jax
- if param.dtype == flax.jax_utils.jnp.bfloat16:
- param = param.astype(np.float32)
if "attn.q_einsum" in name:
gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads
assert gqa_mode
@@ -802,6 +793,10 @@ def convert_from_checkpoint(
add_trt_llm_weight(weights, "lm_head.weight",
np.copy(lm_head), trt_llm_config.dtype)
+ param = np.multiply(
+ param.astype(np.float32),
+ math.sqrt(trt_llm_config.hidden_size),
+ )
if trt_llm_config.use_parallel_embedding:
assert trt_llm_config.vocab_size % tp_size == 0
param = split_matrix_tp(
@@ -845,13 +840,6 @@ def convert(worker_rank, args, convert_kwargs):
smoother = {}
dataset = load_dataset("ccdv/cnn_dailymail", '3.0.0')
tokenizer = sp.SentencePieceProcessor(model_file=args.tokenizer_dir)
- if "transformer.vocab_embedding.weight" in weights:
- # To use the HF to do SmoothQuant, we need to scale the embedding.
- weights["transformer.vocab_embedding.weight"] = np.multiply(
- weights["transformer.vocab_embedding.weight"].astype(
- np.float32),
- math.sqrt(trt_llm_config.hidden_size),
- )
hf_model = create_model_from_config(trt_llm_config, weights)
act_range = capture_activation_range(hf_model, tokenizer, dataset)
if args.use_smooth_quant_plugin is not None:
@@ -865,18 +853,6 @@ def convert(worker_rank, args, convert_kwargs):
torch.quint4x2, args.use_smooth_quant_plugin is not None,
args.per_channel, args.per_token, args.calibrate_kv_cache,
act_range, qkv_para, smoother)
- if "transformer.vocab_embedding.weight" in weights:
- # Revert the scaling of embedding
- weights["transformer.vocab_embedding.weight"] = torch.divide(
- weights["transformer.vocab_embedding.weight"].to(
- torch.float32),
- math.sqrt(trt_llm_config.hidden_size),
- )
- if trt_llm_config.share_embedding_table and "lm_head.weight" in weights:
- # When share_embedding_table is enabled, we add lm_head into weights
- # to do quantization in HF. Remove lm_head before saving it in unified
- # checkpoint.
- del weights["lm_head.weight"]
safetensors.torch.save_file(
weights, args.output_model_dir / f"rank{rank}.safetensors")
return
@@ -901,9 +877,7 @@ def convert(worker_rank, args, convert_kwargs):
args.fp8_kv_cache, weight_scales)
weights.update(scales)
- for key in weights:
- weights[key] = numpy_to_torch(weights[key])
- safetensors.torch.save_file(
+ safetensors.numpy.save_file(
weights, args.output_model_dir / f"rank{rank}.safetensors")
@@ -993,7 +967,6 @@ def main():
tp_size=args.world_size,
pp_size=1,
quantization=quant_config,
- share_embedding_table=args.use_embedding_sharing,
)
trt_llm_config_dict = trt_llm_config.to_dict()
diff --git a/examples/gemma/requirements.txt b/examples/gemma/requirements.txt
index a420e1ed0..30785ce93 100644
--- a/examples/gemma/requirements.txt
+++ b/examples/gemma/requirements.txt
@@ -1,8 +1,11 @@
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+# WAR the new posting of "nvidia-cudnn-cu12~=9.0".
+# "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9".
+nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64"
+tensorrt_llm==0.9.0.dev2024040200
flax~=0.8.0
-jax[cuda12_pip]~=0.4.19; platform_system != "Windows"
+# jax[cuda12_pip]~=0.4.19; platform_system != "Windows"
jax~=0.4.19; platform_system == "Windows"
safetensors~=0.4.1
sentencepiece~=0.1.99
diff --git a/examples/gpt/requirements.txt b/examples/gpt/requirements.txt
index 90742ca41..73b87b655 100644
--- a/examples/gpt/requirements.txt
+++ b/examples/gpt/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets~=2.14.5
evaluate~=0.4.1
rouge_score~=0.1.2
diff --git a/examples/gptneox/requirements.txt b/examples/gptneox/requirements.txt
index a200458d5..32c35f747 100644
--- a/examples/gptneox/requirements.txt
+++ b/examples/gptneox/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets~=2.14.5
rouge_score~=0.1.2
evaluate~=0.4.1
diff --git a/examples/high-level-api/requirements.txt b/examples/high-level-api/requirements.txt
index 24d68697a..e0c321273 100644
--- a/examples/high-level-api/requirements.txt
+++ b/examples/high-level-api/requirements.txt
@@ -1,2 +1,2 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
diff --git a/examples/internlm/requirements.txt b/examples/internlm/requirements.txt
index a11608301..e7461b286 100644
--- a/examples/internlm/requirements.txt
+++ b/examples/internlm/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets==2.14.5
rouge_score~=0.1.2
sentencepiece~=0.1.99
diff --git a/examples/llama/requirements.txt b/examples/llama/requirements.txt
index 639a1be7a..4791f486f 100644
--- a/examples/llama/requirements.txt
+++ b/examples/llama/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets==2.14.6
evaluate~=0.4.1
rouge_score~=0.1.2
diff --git a/examples/mamba/requirements.txt b/examples/mamba/requirements.txt
index 24d68697a..e0c321273 100644
--- a/examples/mamba/requirements.txt
+++ b/examples/mamba/requirements.txt
@@ -1,2 +1,2 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
diff --git a/examples/medusa/requirements.txt b/examples/medusa/requirements.txt
index 8552bd1b1..9b14bbb46 100644
--- a/examples/medusa/requirements.txt
+++ b/examples/medusa/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets~=2.14.5
rouge_score~=0.1.2
sentencepiece~=0.1.99
diff --git a/examples/mixtral/requirements.txt b/examples/mixtral/requirements.txt
index 662a597bf..655f032ed 100644
--- a/examples/mixtral/requirements.txt
+++ b/examples/mixtral/requirements.txt
@@ -1,4 +1,4 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
transformers==4.38.2
accelerate==0.25.0
diff --git a/examples/mpt/requirements.txt b/examples/mpt/requirements.txt
index c998f3cba..d09ef3e42 100644
--- a/examples/mpt/requirements.txt
+++ b/examples/mpt/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets~=2.14.5
evaluate~=0.4.1
rouge_score~=0.1.2
diff --git a/examples/opt/requirements.txt b/examples/opt/requirements.txt
index c998f3cba..d09ef3e42 100644
--- a/examples/opt/requirements.txt
+++ b/examples/opt/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets~=2.14.5
evaluate~=0.4.1
rouge_score~=0.1.2
diff --git a/examples/phi/requirements.txt b/examples/phi/requirements.txt
index 1d2229f6a..9db5406ba 100644
--- a/examples/phi/requirements.txt
+++ b/examples/phi/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets~=2.14.5
evaluate~=0.4.1
rouge_score~=0.1.2
diff --git a/examples/quantization/requirements.txt b/examples/quantization/requirements.txt
index 1db286404..9fdaa4d8a 100644
--- a/examples/quantization/requirements.txt
+++ b/examples/quantization/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets>=2.14.4
nemo-toolkit[all]<=1.20.0,>=1.18.0
rouge_score~=0.1.2
diff --git a/examples/qwen/requirements.txt b/examples/qwen/requirements.txt
index de7d0230e..648387d38 100644
--- a/examples/qwen/requirements.txt
+++ b/examples/qwen/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets~=2.16.0
evaluate~=0.4.1
rouge_score~=0.1.2
diff --git a/examples/qwenvl/requirements.txt b/examples/qwenvl/requirements.txt
index 1beed7c17..199e0f5b4 100644
--- a/examples/qwenvl/requirements.txt
+++ b/examples/qwenvl/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets~=2.16.0
evaluate~=0.4.1
rouge_score~=0.1.2
diff --git a/examples/skywork/requirements.txt b/examples/skywork/requirements.txt
index f6d50ac84..8935683e8 100644
--- a/examples/skywork/requirements.txt
+++ b/examples/skywork/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets~=2.16.1
evaluate~=0.4.1
rouge_score~=0.1.2
diff --git a/examples/smaug/requirements.txt b/examples/smaug/requirements.txt
index 639a1be7a..4791f486f 100644
--- a/examples/smaug/requirements.txt
+++ b/examples/smaug/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
datasets==2.14.6
evaluate~=0.4.1
rouge_score~=0.1.2
diff --git a/examples/utils.py b/examples/utils.py
index 7917ec0ba..200aa9c46 100644
--- a/examples/utils.py
+++ b/examples/utils.py
@@ -21,7 +21,6 @@
from tensorrt_llm.builder import get_engine_version
-# TODO(enweiz): Update for refactored models
DEFAULT_HF_MODEL_DIRS = {
'BaichuanForCausalLM': 'baichuan-inc/Baichuan-13B-Chat',
'BloomForCausalLM': 'bigscience/bloom-560m',
diff --git a/examples/whisper/build.py b/examples/whisper/build.py
index f70aea601..ea1d1bb90 100644
--- a/examples/whisper/build.py
+++ b/examples/whisper/build.py
@@ -24,10 +24,11 @@
from tensorrt_llm.builder import Builder
from tensorrt_llm.functional import LayerNormPositionType, LayerNormType
from tensorrt_llm.logger import logger
-from tensorrt_llm.models import quantize_model
+from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
-from tensorrt_llm.quantization import QuantMode
+from tensorrt_llm.quantization import QuantAlgo
+from tensorrt_llm.quantization.quantize import quantize
MODEL_ENCODER_NAME = "whisper_encoder"
MODEL_DECODER_NAME = "whisper_decoder"
@@ -164,16 +165,17 @@ def parse_arguments():
)
setattr(args, plugin_arg, args.dtype)
+ quant_algo = None
+ kv_cache_quant_algo = None
if args.use_weight_only:
- args.quant_mode = QuantMode.from_description(
- quantize_weights=True,
- quantize_activations=False,
- use_int4_weights="int4" in args.weight_only_precision)
- else:
- args.quant_mode = QuantMode(0)
-
+ if "int4" in args.weight_only_precision:
+ quant_algo = QuantAlgo.W4A16
+ else:
+ quant_algo = QuantAlgo.W8A16
if args.int8_kv_cache:
- args.quant_mode = args.quant_mode.set_int8_kv_cache()
+ kv_cache_quant_algo = QuantAlgo.INT8
+ args.quant_config = QuantConfig(quant_algo=quant_algo,
+ kv_cache_quant_algo=kv_cache_quant_algo)
return args
@@ -204,7 +206,7 @@ def build_encoder(model, args):
hidden_size=hidden_states,
max_batch_size=max_batch_size,
max_beam_width=args.max_beam_width,
- int8=args.quant_mode.has_act_or_weight_quant(),
+ int8=args.quant_config.quant_mode.has_act_or_weight_quant(),
n_mels=model_metadata['n_mels'],
num_languages=model_metadata['n_vocab'] - 51765 -
int(model_is_multilingual),
@@ -216,8 +218,8 @@ def build_encoder(model, args):
model_metadata['n_audio_layer'], str_dtype_to_trt(args.dtype))
if args.use_weight_only:
- tensorrt_llm_whisper_encoder = quantize_model(
- tensorrt_llm_whisper_encoder, args.quant_mode)
+ tensorrt_llm_whisper_encoder = quantize(tensorrt_llm_whisper_encoder,
+ args.quant_config)
use_gemm_woq_plugin = args.use_gemm_plugin and args.use_weight_only
load_encoder_weight(tensorrt_llm_whisper_encoder, model_metadata,
@@ -294,7 +296,7 @@ def build_decoder(model, args):
cross_attention=True,
has_position_embedding=True,
has_token_type_embedding=False,
- int8=args.quant_mode.has_act_or_weight_quant(),
+ int8=args.quant_config.quant_mode.has_act_or_weight_quant(),
)
tensorrt_llm_whisper_decoder = tensorrt_llm.models.DecoderModel(
@@ -327,8 +329,8 @@ def build_decoder(model, args):
logits_dtype=str_dtype_to_trt(args.dtype))
if args.use_weight_only:
- tensorrt_llm_whisper_decoder = quantize_model(
- tensorrt_llm_whisper_decoder, args.quant_mode)
+ tensorrt_llm_whisper_decoder = quantize(tensorrt_llm_whisper_decoder,
+ args.quant_config)
use_gemm_woq_plugin = args.use_gemm_plugin and args.use_weight_only
load_decoder_weight(tensorrt_llm_whisper_decoder, model_params,
diff --git a/examples/whisper/requirements.txt b/examples/whisper/requirements.txt
index 50c73c31e..7ac655eeb 100644
--- a/examples/whisper/requirements.txt
+++ b/examples/whisper/requirements.txt
@@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com
-tensorrt_llm==0.9.0.dev2024032600
+tensorrt_llm==0.9.0.dev2024040200
tiktoken
datasets
kaldialign
diff --git a/requirements-windows.txt b/requirements-windows.txt
index 4acc11967..4bb4e1038 100644
--- a/requirements-windows.txt
+++ b/requirements-windows.txt
@@ -16,6 +16,9 @@ pandas
h5py
pywin32
sentencepiece>=0.1.99
+# WAR the new posting of "nvidia-cudnn-cu12~=9.0".
+# "tensorrt==9.2.0.post12.dev5" specifies "nvidia-cudnn-cu12" but actually requires "nvidia-cudnn-cu12~=8.9".
+nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64"
tensorrt==9.2.0.post12.dev5
tokenizers>=0.14
# Default torch is CPU-only on Windows, so need to specify a torch version with GPU support
diff --git a/requirements.txt b/requirements.txt
index c3c06e20d..302d299d1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -16,6 +16,9 @@ pandas
h5py
StrEnum
sentencepiece>=0.1.99
+# WAR the new posting of "nvidia-cudnn-cu12~=9.0".
+# "tensorrt==9.3.0.post12.dev1" specifies "nvidia-cudnn-cu12" but actually requires "nvidia-cudnn-cu12~=8.9".
+nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64"
tensorrt==9.3.0.post12.dev1
# https://github.com/pytorch/pytorch/blob/v2.2.1/version.txt still uses 2.2.0a0.
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html#rel-24-02 uses 2.3.0a0.
diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py
index 10c57103d..504d5c7db 100644
--- a/tensorrt_llm/_utils.py
+++ b/tensorrt_llm/_utils.py
@@ -121,11 +121,7 @@ def torch_version():
def str_dtype_to_np(dtype):
- if isinstance(dtype, str):
- ret = _str_to_np_dict.get(dtype)
- else:
- # metadata
- ret = _str_to_np_dict.get(dtype.metadata['dtype'])
+ ret = _str_to_np_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
return ret
diff --git a/tensorrt_llm/layers/moe.py b/tensorrt_llm/layers/moe.py
index 9f7c1d6dd..b6af362ea 100644
--- a/tensorrt_llm/layers/moe.py
+++ b/tensorrt_llm/layers/moe.py
@@ -71,10 +71,6 @@ def has_moe(self) -> bool:
return self.num_experts > 1
-def is_gated_activation(activation_str):
- return activation_str in ("swiglu", "geglu")
-
-
def _moe_plugin(moe_config,
hidden_states,
routing,
diff --git a/tensorrt_llm/models/__init__.py b/tensorrt_llm/models/__init__.py
index 36ee10786..b9e186ee9 100755
--- a/tensorrt_llm/models/__init__.py
+++ b/tensorrt_llm/models/__init__.py
@@ -32,8 +32,6 @@
from .phi.model import PhiForCausalLM, PhiModel
from .qwen.model import QWenForCausalLM
-from .quantized.quant import quantize_model # noqa # isort:skip
-
__all__ = [
'BertModel',
'BertForQuestionAnswering',
@@ -55,7 +53,6 @@
'GPTNeoXForCausalLM',
'PhiModel',
'PhiForCausalLM',
- 'quantize_model',
'ChatGLMForCausalLM',
'ChatGLMModel',
'BaichuanForCausalLM',
diff --git a/tensorrt_llm/models/gemma/model.py b/tensorrt_llm/models/gemma/model.py
index b2ae3990b..008e47104 100644
--- a/tensorrt_llm/models/gemma/model.py
+++ b/tensorrt_llm/models/gemma/model.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import math
from typing import Optional
from ..._utils import pad_vocab_size
@@ -129,7 +128,6 @@ def __init__(self, config: PretrainedConfig) -> None:
self.ln_f = RmsNorm(normalized_shape=config.hidden_size,
eps=config.norm_epsilon,
dtype=config.dtype)
- self.hidden_size = config.hidden_size
def forward(self,
input_ids,
@@ -150,7 +148,6 @@ def forward(self,
if self.mapping.is_first_pp_rank():
hidden_states = self.vocab_embedding(input_ids, *ptuning_args)
- hidden_states = hidden_states * math.sqrt(self.hidden_size)
else:
hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
@@ -185,10 +182,6 @@ def __init__(self, config: PretrainedConfig):
vocab_size_padded = pad_vocab_size(config.vocab_size,
config.mapping.tp_size)
- share_weight = None
- assert config.share_embedding_table, "Gemma only supports share_embedding_table"
- share_weight = transformer.vocab_embedding.weight
-
if config.mapping.is_last_pp_rank():
lm_head = ColumnLinear(config.hidden_size,
vocab_size_padded,
@@ -196,8 +189,7 @@ def __init__(self, config: PretrainedConfig):
dtype=config.dtype,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
- gather_output=True,
- share_weight=share_weight)
+ gather_output=True)
else:
lm_head = None
self.quant_mode = config.quant_mode
diff --git a/tensorrt_llm/models/gemma/smoothquant.py b/tensorrt_llm/models/gemma/smoothquant.py
index e33f5ddd9..3f933fbce 100644
--- a/tensorrt_llm/models/gemma/smoothquant.py
+++ b/tensorrt_llm/models/gemma/smoothquant.py
@@ -17,8 +17,6 @@
repeat_kv)
from transformers.pytorch_utils import Conv1D
-from tensorrt_llm._utils import numpy_to_dtype
-
def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False):
"""
@@ -740,8 +738,8 @@ def create_model_from_config(trt_llm_config, weights):
}
for name in list(weights):
if model_config.dtype == "bfloat16":
- param = torch.from_numpy(numpy_to_dtype(
- weights[name], "float32")).to(torch.bfloat16)
+ param = torch.from_numpy(weights[name].astype(np.float32)).to(
+ torch.bfloat16)
else:
param = torch.from_numpy(weights[name])
weights.pop(name)
@@ -761,9 +759,6 @@ def create_model_from_config(trt_llm_config, weights):
weights[new_name.replace('attention.qkv', 'self_attn.v_proj')] = vw
else:
weights[new_name] = param
-
- if "lm_head.weight" not in weights:
- weights["lm_head.weight"] = weights["model.embed_tokens.weight"].clone()
model.load_state_dict(weights)
return model
diff --git a/tensorrt_llm/models/gemma/weight.py b/tensorrt_llm/models/gemma/weight.py
index 42a56d0cd..00ec715d8 100644
--- a/tensorrt_llm/models/gemma/weight.py
+++ b/tensorrt_llm/models/gemma/weight.py
@@ -16,107 +16,16 @@
import os
import time
from pathlib import Path
-from typing import Dict, List, Optional, Union
+from typing import Union
import numpy as np
import torch
-from ..._utils import (numpy_to_dtype, numpy_to_torch, pad_vocab_size,
- str_dtype_to_torch, torch_to_numpy)
+from ..._utils import (numpy_to_torch, pad_vocab_size, str_dtype_to_torch,
+ torch_to_numpy)
from ...logger import logger
from ...mapping import Mapping
from ...quantization import QuantMode
-from ..quantized.quant import get_dummy_quant_scales
-
-
-def get_scaling_factors(
- model_path: Union[str, Path],
- num_layers: int,
- quant_mode: Optional[QuantMode] = None,
-) -> Optional[Dict[str, List[int]]]:
- """ Get the scaling factors for LLaMA model
-
- Returns a dictionary of scaling factors for the selected layers of the
- LLaMA model.
-
- Args:
- model_path (str): Path to the quantized LLaMA model
- layers (list): List of layers to get the scaling factors for. If None,
- all layers are selected.
-
- Returns:
- dict: Dictionary of scaling factors for the selected layers of the
- LLaMA model.
-
- example:
-
- {
- 'qkv_act': qkv_act_scale,
- 'qkv_weights': qkv_weights_scale,
- 'qkv_output' : qkv_outputs_scale,
- 'dense_act': dense_act_scale,
- 'dense_weights': dense_weights_scale,
- 'fc_act': fc_act_scale,
- 'fc_weights': fc_weights_scale,
- 'gate_act': gate_act_scale,
- 'gate_weights': gate_weights_scale,
- 'proj_act': proj_act_scale,
- 'proj_weights': proj_weights_scale,
- }
- """
-
- if model_path is None:
- logger.warning(f"--quantized_fp8_model_path not specified. "
- f"Initialize quantization scales automatically.")
- return get_dummy_quant_scales(num_layers)
- weight_dict = np.load(model_path)
- # yapf: disable
- scaling_factor = {
- 'qkv_act': [],
- 'qkv_weights': [],
- 'dense_act': [],
- 'dense_weights': [],
- 'fc_act': [],
- 'fc_weights': [],
- 'gate_act': [],
- 'gate_weights': [],
- 'proj_act': [],
- 'proj_weights': [],
- }
-
- if quant_mode is not None and quant_mode.has_fp8_kv_cache():
- scaling_factor['qkv_output'] = []
-
- for layer in range(num_layers):
- scaling_factor['qkv_act'].append(max(
- weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(),
- weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(),
- weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item()
- ))
- scaling_factor['qkv_weights'].append(max(
- weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(),
- weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(),
- weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item()
- ))
- if quant_mode is not None and quant_mode.has_fp8_kv_cache():
- # Not calibrating KV cache.
- scaling_factor['qkv_output'].append(1.0)
- scaling_factor['dense_act'].append(
- weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item())
- scaling_factor['dense_weights'].append(
- weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item())
- scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item())
- scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item())
- scaling_factor['gate_act'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:activation_scaling_factor'].item())
- scaling_factor['gate_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:weights_scaling_factor'].item())
- scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item())
- scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item())
- # yapf: enable
- for k, v in scaling_factor.items():
- assert len(v) == num_layers, \
- f'Expect scaling factor {k} of length {num_layers}, got {len(v)}'
-
- return scaling_factor
def gen_suffix(rank, use_smooth_quant, quant_per_channel):
@@ -760,8 +669,8 @@ def load_from_hf_gemma(tensorrt_llm_llama: 'GemmaForCausalLM',
def quantize_fp8_weights(weights, num_layers, mapping):
def get_scaling_factor(weight):
- scale = torch_to_numpy(448.0 / numpy_to_torch(weight).max()).reshape(
- [-1])
+ amax = weight.max()
+ scale = 448.0 / amax
return scale
layers_range = mapping.pp_layers(num_layers)
@@ -781,13 +690,9 @@ def get_scaling_factor(weight):
dtype = weights[trt_llm_name].dtype
scale = get_scaling_factor(weight)
scaled_weights[trt_llm_name] = np.ascontiguousarray(
- numpy_to_dtype(
- torch_to_numpy(
- numpy_to_torch(weight).to(torch.float32) *
- numpy_to_torch(scale).to(torch.float32)), dtype))
- scaling_factors[scale_name] = torch_to_numpy(
- (1 / numpy_to_torch(scale)).to(torch.float32))
-
+ (weight * scale).astype(dtype))
+ scaling_factors[scale_name] = np.asarray([1 / scale
+ ]).astype(np.float32)
return scaling_factors
diff --git a/tensorrt_llm/models/llama/weight.py b/tensorrt_llm/models/llama/weight.py
index 98e53bd45..3724f0f71 100644
--- a/tensorrt_llm/models/llama/weight.py
+++ b/tensorrt_llm/models/llama/weight.py
@@ -15,7 +15,7 @@
import configparser
import time
from pathlib import Path
-from typing import Dict, List, Optional, Union
+from typing import List, Union
import numpy as np
import torch
@@ -28,101 +28,10 @@
from ...mapping import Mapping
from ...quantization import QuantMode
from ..modeling_utils import PretrainedConfig
-from ..quantized.quant import get_dummy_quant_scales
from .utils import (iterate_shard_files, load_state_dict,
retrieved_layer_index_from_name)
-def get_scaling_factors(
- model_path: Union[str, Path],
- num_layers: int,
- quant_mode: Optional[QuantMode] = None,
-) -> Optional[Dict[str, List[int]]]:
- """ Get the scaling factors for LLaMA model
-
- Returns a dictionary of scaling factors for the selected layers of the
- LLaMA model.
-
- Args:
- model_path (str): Path to the quantized LLaMA model
- layers (list): List of layers to get the scaling factors for. If None,
- all layers are selected.
-
- Returns:
- dict: Dictionary of scaling factors for the selected layers of the
- LLaMA model.
-
- example:
-
- {
- 'qkv_act': qkv_act_scale,
- 'qkv_weights': qkv_weights_scale,
- 'qkv_output' : qkv_outputs_scale,
- 'dense_act': dense_act_scale,
- 'dense_weights': dense_weights_scale,
- 'fc_act': fc_act_scale,
- 'fc_weights': fc_weights_scale,
- 'gate_act': gate_act_scale,
- 'gate_weights': gate_weights_scale,
- 'proj_act': proj_act_scale,
- 'proj_weights': proj_weights_scale,
- }
- """
-
- if model_path is None:
- logger.warning(f"--quantized_fp8_model_path not specified. "
- f"Initialize quantization scales automatically.")
- return get_dummy_quant_scales(num_layers)
- weight_dict = np.load(model_path)
- # yapf: disable
- scaling_factor = {
- 'qkv_act': [],
- 'qkv_weights': [],
- 'dense_act': [],
- 'dense_weights': [],
- 'fc_act': [],
- 'fc_weights': [],
- 'gate_act': [],
- 'gate_weights': [],
- 'proj_act': [],
- 'proj_weights': [],
- }
-
- if quant_mode is not None and quant_mode.has_fp8_kv_cache():
- scaling_factor['qkv_output'] = []
-
- for layer in range(num_layers):
- scaling_factor['qkv_act'].append(max(
- weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(),
- weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(),
- weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item()
- ))
- scaling_factor['qkv_weights'].append(max(
- weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(),
- weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(),
- weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item()
- ))
- if quant_mode is not None and quant_mode.has_fp8_kv_cache():
- # Not calibrating KV cache.
- scaling_factor['qkv_output'].append(1.0)
- scaling_factor['dense_act'].append(
- weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item())
- scaling_factor['dense_weights'].append(
- weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item())
- scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item())
- scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item())
- scaling_factor['gate_act'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:activation_scaling_factor'].item())
- scaling_factor['gate_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:weights_scaling_factor'].item())
- scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item())
- scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item())
- # yapf: enable
- for k, v in scaling_factor.items():
- assert len(v) == num_layers, \
- f'Expect scaling factor {k} of length {num_layers}, got {len(v)}'
-
- return scaling_factor
-
-
def gen_suffix(rank, use_smooth_quant, quant_per_channel):
suffix = f"{rank}.bin"
if use_smooth_quant:
diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py
index 5f2ba3853..c1d6e3646 100644
--- a/tensorrt_llm/models/modeling_utils.py
+++ b/tensorrt_llm/models/modeling_utils.py
@@ -1028,14 +1028,6 @@ def load_model(
logger.warning(
f"Cannot find {model_path}. Use dummy model weights.")
- if model_config.share_embedding_table:
- if "lm_head.weight" in weights and "transformer.vocab_embedding.weight" in weights:
- assert not (
- weights["lm_head.weight"] -
- weights["transformer.vocab_embedding.weight"]
- ).any(
- ), "When share_embedding_table is enabled, lm_head.weight and transformer.vocab_embedding.weight must be same."
-
# Currently, use_parallel_embedding and share_embedding_table should be enabled before weight loading;
# otherwise, the model will be inconsistent with the weights loaded from checkpoint.
model = optimize_model(
diff --git a/tensorrt_llm/models/quantized/__init__.py b/tensorrt_llm/models/quantized/__init__.py
deleted file mode 100644
index 71bf6d298..000000000
--- a/tensorrt_llm/models/quantized/__init__.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/tensorrt_llm/models/quantized/ammo.py b/tensorrt_llm/models/quantized/ammo.py
deleted file mode 100644
index 368901cb2..000000000
--- a/tensorrt_llm/models/quantized/ammo.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from pathlib import Path
-from typing import Dict, Literal, Optional, Union
-
-import torch
-from torch.utils.data import DataLoader
-
-try:
- import ammo.torch.quantization as atq
- from ammo.torch.export import export_model_config
-except ImportError:
- raise ImportError("AMMO toolkit is not installed. Please install it first.")
-
-from ...logger import logger
-
-
-def _register_falcon_linears(model):
- """Register Falcon linear modules as Quantization.
-
- As falcon models could use remote code, which will be loaded dynamically,
- to build their model. Therefore, we need to register the linear on the fly
- before quantization.
-
- """
- if type(model).__name__ in ["RWForCausalLM", "FalconForCausalLM"]:
- from ammo.torch.quantization import tensor_quant
- from ammo.torch.quantization.nn.modules.quant_module import \
- QuantLinearConvBase
-
- linear_type = type(model.transformer.h[0].self_attention.dense)
-
- class QuantFalconLinearRW1B(linear_type,
- QuantLinearConvBase): # type: ignore
- default_quant_desc_weight = tensor_quant.QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW
-
- atq.module_mapping.QUANT_MODULE_MAPPING[
- linear_type] = QuantFalconLinearRW1B.convert
-
-
-def _quantize_model(model: torch.nn.Module,
- qformat: Literal['fp8', 'int8_sq', 'int4_awq', 'w4a8_awq'],
- calib_dataloader: DataLoader,
- quant_cfg_dict: Optional[Dict] = None) -> torch.nn.Module:
- assert qformat in ['fp8', 'int8_sq', 'int4_awq', 'w4a8_awq'], \
- f'Got unsupported AMMO quantization format, {qformat} '
- if qformat == "fp8":
- quant_cfg = atq.FP8_DEFAULT_CFG
- elif qformat == "int8_sq":
- quant_cfg = atq.INT8_SMOOTHQUANT_CFG
- elif qformat == "int4_awq":
- quant_cfg = atq.INT4_AWQ_CFG
- elif qformat == "w4a8_awq":
- quant_cfg = atq.W4A8_AWQ_BETA_CFG
- else:
- raise ValueError(f"Unsupported quantization format: {qformat}")
-
- if quant_cfg_dict:
- for name, cfg in quant_cfg_dict.items():
- quant_cfg['quant_cfg'][name] = cfg
-
- def calibrate_loop():
- """Adjusts weights and scaling factors based on selected algorithms."""
- for idx, data in enumerate(calib_dataloader):
- logger.debug(f"Calibrating batch {idx}")
- model(data)
-
- _register_falcon_linears(model)
-
- logger.debug("Starting quantization...")
- print(quant_cfg)
- atq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
- logger.debug("Quantization done")
- return model
-
-
-def quantize_and_export(
- model: torch.nn.Module,
- qformat: Literal['fp8', 'int8_sq', 'int4_awq', 'w4a8_awq'],
- calib_dataloader: DataLoader,
- export_path: Optional[Union[str, Path]] = None,
- tensor_parallel_size: int = 1,
- quant_cfg_dict: Optional[Dict] = None) -> torch.nn.Module:
-
- model_cls_name = type(model).__name__
- model_lookup = {
- ("xverse", ): "xverse",
- ("llama", "mistral"): "llama",
- ("gptj", ): "gptj",
- ("falcon", "rw"): "falcon",
- ("baichuan", ): "baichuan",
- ("mpt", ): "mpt",
- ("gpt2", ): "gpt2",
- ("chatglm", ): "chatglm",
- ("qwen", ): "qwen",
- }
- for templates, model_type_target in model_lookup.items():
- if any(t in model_cls_name.lower() for t in templates):
- model_type = model_type_target
- break
- else:
- raise NotImplementedError(
- f"Deploying quantized model {model_cls_name} is not supported")
-
- model = _quantize_model(model,
- qformat=qformat,
- calib_dataloader=calib_dataloader,
- quant_cfg_dict=quant_cfg_dict)
-
- if export_path:
- with torch.inference_mode():
- if qformat == "int4_awq" and model_type == "qwen" or \
- model_type == "chatglm":
- torch.save(model.state_dict(), export_path)
- else:
- export_model_config(
- model,
- model_type,
- torch.float16,
- export_dir=export_path,
- inference_tensor_parallel=tensor_parallel_size,
- export_npz=True,
- )
- logger.info(f"Quantized model exported to :{export_path}")
- return model
diff --git a/tensorrt_llm/models/quantized/quant.py b/tensorrt_llm/models/quantized/quant.py
deleted file mode 100644
index fea6abf0e..000000000
--- a/tensorrt_llm/models/quantized/quant.py
+++ /dev/null
@@ -1,550 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import dataclasses
-from typing import Any
-
-import numpy as np
-
-from ...layers import ColumnLinear, RowLinear
-from ...module import Module
-from ...quantization import QuantMode
-from ...quantization.layers import FP8Linear, FP8RowLinear
-from ...quantization.mode import QuantAlgo
-from ...quantization.quantize import weight_only_quantize
-from ..modeling_utils import QuantConfig
-
-# isort: off
-from ...quantization.layers import (SmoothQuantAttention, SmoothQuantGatedMLP,
- SmoothQuantLayerNorm, SmoothQuantMLP,
- SmoothQuantRmsNorm,
- WeightOnlyGroupwiseQuantColumnLinear,
- WeightOnlyGroupwiseQuantRowLinear)
-# isort: on
-
-
-def _smooth_quantize_gpt(model, quant_mode):
- assert quant_mode.has_act_and_weight_quant()
- for layer_idx, layer in enumerate(model.layers):
- assert hasattr(layer,
- "input_layernorm"), "The layer has no input_layernorm"
- layer.input_layernorm = SmoothQuantLayerNorm(
- normalized_shape=layer.hidden_size,
- dtype=layer.dtype,
- quant_mode=quant_mode)
- assert hasattr(layer, "attention"), "The layer has no attention"
- layer.attention = SmoothQuantAttention(
- layer_idx=layer_idx,
- hidden_size=layer.hidden_size,
- num_attention_heads=layer.num_attention_heads,
- num_kv_heads=layer.attention.num_attention_kv_heads * layer.tp_size,
- max_position_embeddings=layer.max_position_embeddings,
- num_layers=layer.num_layers,
- apply_query_key_layer_scaling=layer.apply_query_key_layer_scaling,
- dtype=layer.dtype,
- attention_mask_type=layer.attention_mask_type,
- bias=(layer.attention.dense.bias != None),
- qkv_bias_only=(layer.attention.qkv.bias != None
- and layer.attention.dense.bias == None),
- position_embedding_type=layer.position_embedding_type,
- tp_group=layer.tp_group,
- tp_size=layer.tp_size,
- tp_rank=layer.attention.tp_rank,
- quant_mode=quant_mode)
- assert hasattr(layer, "mlp"), "The layer has no mlp"
- layer.mlp = SmoothQuantMLP(hidden_size=layer.hidden_size,
- ffn_hidden_size=layer.hidden_size * 4,
- hidden_act=layer.hidden_act,
- bias=(layer.mlp.fc.bias != None),
- dtype=layer.dtype,
- tp_group=layer.tp_group,
- tp_size=layer.tp_size,
- quant_mode=quant_mode)
- assert hasattr(layer,
- "post_layernorm"), "The layer has no post_layernorm"
- layer.post_layernorm = SmoothQuantLayerNorm(
- normalized_shape=layer.hidden_size,
- dtype=layer.dtype,
- quant_mode=quant_mode)
-
- return model
-
-
-def _smooth_quantize_llama(model, quant_mode):
- assert quant_mode.has_act_and_weight_quant()
- for layer_idx, layer in enumerate(model.layers):
- assert hasattr(layer,
- "input_layernorm"), "The layer has no input_layernorm"
- layer.input_layernorm = SmoothQuantRmsNorm(
- normalized_shape=layer.hidden_size,
- dtype=layer.dtype,
- quant_mode=quant_mode)
- assert hasattr(layer, "attention"), "The layer has no attention"
- layer.attention = SmoothQuantAttention(
- layer_idx=layer_idx,
- hidden_size=layer.hidden_size,
- num_attention_heads=layer.num_attention_heads,
- num_kv_heads=layer.num_kv_heads,
- max_position_embeddings=layer.max_position_embeddings,
- num_layers=model.num_layers,
- dtype=layer.dtype,
- attention_mask_type=layer.attention_mask_type,
- position_embedding_type=layer.position_embedding_type,
- rotary_embedding_base=layer.attention.rotary_embedding_base,
- tp_group=layer.tp_group,
- tp_size=layer.tp_size,
- quant_mode=quant_mode,
- bias=layer.attention.qkv.bias is not None)
-
- assert hasattr(layer, "mlp"), "The layer has no mlp"
- if hasattr(model, "moe_config"):
- assert not model.moe_config.has_moe(
- ), "MOE does not support smooth quant"
- layer.mlp = SmoothQuantGatedMLP(hidden_size=model.hidden_size,
- ffn_hidden_size=layer.mlp_hidden_size,
- hidden_act=layer.hidden_act,
- dtype=layer.dtype,
- tp_group=layer.tp_group,
- tp_size=layer.tp_size,
- quant_mode=quant_mode,
- bias=layer.mlp.fc.bias is not None)
- assert hasattr(layer,
- "post_layernorm"), "The layer has no post_layernorm"
- layer.post_layernorm = SmoothQuantRmsNorm(
- normalized_shape=layer.hidden_size,
- dtype=layer.dtype,
- quant_mode=quant_mode)
-
- return model
-
-
-def _smooth_quantize_bloom(model, quant_mode):
- assert quant_mode.has_act_and_weight_quant()
- for layer_idx, layer in enumerate(model.layers):
- assert hasattr(layer,
- "input_layernorm"), "The layer has no input_layernorm"
- layer.input_layernorm = SmoothQuantLayerNorm(
- normalized_shape=layer.hidden_size,
- dtype=layer.dtype,
- quant_mode=quant_mode)
- assert hasattr(layer, "attention"), "The layer has no attention"
- layer.attention = SmoothQuantAttention(
- layer_idx=layer_idx,
- hidden_size=layer.hidden_size,
- num_attention_heads=layer.num_attention_heads,
- max_position_embeddings=layer.max_position_embeddings,
- num_layers=layer.num_layers,
- dtype=layer.dtype,
- attention_mask_type=layer.attention_mask_type,
- position_embedding_type=layer.position_embedding_type,
- tp_group=layer.tp_group,
- tp_size=layer.tp_size,
- tp_rank=layer.tp_rank,
- quant_mode=quant_mode)
-
- assert hasattr(layer, "mlp"), "The layer has no mlp"
- layer.mlp = SmoothQuantMLP(hidden_size=layer.hidden_size,
- ffn_hidden_size=layer.hidden_size * 4,
- hidden_act=layer.hidden_act,
- dtype=layer.dtype,
- tp_group=layer.tp_group,
- tp_size=layer.tp_size,
- quant_mode=quant_mode)
- assert hasattr(layer,
- "post_layernorm"), "The layer has no post_layernorm"
- layer.post_layernorm = SmoothQuantLayerNorm(
- normalized_shape=layer.hidden_size,
- dtype=layer.dtype,
- quant_mode=quant_mode)
-
- setattr(model, 'quant_mode', quant_mode)
- return model
-
-
-def _smooth_quantize_baichuan(model, quant_mode):
- assert quant_mode.has_act_and_weight_quant()
- for layer_idx, layer in enumerate(model.layers):
- assert hasattr(layer,
- "input_layernorm"), "The layer has no input_layernorm"
- layer.input_layernorm = SmoothQuantRmsNorm(
- normalized_shape=layer.hidden_size,
- dtype=layer.dtype,
- quant_mode=quant_mode)
- assert hasattr(layer, "attention"), "The layer has no attention"
- layer.attention = SmoothQuantAttention(
- layer_idx=layer_idx,
- hidden_size=layer.hidden_size,
- num_attention_heads=layer.num_attention_heads,
- num_kv_heads=layer.num_kv_heads,
- max_position_embeddings=layer.max_position_embeddings,
- num_layers=model.num_layers,
- dtype=layer.dtype,
- attention_mask_type=layer.attention_mask_type,
- position_embedding_type=layer.position_embedding_type,
- rotary_embedding_base=layer.attention.rotary_embedding_base,
- tp_group=layer.tp_group,
- tp_size=layer.tp_size,
- tp_rank=layer.tp_rank,
- quant_mode=quant_mode,
- bias=layer.attention.qkv.bias is not None)
-
- assert hasattr(layer, "mlp"), "The layer has no mlp"
- if hasattr(model, "moe_config"):
- assert not model.moe_config.has_moe(
- ), "MOE does not support smooth quant"
- layer.mlp = SmoothQuantGatedMLP(hidden_size=model.hidden_size,
- ffn_hidden_size=layer.mlp_hidden_size,
- hidden_act=layer.hidden_act,
- dtype=layer.dtype,
- tp_group=layer.tp_group,
- tp_size=layer.tp_size,
- quant_mode=quant_mode,
- bias=layer.mlp.fc.bias is not None)
- assert hasattr(layer,
- "post_layernorm"), "The layer has no post_layernorm"
- layer.post_layernorm = SmoothQuantRmsNorm(
- normalized_shape=layer.hidden_size,
- dtype=layer.dtype,
- quant_mode=quant_mode)
-
- return model
-
-
-def _smooth_quantize_internlm(model, quant_mode):
- assert quant_mode.has_act_and_weight_quant()
- for layer_idx, layer in enumerate(model.layers):
- assert hasattr(layer,
- "input_layernorm"), "The layer has no input_layernorm"
- layer.input_layernorm = SmoothQuantRmsNorm(
- normalized_shape=layer.hidden_size,
- dtype=layer.dtype,
- quant_mode=quant_mode)
- assert hasattr(layer, "attention"), "The layer has no attention"
- layer.attention = SmoothQuantAttention(
- layer_idx=layer_idx,
- hidden_size=layer.hidden_size,
- num_attention_heads=layer.num_attention_heads,
- num_kv_heads=layer.num_kv_heads,
- max_position_embeddings=layer.max_position_embeddings,
- num_layers=model.num_layers,
- dtype=layer.dtype,
- attention_mask_type=layer.attention_mask_type,
- position_embedding_type=layer.position_embedding_type,
- tp_group=layer.tp_group,
- tp_size=layer.tp_size,
- quant_mode=quant_mode,
- bias=model.attn_bias)
-
- assert hasattr(layer, "mlp"), "The layer has no mlp"
- layer.mlp = SmoothQuantGatedMLP(hidden_size=model.hidden_size,
- ffn_hidden_size=layer.mlp_hidden_size,
- hidden_act=layer.hidden_act,
- dtype=layer.dtype,
- tp_group=layer.tp_group,
- tp_size=layer.tp_size,
- quant_mode=quant_mode,
- bias=False)
- assert hasattr(layer,
- "post_layernorm"), "The layer has no post_layernorm"
- layer.post_layernorm = SmoothQuantRmsNorm(
- normalized_shape=layer.hidden_size,
- dtype=layer.dtype,
- quant_mode=quant_mode)
-
- setattr(model, 'quant_mode', quant_mode)
- return model
-
-
-def _smooth_quantize_qwen(model, quant_mode):
- assert quant_mode.has_act_and_weight_quant()
- for layer_idx, layer in enumerate(model.layers):
- assert hasattr(layer, "ln_1"), "The layer has no ln_1"
- layer.ln_1 = SmoothQuantRmsNorm(normalized_shape=layer.hidden_size,
- dtype=layer.dtype,
- quant_mode=quant_mode)
- assert hasattr(layer, "attention"), "The layer has no attention"
- layer.attention = SmoothQuantAttention(
- layer_idx=layer_idx,
- hidden_size=layer.hidden_size,
- num_attention_heads=layer.num_attention_heads,
- max_position_embeddings=layer.max_position_embeddings,
- num_layers=layer.num_layers,
- apply_query_key_layer_scaling=layer.apply_query_key_layer_scaling,
- attention_mask_type=layer.attention_mask_type,
- bias=layer.bias,
- qkv_bias_only=True,
- dtype=layer.dtype,
- position_embedding_type=layer.position_embedding_type,
- tp_group=layer.tp_group,
- tp_size=layer.tp_size,
- quant_mode=quant_mode)
- assert hasattr(layer, "mlp"), "The layer has no mlp"
- layer.mlp = SmoothQuantGatedMLP(hidden_size=layer.hidden_size,
- ffn_hidden_size=layer.mlp_hidden_size //
- 2,
- hidden_act=layer.hidden_act,
- dtype=layer.dtype,
- bias=layer.bias,
- tp_group=layer.tp_group,
- tp_size=layer.tp_size,
- quant_mode=quant_mode)
- assert hasattr(layer, "ln_2"), "The layer has no ln_2"
- layer.ln_2 = SmoothQuantRmsNorm(normalized_shape=layer.hidden_size,
- dtype=layer.dtype,
- quant_mode=quant_mode)
-
- setattr(model, 'quant_mode', quant_mode)
- return model
-
-
-def _smooth_quantize_chatglm(model, quant_mode):
- assert quant_mode.has_act_and_weight_quant()
- for layer_idx, layer in enumerate(model.layers):
- assert hasattr(layer,
- "input_layernorm"), "The layer has no input_layernorm"
- layer.input_layernorm = SmoothQuantRmsNorm(
- normalized_shape=layer.hidden_size,
- dtype=layer.dtype,
- quant_mode=quant_mode,
- )
- assert hasattr(layer, "attention"), "The layer has no attention"
- layer.attention = SmoothQuantAttention(
- layer_idx=layer_idx,
- hidden_size=layer.hidden_size,
- num_attention_heads=layer.num_heads,
- num_kv_heads=layer.num_kv_heads,
- max_position_embeddings=layer.max_seq_length,
- num_layers=layer.num_layers,
- apply_query_key_layer_scaling=layer.apply_query_key_layer_scaling,
- dtype=layer.dtype,
- attention_mask_type=layer.attention_mask_type,
- position_embedding_type=layer.position_embedding_type,
- rotary_embedding_base=layer.rotary_embedding_base,
- tp_group=layer.tp_group,
- tp_size=layer.tp_size,
- quant_mode=quant_mode,
- bias=layer.dense_bias,
- qkv_bias_only=layer.bias and not layer.dense_bias,
- )
- assert hasattr(layer, "mlp"), "The layer has no mlp"
- layer.mlp = SmoothQuantMLP(
- hidden_size=layer.hidden_size,
- ffn_hidden_size=layer.ffn_hidden_size,
- hidden_act=layer.hidden_act,
- dtype=layer.dtype,
- tp_group=layer.tp_group,
- tp_size=layer.tp_size,
- quant_mode=quant_mode,
- bias=layer.dense_bias,
- )
- assert hasattr(layer,
- "post_layernorm"), "The layer has no post_layernorm"
- layer.post_layernorm = SmoothQuantRmsNorm(
- normalized_shape=layer.hidden_size,
- dtype=layer.dtype,
- quant_mode=quant_mode,
- )
- return model
-
-
-def _smooth_quantize(model, quant_mode):
- from ...models import (BaichuanForCausalLM, BloomForCausalLM,
- ChatGLMForCausalLM, GPTForCausalLM, LLaMAForCausalLM,
- QWenForCausalLM)
- assert isinstance(model, GPTForCausalLM) or isinstance(model, LLaMAForCausalLM) \
- or isinstance(model, BloomForCausalLM) or isinstance(model, BaichuanForCausalLM) \
- or isinstance(model, QWenForCausalLM) or isinstance(model, ChatGLMForCausalLM), \
- "Only GPTForCausalLM, LLaMAForCausalLM BloomForCausalLM and BaichuanForCausalLM are well tested now"
- if isinstance(model, GPTForCausalLM):
- return _smooth_quantize_gpt(model, quant_mode)
- elif isinstance(model, LLaMAForCausalLM):
- return _smooth_quantize_llama(model, quant_mode)
- elif isinstance(model, BloomForCausalLM):
- return _smooth_quantize_bloom(model, quant_mode)
- elif isinstance(model, BaichuanForCausalLM):
- return _smooth_quantize_baichuan(model, quant_mode)
- elif isinstance(model, QWenForCausalLM):
- return _smooth_quantize_qwen(model, quant_mode)
- elif isinstance(model, ChatGLMForCausalLM):
- return _smooth_quantize_chatglm(model, quant_mode)
- else:
- assert False, f"Model {type(model).__name__} is not supported by SmoothQuant yet"
-
-
-def _weight_only_groupwise_quantize(model,
- quant_mode,
- group_size=128,
- pre_quant_scale=False,
- zero=False,
- weight_only_precision="int4_awq",
- exclude_modules=None,
- current_key_name=None):
- exclude_modules = ['lm_head'
- ] if exclude_modules is None else exclude_modules
-
- for name, module in model.named_children():
- if current_key_name is None:
- current_key_name = []
- current_key_name.append(name)
-
- if len(list(module.children())) > 0:
- _weight_only_groupwise_quantize(module, quant_mode, group_size,
- pre_quant_scale, zero,
- weight_only_precision,
- exclude_modules, current_key_name)
-
- if isinstance(module, ColumnLinear) and name not in exclude_modules:
- if not any(key in '.'.join(current_key_name)
- for key in exclude_modules):
- model._modules[name] = WeightOnlyGroupwiseQuantColumnLinear(
- in_features=module.in_features,
- out_features=module.out_features * module.tp_size,
- group_size=group_size,
- pre_quant_scale=pre_quant_scale,
- zero=zero,
- bias=module.bias is not None,
- use_w4a8_awq=weight_only_precision == 'w4a8_awq',
- dtype=module.dtype,
- tp_group=module.tp_group,
- tp_size=module.tp_size,
- gather_output=module.gather_output)
- elif isinstance(module, RowLinear) and name not in exclude_modules:
- if not any(key in '.'.join(current_key_name)
- for key in exclude_modules):
- model._modules[name] = WeightOnlyGroupwiseQuantRowLinear(
- in_features=module.in_features * module.tp_size,
- out_features=module.out_features,
- group_size=group_size,
- pre_quant_scale=pre_quant_scale,
- zero=zero,
- bias=module.bias is not None,
- use_w4a8_awq=weight_only_precision == 'w4a8_awq',
- dtype=module.dtype,
- tp_group=module.tp_group,
- tp_size=module.tp_size)
-
- current_key_name.pop(-1)
-
- return model
-
-
-def quantize_model(model: Module, quant_mode: QuantMode, **kwargs: Any):
- if quant_mode.is_weight_only():
- if quant_mode.has_per_group_scaling():
- model = _weight_only_groupwise_quantize(model, quant_mode, **kwargs)
- else:
- quant_algo = QuantAlgo.W4A16 if quant_mode.is_int4_weight_only(
- ) else QuantAlgo.W8A16
- model = weight_only_quantize(
- model, dataclasses.replace(QuantConfig(quant_algo), **kwargs))
- elif quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache():
- model = _fp8_quantize(model, quant_mode, **kwargs)
- elif quant_mode.has_act_and_weight_quant():
- model = _smooth_quantize(model, quant_mode)
-
- setattr(model, "quant_mode", quant_mode)
- return model
-
-
-def get_dummy_quant_scales(num_layers):
- return {
- 'lm_head_act': 0.99,
- 'lm_head_weights': 0.99,
- 'fc_act': [0.99 for _ in range(num_layers)],
- 'fc_weights': [0.99 for _ in range(num_layers)],
- 'gate_act': [0.99 for _ in range(num_layers)],
- 'gate_weights': [0.99 for _ in range(num_layers)],
- 'proj_act': [0.99 for _ in range(num_layers)],
- 'proj_weights': [0.99 for _ in range(num_layers)],
- 'qkv_act': [0.99 for _ in range(num_layers)],
- 'qkv_weights': [0.99 for _ in range(num_layers)],
- 'qkv_output': [1.0 for _ in range(num_layers)],
- 'dense_act': [0.99 for _ in range(num_layers)],
- 'dense_weights': [0.99 for _ in range(num_layers)],
- }
-
-
-def _quantize_layer(layer, layer_idx, quant_mode, quant_scales):
- assert hasattr(layer, "mlp"), "The layer has no mlp"
- fake_fp8_sf_dt = np.float32
-
- assert isinstance(layer.mlp.fc, (FP8Linear, FP8RowLinear))
- assert isinstance(layer.mlp.proj, (FP8Linear, FP8RowLinear))
- layer.mlp.fc.activation_scaling_factor.value = np.array(
- [quant_scales['fc_act'][layer_idx]], dtype=fake_fp8_sf_dt)
- layer.mlp.fc.weights_scaling_factor.value = np.array(
- [quant_scales['fc_weights'][layer_idx]], dtype=fake_fp8_sf_dt)
- layer.mlp.proj.activation_scaling_factor.value = np.array(
- [quant_scales['proj_act'][layer_idx]], dtype=fake_fp8_sf_dt)
- layer.mlp.proj.weights_scaling_factor.value = np.array(
- [quant_scales['proj_weights'][layer_idx]], dtype=fake_fp8_sf_dt)
- if hasattr(layer.mlp, 'gate'):
- assert isinstance(layer.mlp.gate, (FP8Linear, FP8RowLinear))
- layer.mlp.gate.activation_scaling_factor.value = np.array(
- [quant_scales['gate_act'][layer_idx]], dtype=fake_fp8_sf_dt)
- layer.mlp.gate.weights_scaling_factor.value = np.array(
- [quant_scales['gate_weights'][layer_idx]], dtype=fake_fp8_sf_dt)
-
- assert hasattr(layer, "attention"), "The layer has no attention"
- assert isinstance(layer.attention.qkv, (FP8Linear, FP8RowLinear))
- assert isinstance(layer.attention.dense, (FP8Linear, FP8RowLinear))
- layer.attention.qkv.activation_scaling_factor.value = np.array(
- [quant_scales['qkv_act'][layer_idx]], dtype=fake_fp8_sf_dt)
- layer.attention.qkv.weights_scaling_factor.value = np.array(
- [quant_scales['qkv_weights'][layer_idx]], dtype=fake_fp8_sf_dt)
- if quant_mode.has_fp8_kv_cache():
- layer.attention.kv_cache_scaling_factor.value = np.array(
- [quant_scales['qkv_output'][layer_idx]], dtype=fake_fp8_sf_dt)
- layer.attention.dense.activation_scaling_factor.value = np.array(
- [quant_scales['dense_act'][layer_idx]], dtype=fake_fp8_sf_dt)
- layer.attention.dense.weights_scaling_factor.value = np.array(
- [quant_scales['dense_weights'][layer_idx]], dtype=fake_fp8_sf_dt)
-
- return layer
-
-
-def _default_fp8_quantize(model,
- quant_mode: QuantMode,
- quant_scales: dict = None):
- """
- Quantize all linear layers (i.e., MLP, Attention QKV/Dense) and KV cache IO with dummy scales
- This is used by benchmark script and therefore is intentionally decoupled from AMMO toolkit
- """
- if quant_scales is None:
- num_layers = getattr(model, '_num_layers',
- getattr(model, 'num_layers', None))
- assert num_layers is not None
- quant_scales = get_dummy_quant_scales(num_layers)
-
- assert model.quant_mode == quant_mode, "Quant setting not consistent with model init setting"
-
- use_fp8_qdq = quant_mode.has_fp8_qdq()
- assert use_fp8_qdq
-
- for layer_idx, layer in enumerate(model.layers):
- layer = _quantize_layer(layer, layer_idx, quant_mode, quant_scales)
-
- # TODO: add lm_head
-
- return model
-
-
-def _fp8_quantize(model, quant_mode: QuantMode, quant_scales: dict = None):
- from ...models import (BaichuanForCausalLM, FalconForCausalLM,
- GPTForCausalLM, GPTJForCausalLM, LLaMAForCausalLM)
- if isinstance(model, (FalconForCausalLM, GPTJForCausalLM, GPTForCausalLM,
- LLaMAForCausalLM, BaichuanForCausalLM)):
- return _default_fp8_quantize(model, quant_mode, quant_scales)
- raise NotImplementedError(
- f"Model {model} is not implemented by fp8_quantize yet")
diff --git a/tensorrt_llm/quantization/quantize_by_ammo.py b/tensorrt_llm/quantization/quantize_by_ammo.py
index 8b309214c..3eaf331ff 100644
--- a/tensorrt_llm/quantization/quantize_by_ammo.py
+++ b/tensorrt_llm/quantization/quantize_by_ammo.py
@@ -318,18 +318,13 @@ def quantize_and_export(*, model_dir, dtype, device, qformat, kv_cache_dtype,
export_path = output_dir
start_time = time.time()
- export_npz = (model_type not in [
- 'gpt2', 'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan',
- 'gemma', 'qwen'
- ])
export_model_config(model,
model_type,
getattr(torch, dtype),
export_dir=export_path,
inference_tensor_parallel=tp_size,
inference_pipeline_parallel=pp_size,
- export_tensorrt_llm_config=(not export_npz),
- export_npz=export_npz)
+ export_tensorrt_llm_config=True)
# Workaround for wo quantization
if qformat in ["int8_wo", "int4_wo", "full_prec"]:
diff --git a/tensorrt_llm/version.py b/tensorrt_llm/version.py
index 4dd337898..6f2cf0654 100644
--- a/tensorrt_llm/version.py
+++ b/tensorrt_llm/version.py
@@ -12,4 +12,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "0.9.0.dev2024032600"
+__version__ = "0.9.0.dev2024040200"