-
Notifications
You must be signed in to change notification settings - Fork 264
Closed
Labels
awqFor any issue / PR related to AWQ supportFor any issue / PR related to AWQ supportqwenFor any PR / issue related to Qwen supportFor any PR / issue related to Qwen support
Description
Script:
import base64
from io import BytesIO
import torch
from datasets import load_dataset
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.utils import dispatch_for_generation
MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Instruct"
OUTPUT_DIR = MODEL_ID.split("/")[-1] + "-AWQ-W4A16"
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
DATASET_ID = "lmms-lab/flickr30k"
NUM_CALIBRATION_SAMPLES = 256
DATASET_SPLIT = f"test[:{NUM_CALIBRATION_SAMPLES}]"
MAX_SEQUENCE_LENGTH = 1024
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42)
def preprocess_and_tokenize(example):
buffered = BytesIO()
example["image"].save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue())
base64_image = f"data:image;base64,{encoded_image.decode('utf-8')}"
messages = [{"role": "user", "content": [{"type": "image", "image": base64_image}, {"type": "text", "text": "What does the image show?"}]}]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
images=[example["image"]],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
)
return inputs
ds = ds.map(preprocess_and_tokenize, remove_columns=ds.column_names)
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}
recipe = AWQModifier(
targets="Linear",
scheme="W4A16",
ignore=["re:.*lm_head", "re:.*visual.*", "re:.*mlp.gate$"],
duo_scaling=False,
)
oneshot(
model=model,
processor=processor,
recipe=recipe,
dataset=ds,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
data_collator=data_collator,
pipeline="sequential",
)
dispatch_for_generation(model)
messages = [{
"role": "user",
"content": [
{"type": "image", "image": "http://images.cocodataset.org/train2017/000000231895.jpg"},
{"type": "text", "text": "Please describe the animal in this image\n"},
],
}]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
text=[prompt],
images=["http://images.cocodataset.org/train2017/000000231895.jpg"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
return_tensors="pt",
).to("cuda")
output = model.generate(**inputs, max_new_tokens=100)
processor.decode(output[0], skip_special_tokens=True)
model.save_pretrained(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)
root@308ee78b4020:# export TOKENIZERS_PARALLELISM=false# python3 te.py
root@308ee78b4020:
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:10<00:00, 1.23it/s]
2025-10-16T09:33:02.908921+0000 | reset | INFO - Compression lifecycle reset
2025-10-16T09:33:02.917289+0000 | _create_default_logger | INFO - Logging all LLM Compressor modifier-level logs to sparse_logs/16-10-2025_09.33.02.log
2025-10-16T09:33:02.917643+0000 | from_modifiers | INFO - Creating recipe from modifiers
2025-10-16T09:33:02.985165+0000 | on_initialize | INFO - No AWQModifier.mappings provided, inferring from model...
2025-10-16T09:33:02.985360+0000 | get_layer_mappings_from_architecture | INFO - Architecture Qwen3VLMoeForConditionalGeneration not found in mappings. Using default mappings: [AWQMapping(smooth_layer='re:.*input_layernorm$', balance_layers=['re:.*q_proj$', 're:.*k_proj$', 're:.*v_proj$']), AWQMapping(smooth_layer='re:.*v_proj$', balance_layers=['re:.*o_proj$']), AWQMapping(smooth_layer='re:.*post_attention_layernorm$', balance_layers=['re:.*gate_proj$', 're:.*up_proj$']), AWQMapping(smooth_layer='re:.*up_proj$', balance_layers=['re:.*down_proj$'])]
Resolving mapping 1/4 (0 skipped): : 48it [00:00, 4149.27it/s]
Resolving mapping 2/4 (47 skipped): : 48it [00:00, 8543.10it/s]
Resolving mapping 3/4 (0 skipped): : 48it [00:00, 5982.96it/s]
0it [00:00, ?it/s]
2025-10-16T09:33:03.014494+0000 | initialize | INFO - Compression lifecycle initialized for 1 modifiers
2025-10-16T09:33:53.253637+0000 | trace_subgraphs | WARNING - Expected 75 subgraphs, but only traced 49. This is likely due to having wrapped code which calls sequential targets
Preparing cache: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:31<00:00, 8.22it/s]
(1/49): Calibrating: 0%| | 0/256 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/pipelines/sequential/helpers.py", line 73, in forward
outputs = forward_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<string>", line 19, in forward
File "Qwen3VLMoeModel_7739583652820_autowrapped", line 33, in wrapped_2
File "/usr/local/lib/python3.12/dist-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py", line 1227, in get_image_features
image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py", line 769, in forward
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py", line 738, in fast_pos_embed_interpolate
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 170, in new_forward
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/compressed_tensors/utils/offload.py", line 574, in keep_onload_pre_forward
ret = original_pre_forward(self, module, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 369, in pre_forward
return send_to_device(args, self.execution_device), send_to_device(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py", line 169, in send_to_device
return honor_type(
^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py", line 81, in honor_type
return type(obj)(generator)
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py", line 170, in <genexpr>
tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/accelerate/utils/operations.py", line 153, in send_to_device
return tensor.to(device, non_blocking=non_blocking)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/root/te.py", line 58, in <module>
oneshot(
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/entrypoints/oneshot.py", line 330, in oneshot
one_shot()
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/entrypoints/oneshot.py", line 158, in __call__
self.apply_recipe_modifiers(
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/entrypoints/oneshot.py", line 201, in apply_recipe_modifiers
pipeline(
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/pipelines/sequential/pipeline.py", line 104, in __call__
subgraph.forward(model, **inputs)
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/pipelines/sequential/helpers.py", line 75, in forward
raise RuntimeError(
RuntimeError: Raised an exception during execution of the following code:
1
2 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_0")
3 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_1")
4 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_2")
5 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_3")
6 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_5")
7 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_4")
8 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_6")
9 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_7")
10 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_8")
11 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_9")
12 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_10")
13 torch.fx._symbolic_trace.wrap("transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_11")
14
15 def forward(self, input_ids : torch.Tensor, attention_mask : torch.Tensor, pixel_values : torch.Tensor, image_grid_thw : torch.Tensor):
16 wrapped_0 = transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_0(input_ids, None); wrapped_0 = None
17 wrapped_1 = transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_1(input_ids, None)
18 getitem = wrapped_1[0]; wrapped_1 = None
19 wrapped_2 = transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_2(image_grid_thw, None, input_ids, getitem, pixel_values); getitem = pixel_values = None
20 getitem_1 = wrapped_2[0]
21 getitem_2 = wrapped_2[1]
22 getitem_3 = wrapped_2[2]; getitem_3 = None
23 getitem_4 = wrapped_2[3]
24 getitem_5 = wrapped_2[4]; wrapped_2 = None
25 wrapped_3 = transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_3(getitem_1, input_ids, getitem_5, None, None, None); getitem_1 = getitem_5 = None
26 getitem_6 = wrapped_3[0]
27 getitem_7 = wrapped_3[1]
28 getitem_8 = wrapped_3[2]
29 getitem_9 = wrapped_3[3]; getitem_9 = None
30 getitem_10 = wrapped_3[4]; wrapped_3 = None
31 wrapped_5 = transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_5(getitem_6, attention_mask, None, image_grid_thw, input_ids, getitem_8, None, None, None); getitem_6 = image_grid_thw = input_ids = None
32 wrapped_6 = transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_4(None, getitem_8); wrapped_6 = None
33 wrapped_7 = transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_6(None, getitem_8); getitem_8 = None
34 wrapped_4 = transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_7(getitem_2, getitem_7, None, getitem_4, getitem_10, None); getitem_2 = getitem_7 = getitem_4 = getitem_10 = None
35 getitem_20 = wrapped_5[0]; getitem_20 = None
36 getitem_21 = wrapped_5[1]; getitem_21 = None
37 getitem_22 = wrapped_5[2]; getitem_22 = None
38 getitem_23 = wrapped_5[3]; getitem_23 = None
39 getitem_24 = wrapped_5[4]
40 getitem_25 = wrapped_5[5]; getitem_25 = None
41 getitem_26 = wrapped_5[6]; getitem_26 = None
42 getitem_27 = wrapped_5[7]; getitem_27 = None
43 getitem_28 = wrapped_5[8]; wrapped_5 = getitem_28 = None
44 getitem_29 = wrapped_7[0]; wrapped_7 = None
45 getitem_11 = wrapped_4[0]
46 getitem_12 = wrapped_4[1]; getitem_12 = None
47 getitem_13 = wrapped_4[2]; getitem_13 = None
48 getitem_14 = wrapped_4[3]; getitem_14 = None
49 getitem_15 = wrapped_4[4]; getitem_15 = None
50 getitem_16 = wrapped_4[5]; getitem_16 = None
51 getitem_17 = wrapped_4[6]; getitem_17 = None
52 getitem_18 = wrapped_4[7]; getitem_18 = None
53 getitem_19 = wrapped_4[8]; wrapped_4 = None
54 wrapped_8 = transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_8(None, getitem_29, None)
55 getitem_30 = wrapped_8[0]
56 getitem_31 = wrapped_8[1]; wrapped_8 = getitem_31 = None
57 wrapped_9 = transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_9(getitem_30, getitem_29, getitem_24); getitem_24 = None
58 getitem_32 = wrapped_9[0]; wrapped_9 = None
59 wrapped_10 = transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_10(getitem_32); getitem_32 = None
60 getitem_33 = wrapped_10[0]
61 getitem_34 = wrapped_10[1]; wrapped_10 = None
62 model_language_model_rotary_emb = self.model.language_model.rotary_emb(getitem_29, getitem_33); getitem_33 = None
63 wrapped_11 = transformers_models_qwen3_vl_moe_modeling_qwen3_vl_moe_wrapped_11(attention_mask, getitem_30, getitem_29, None, getitem_34); attention_mask = None
64 model_language_model_layers_0 = getattr(self.model.language_model.layers, "0")(getitem_29, attention_mask = wrapped_11, position_ids = getitem_34, past_key_values = None, cache_position = getitem_30, position_embeddings = model_language_model_rotary_emb); getitem_29 = None
65 return {'getitem_11': getitem_11, 'getitem_19': getitem_19, 'getitem_30': getitem_30, 'getitem_34': getitem_34, 'model_language_model_rotary_emb': model_language_model_rotary_emb, 'wrapped_11': wrapped_11, 'model_language_model_layers_0': model_language_model_layers_0}
66
I don't know where to start anymore, I always end up at the same point.
I'm on RunPod with an 80GB A100.
Metadata
Metadata
Assignees
Labels
awqFor any issue / PR related to AWQ supportFor any issue / PR related to AWQ supportqwenFor any PR / issue related to Qwen supportFor any PR / issue related to Qwen support