Skip to content

Commit 4ce1bdf

Browse files
authored
[Quantization] Attention/ KV Cache Refactor (#1651)
## Purpose ## * Support fully-expressive attention and kv cache quantization * Support running kv cache quantization evals with hf transformers * Resolves #1949 * Resolves #1928 ```python3 recipe = QuantizationModifier( config_groups={ "attention": QuantizationScheme( targets=["LlamaAttention"], input_activations=QuantizationArgs( num_bits=8, type="float", strategy="tensor" ), ) } ) ``` ```json { "quantization_config": { "config_groups": { "group_0": { "format": null, "input_activations": { "dynamic": false, "num_bits": 8, "observer": "minmax", "strategy": "tensor", "symmetric": true, "type": "float" }, "output_activations": null, "targets": [ "LlamaAttention" ], "weights": null } }, "format": "dense", "ignore": [], "kv_cache_scheme": { "dynamic": false, "group_size": null, "num_bits": 8, "observer": "minmax", "strategy": "tensor", "symmetric": true, "type": "float" }, "quant_method": "compressed-tensors", "quantization_status": "frozen", }, } ``` ## Prerequisites ## * Must be merged at the same time as vllm-project/compressed-tensors#436 ## Changes ## * Replace hooks * Remove `calibrate_kv_cache_input_hook`, `calibrate_kv_cache_output_hook`, `initialize_quantized_kv_cache` * Add `calibrate_query_hook` `calibrate_key_hook`, `calibrate_value_hook` * QuantizationMixin now initializes "q", "k", and "v" obsevers ([depending on the attached submodules](https://github.com/vllm-project/llm-compressor/pull/1651/files#diff-33303ae48e185b2fbb14dc45c2052805837deb5723248367b9579321c4c4e974R263-R270)) and adds the appropriate hooks * Miscellaneous * Fix minor shape bug in `_flatten_attention` * Add support for "attn_head" strategy in `_flatten_attention` * Tests * Removed old QuantizationKVCache tests (these classes are now tested [here])(https://github.com/neuralmagic/compressed-tensors/pull/436/files#diff-6e33ff48047dc4f7c9d969293f87e32e4d5ec3f3e8b741ea757780c8c0aab775) * Updated scale names to avoid using enum * Avoid unnecessary tokenization to reduce runtime ## Testing ## * Kv cache regression tests pass * Able to quantize attention with scripts (will add to examples once loadable in vllm) * kylesayrs/Llama-3.2-1B-Instruct-attention-fp8-head * kylesayrs/Llama-3.2-1B-Instruct-attention-nvfp4-head * Nightly passes (in progress) ## Evaluation ## <details><summary>eval.py</summary> ```python import sys import lm_eval model_id = sys.argv[1] print(model_id) results = lm_eval.simple_evaluate( # 3) hf serialized model="hf", model_args={ "pretrained": model_id, "add_bos_token": False, "dtype": "auto", "device_map": "cuda", #"max_length": 128000, }, device="cuda", # 3/) #tasks=["gsm8k_platinum", "mmlu_llama", "longbench2_single"], tasks=["gsm8k_platinum"], batch_size=64, apply_chat_template=True, fewshot_as_multiturn=True, ) print(model_id) print(lm_eval.utils.make_table(results)) ``` </details> <details><summary>compress.py</summary> ```python from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.utils import dispatch_for_generation from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs # Select model and load it. #model_id = "Qwen/Qwen2.5-14B-Instruct-1M" model_id = "meta-llama/Llama-3.1-8B-Instruct" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) # Select calibration dataset. DATASET_ID = "ultrachat_200k" DATASET_SPLIT = "train_sft" # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 # Configure the quantization algorithm to run. args = QuantizationArgs( num_bits=8, type="float", strategy="attn_head", symmetric=True, observer="static_minmax", ) recipe = QuantizationModifier( # config_groups={ # "attention": QuantizationScheme( # #targets=["Qwen2Attention"], # targets=["LlamaAttention"], # input_activations=args, # ) # } kv_cache_scheme=args, ) # Apply algorithms. oneshot( model=model, dataset=DATASET_ID, splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"}, recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) # Confirm generations of the quantized model look sane. print("\n\n") print("========== SAMPLE GENERATION ==============") dispatch_for_generation(model) sample = tokenizer("Hello my name is", return_tensors="pt") sample = {key: value.to(model.device) for key, value in sample.items()} output = model.generate(**sample, max_new_tokens=100) print(tokenizer.decode(output[0])) print("==========================================\n\n") # Save to disk compressed. SAVE_DIR = model_id.rstrip("/").split("/")[-1] + f"-KV-FP8-{args.strategy}-{args.observer}" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) ``` </details> Model | GSM8K -- | -- nm-testing/Llama-3.1-8B-Instruct | 0.8337 nm-testing/Llama-3.1-8B-Instruct-KV-FP8-Tensor | 0.8271 nm-testing/Llama-3.1-8B-Instruct-KV-FP8-Head | 0.8354 nm-testing/Llama-3.1-8B-Instruct-QKV-FP8-Tensor | 0.8321 nm-testing/Llama-3.1-8B-Instruct-QKV-FP8-Head | 0.8238 --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 6db28bc commit 4ce1bdf

File tree

10 files changed

+195
-557
lines changed

10 files changed

+195
-557
lines changed

experimental/llama3_attention.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
from llmcompressor.utils import dispatch_for_generation
7+
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs
8+
9+
# Select model and load it.
10+
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
11+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
12+
tokenizer = AutoTokenizer.from_pretrained(model_id)
13+
14+
# Select calibration dataset.
15+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
16+
DATASET_SPLIT = "train_sft"
17+
18+
# Select number of samples. 512 samples is a good place to start.
19+
# Increasing the number of samples can improve accuracy.
20+
NUM_CALIBRATION_SAMPLES = 512
21+
MAX_SEQUENCE_LENGTH = 2048
22+
23+
# Load dataset and preprocess.
24+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
25+
ds = ds.shuffle(seed=42)
26+
27+
28+
def preprocess(example):
29+
return {
30+
"text": tokenizer.apply_chat_template(
31+
example["messages"],
32+
tokenize=False,
33+
)
34+
}
35+
36+
37+
ds = ds.map(preprocess)
38+
39+
40+
# Tokenize inputs.
41+
def tokenize(sample):
42+
return tokenizer(
43+
sample["text"],
44+
padding=False,
45+
max_length=MAX_SEQUENCE_LENGTH,
46+
truncation=True,
47+
add_special_tokens=False,
48+
)
49+
50+
51+
ds = ds.map(tokenize, remove_columns=ds.column_names)
52+
53+
# Configure the quantization algorithm to run.
54+
recipe = QuantizationModifier(
55+
config_groups={
56+
"attention": QuantizationScheme(
57+
targets=["LlamaAttention"],
58+
input_activations=QuantizationArgs(
59+
num_bits=8, type="float", strategy="attn_head"
60+
),
61+
)
62+
}
63+
)
64+
65+
# Apply algorithms.
66+
oneshot(
67+
model=model,
68+
dataset=ds,
69+
recipe=recipe,
70+
max_seq_length=MAX_SEQUENCE_LENGTH,
71+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
72+
)
73+
74+
# Confirm generations of the quantized model look sane.
75+
print("\n\n")
76+
print("========== SAMPLE GENERATION ==============")
77+
dispatch_for_generation(model)
78+
sample = tokenizer("Hello my name is", return_tensors="pt")
79+
sample = {key: value.to(model.device) for key, value in sample.items()}
80+
output = model.generate(**sample, max_new_tokens=100)
81+
print(tokenizer.decode(output[0]))
82+
print("==========================================\n\n")
83+
84+
# Save to disk compressed.
85+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-attention-fp8-head"
86+
model.save_pretrained(SAVE_DIR, save_compressed=True)
87+
tokenizer.save_pretrained(SAVE_DIR)
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# ruff: noqa
22

3-
from .cache import *
43
from .gptq import *
54
from .quantization import *

src/llmcompressor/modifiers/quantization/cache.py

Lines changed: 0 additions & 218 deletions
This file was deleted.

0 commit comments

Comments
 (0)