Skip to content

Commit 8c3f4b5

Browse files
committed
add qwen3-vl-30b-a3b-Instruct-example
Signed-off-by: JartX <[email protected]>
1 parent 5061adf commit 8c3f4b5

File tree

1 file changed

+147
-0
lines changed

1 file changed

+147
-0
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import torch
2+
from datasets import load_dataset
3+
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
4+
5+
from llmcompressor import oneshot
6+
from llmcompressor.modeling import replace_modules_for_calibration
7+
from llmcompressor.modifiers.awq import AWQModifier
8+
from llmcompressor.utils import dispatch_for_generation
9+
10+
# NOTE: Requires a minimum of transformers 4.57.0
11+
12+
MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Instruct"
13+
14+
# Load model.
15+
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
16+
MODEL_ID,
17+
torch_dtype=torch.bfloat16,
18+
device_map=None,
19+
trust_remote_code=True
20+
)
21+
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
22+
model = replace_modules_for_calibration(model)
23+
24+
DATASET_ID = "neuralmagic/calibration"
25+
NUM_CALIBRATION_SAMPLES = 256
26+
MAX_SEQUENCE_LENGTH = 8192
27+
28+
ds = load_dataset(DATASET_ID, name="LLM",
29+
split=f"train[:{NUM_CALIBRATION_SAMPLES}]")
30+
ds = ds.shuffle(seed=42)
31+
32+
33+
def preprocess_function(example):
34+
messages = []
35+
for message in example["messages"]:
36+
messages.append(
37+
{
38+
"role": message["role"],
39+
"content": [{"type": "text", "text": message["content"]}],
40+
}
41+
)
42+
43+
return processor.apply_chat_template(
44+
messages,
45+
return_tensors="pt",
46+
padding=False,
47+
truncation=True,
48+
max_length=MAX_SEQUENCE_LENGTH,
49+
tokenize=True,
50+
add_special_tokens=False,
51+
return_dict=True,
52+
add_generation_prompt=False,
53+
)
54+
55+
56+
ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names)
57+
58+
59+
def data_collator(batch):
60+
assert len(batch) == 1
61+
return {
62+
key: (
63+
torch.tensor(value)
64+
if key != "pixel_values"
65+
else torch.tensor(value, dtype=torch.bfloat16).squeeze(0)
66+
)
67+
for key, value in batch[0].items()
68+
}
69+
70+
71+
# Configure AWQ quantization with smoothing and balancing
72+
recipe = AWQModifier(
73+
ignore=[
74+
're:.*embed_tokens',
75+
're:.*input_layernorm$',
76+
're:.*mlp[.]gate$',
77+
're:.*post_attention_layernorm$',
78+
're:.*norm$',
79+
're:model[.]visual.*',
80+
're:visual.*',
81+
'lm_head'
82+
],
83+
mappings=[
84+
{
85+
"smooth_layer": "re:.*input_layernorm$",
86+
"balance_layers": ['re:.*q_proj$', 're:.*k_proj$', 're:.*v_proj$']
87+
},
88+
{
89+
"smooth_layer": "re:.*v_proj$",
90+
"balance_layers": ['re:.*o_proj$']
91+
},
92+
{
93+
"smooth_layer": "re:.*post_attention_layernorm$",
94+
"balance_layers": ['re:.*gate_proj$', 're:.*up_proj$']
95+
},
96+
{
97+
"smooth_layer": "re:.*up_proj$",
98+
"balance_layers": ['re:.*down_proj$']
99+
}
100+
],
101+
duo_scaling=True,
102+
config_groups={
103+
"group_0": {
104+
"targets": ["Linear"],
105+
"weights": {
106+
"num_bits": 8,
107+
"type": "int",
108+
"symmetric": True,
109+
"group_size": 32,
110+
"strategy": "group",
111+
"block_structure": None,
112+
"dynamic": False,
113+
"actorder": None,
114+
"observer": "mse",
115+
"observer_kwargs": {}
116+
},
117+
"input_activations": None,
118+
"output_activations": None,
119+
"format": None
120+
}
121+
}
122+
)
123+
124+
# Apply AWQ quantization.
125+
oneshot(
126+
model=model,
127+
processor=processor,
128+
recipe=recipe,
129+
dataset=ds,
130+
max_seq_length=MAX_SEQUENCE_LENGTH,
131+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
132+
data_collator=data_collator,
133+
134+
)
135+
136+
print("========== SAMPLE GENERATION ==============")
137+
dispatch_for_generation(model)
138+
input_ids = processor(text="Hello my name is",
139+
return_tensors="pt").input_ids.to("cuda")
140+
output = model.generate(input_ids, max_new_tokens=20)
141+
print(processor.decode(output[0]))
142+
print("==========================================")
143+
144+
# Save to disk in compressed-tensors format.
145+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-AWQ-W8A16-mse-seq"
146+
model.save_pretrained(SAVE_DIR, save_compressed=True)
147+
processor.save_pretrained(SAVE_DIR)

0 commit comments

Comments
 (0)