- 
                Notifications
    You must be signed in to change notification settings 
- Fork 271
[AWQ][Qwen3 VL] Add qwen3-vl-30b-a3b-Instruct-example #1947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
            kylesayrs
  merged 9 commits into
  vllm-project:main
from
JartX:feature/qwen3-vl-30b-a3b-Instruct-example
  
      
      
   
  Oct 22, 2025 
      
    
      
        
          +120
        
        
          −0
        
        
          
        
      
    
  
  
     Merged
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            9 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      8c3f4b5
              
                add qwen3-vl-30b-a3b-Instruct-example
              
              
                JartX a11434d
              
                format
              
              
                JartX f68090a
              
                applied ruff format
              
              
                JartX 5387449
              
                Remove layer balancing mappings from config
              
              
                JartX 79d63ea
              
                Update num_bits for AWQ quantization
              
              
                JartX 59ec83c
              
                Apply suggestion from @brian-dellabetta
              
              
                JartX e3b373e
              
                Apply suggestion from @brian-dellabetta
              
              
                JartX 9fb1145
              
                format
              
              
                JartX 7ed1c42
              
                Merge branch 'main' into feature/qwen3-vl-30b-a3b-Instruct-example
              
              
                brian-dellabetta File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
          Some comments aren't visible on the classic Files Changed page.
        
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| import torch | ||
| from datasets import load_dataset | ||
| from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration | ||
|  | ||
| from llmcompressor import oneshot | ||
| from llmcompressor.modeling import replace_modules_for_calibration | ||
| from llmcompressor.modifiers.awq import AWQModifier | ||
| from llmcompressor.utils import dispatch_for_generation | ||
|  | ||
| # NOTE: Requires a minimum of transformers 4.57.0 | ||
|  | ||
| MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Instruct" | ||
|  | ||
| # Load model. | ||
| model = Qwen3VLMoeForConditionalGeneration.from_pretrained( | ||
| MODEL_ID, torch_dtype=torch.bfloat16, device_map=None, trust_remote_code=True | ||
| ) | ||
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | ||
| model = replace_modules_for_calibration(model) | ||
|  | ||
| DATASET_ID = "neuralmagic/calibration" | ||
| NUM_CALIBRATION_SAMPLES = 256 | ||
| MAX_SEQUENCE_LENGTH = 8192 | ||
|  | ||
| ds = load_dataset(DATASET_ID, name="LLM", split=f"train[:{NUM_CALIBRATION_SAMPLES}]") | ||
| ds = ds.shuffle(seed=42) | ||
|  | ||
|  | ||
| def preprocess_function(example): | ||
| messages = [] | ||
| for message in example["messages"]: | ||
| messages.append( | ||
| { | ||
| "role": message["role"], | ||
| "content": [{"type": "text", "text": message["content"]}], | ||
| } | ||
| ) | ||
|  | ||
| return processor.apply_chat_template( | ||
| messages, | ||
| return_tensors="pt", | ||
| padding=False, | ||
| truncation=True, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| tokenize=True, | ||
| add_special_tokens=False, | ||
| return_dict=True, | ||
| add_generation_prompt=False, | ||
| ) | ||
|  | ||
|  | ||
| ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names) | ||
|  | ||
|  | ||
| def data_collator(batch): | ||
| assert len(batch) == 1 | ||
| return { | ||
| key: ( | ||
| torch.tensor(value) | ||
|         
                  JartX marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| if key != "pixel_values" | ||
| else torch.tensor(value, dtype=torch.bfloat16).squeeze(0) | ||
| ) | ||
| for key, value in batch[0].items() | ||
| } | ||
|  | ||
|  | ||
| # Configure AWQ quantization with smoothing and balancing | ||
| # NOTE: This recipe uses W4A16 quantization with group_size=32 | ||
| # rather than the default preset with group_size=128 | ||
| recipe = AWQModifier( | ||
| ignore=[ | ||
| "re:.*embed_tokens", | ||
| "re:.*input_layernorm$", | ||
| "re:.*mlp[.]gate$", | ||
| "re:.*post_attention_layernorm$", | ||
| "re:.*norm$", | ||
| "re:model[.]visual.*", | ||
| "re:visual.*", | ||
| "lm_head", | ||
| ], | ||
| duo_scaling=True, | ||
| config_groups={ | ||
|         
                  brian-dellabetta marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| "group_0": { | ||
| "targets": ["Linear"], | ||
| "weights": { | ||
| "num_bits": 4, | ||
| "type": "int", | ||
| "symmetric": True, | ||
| "group_size": 32, | ||
| "strategy": "group", | ||
| "dynamic": False, | ||
| "actorder": None, | ||
| "observer": "mse", | ||
| }, | ||
| } | ||
| }, | ||
| ) | ||
|  | ||
| # Apply AWQ quantization. | ||
| oneshot( | ||
| model=model, | ||
| processor=processor, | ||
| recipe=recipe, | ||
| dataset=ds, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| data_collator=data_collator, | ||
| ) | ||
|  | ||
| print("========== SAMPLE GENERATION ==============") | ||
| dispatch_for_generation(model) | ||
| input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") | ||
| output = model.generate(input_ids, max_new_tokens=20) | ||
| print(processor.decode(output[0])) | ||
| print("==========================================") | ||
|  | ||
| # Save to disk in compressed-tensors format. | ||
| SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-AWQ-W8A16-mse-seq" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| processor.save_pretrained(SAVE_DIR) | ||
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.