diff --git a/.github/workflows/gpu_tests.yml b/.github/workflows/gpu_tests.yml index 8c104d9ea..2ffb738c4 100644 --- a/.github/workflows/gpu_tests.yml +++ b/.github/workflows/gpu_tests.yml @@ -61,7 +61,7 @@ jobs: if: needs.check-file-changes.outputs.any_changed == 'true' # Runner list at https://github.com/nv-gha-runners/enterprise-runner-configuration/blob/main/docs/runner-groups.md runs-on: linux-amd64-gpu-l4-latest-1 - timeout-minutes: 90 + timeout-minutes: 120 container: &gpu_container image: nvcr.io/nvidia/pytorch:25.06-py3 env: @@ -80,7 +80,7 @@ jobs: if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} # Runner list at https://github.com/nv-gha-runners/enterprise-runner-configuration/blob/main/docs/runner-groups.md runs-on: linux-amd64-gpu-h100-latest-1 - timeout-minutes: 90 + timeout-minutes: 120 container: *gpu_container steps: *gpu_steps gpu-pr-required-check: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d59ec07cd..49f599e6e 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,7 @@ Model Optimizer Changelog (Linux) - Allow specifying ``calib_seq`` in ``examples/llm_ptq`` to set the maximum sequence length for calibration. - Add support for MCore MoE PTQ/QAT/QAD. - Add support for multi-node PTQ and export with FSDP2 in ``examples/llm_ptq/multinode_ptq.py``. See `examples/llm_ptq/README.md `_ for more details. +- Add support for Nemotron Nano VL v1 & v2 models in FP8/NVFP4 PTQ workflow. **Documentation** diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index d6ae283a1..c313ccb1a 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -39,6 +39,91 @@ SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"] +def run_nemotron_vl_preview( + full_model, tokenizer, input_ids, pyt_ckpt_path, stage_name, allow_fallback=False +): + """Run text-only and VL preview generation for Nemotron VL models. + + Args: + full_model: The full VL model + tokenizer: The tokenizer + input_ids: Input tensor for generation + pyt_ckpt_path: Path to the model checkpoint + stage_name: Description of the stage (e.g., "before quantization", "after quantization") + allow_fallback: Whether to allow fallback to standard generate on failure + + Returns: + Generated text response or None if generation failed + """ + from vlm_utils import run_text_only_generation, run_vl_preview_generation + + print(f"Running text-only preview generation for Nemotron VL model ({stage_name})...") + question = tokenizer.decode(input_ids[0], skip_special_tokens=True) + generation_config = { + "max_new_tokens": 100, + "do_sample": False, + "eos_token_id": tokenizer.eos_token_id, + } + + # Try text-only generation + text_response = run_text_only_generation( + full_model, tokenizer, question, generation_config, pyt_ckpt_path + ) + + if text_response is not None: + print(f"✅ Text-only generation successful: {text_response[:100]}...") + generated_ids = text_response + elif allow_fallback: + print("Text-only generation failed, falling back to standard generate...") + generated_ids = full_model.generate(input_ids, max_new_tokens=100) + else: + generated_ids = None + + # Run additional VL test with images + print(f"Running additional VL test with images ({stage_name})...") + run_vl_preview_generation(full_model, tokenizer, pyt_ckpt_path, stage_name) + + return generated_ids + + +def _is_multimodal_config(config): + """Check if a config indicates a multimodal model (config-only version of is_multimodal_model).""" + return ( + hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL) + or getattr(config, "model_type", "") == "phi4mm" # Phi-4 multimodal + or hasattr(config, "vision_lora") # Vision LoRA configurations + or hasattr(config, "audio_processor") # Audio processing capabilities + or ( + hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer") + ) # Image embedding layers + ) + + +def is_nemotron_vl(model_or_config): + """Check if model or config indicates a Nemotron VL model. + + Args: + model_or_config: Either a model instance or a config object. + + Returns: + bool: True if it's a Nemotron VL model, False otherwise. + """ + # Try to get config from model, or use directly if it's a config + if hasattr(model_or_config, "config"): + config = model_or_config.config + from modelopt.torch.export.model_utils import is_multimodal_model + + if not is_multimodal_model(model_or_config): + return False + else: + config = model_or_config + if not _is_multimodal_config(config): + return False + + architectures = getattr(config, "architectures", []) + return any("nemotron" in arch.lower() for arch in architectures) + + def build_quant_cfg( qformat, kv_cache_qformat, @@ -185,7 +270,21 @@ def get_model( if device == "cpu": device_map = "cpu" + # Prepare config kwargs for loading config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {} + + # Load config once and handle VL model detection + try: + hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) + if is_nemotron_vl(hf_config): + print( + "Detected Nemotron VL model from config. " + "Disabling automatic device mapping for compatibility." + ) + device_map = None + except Exception as e: + print(f"Error: Could not load config from {ckpt_path}: {e}") + raise RuntimeError(f"Failed to load model configuration from {ckpt_path}") from e if attn_implementation is not None: config_kwargs["attn_implementation"] = attn_implementation @@ -207,11 +306,6 @@ def get_model( ) model = hf_vila.llm else: - hf_config = AutoConfig.from_pretrained( - ckpt_path, - **config_kwargs, - ) - if use_seq_device_map: device_map = "sequential" # If we use sequential, set max_memory limit to ensure that the model does not occupy the full GPU @@ -282,6 +376,12 @@ def get_model( **model_kwargs, ) model.eval() + + # If device_map was disabled (None), manually move model to target device + if device_map is None and device != "cpu": + print(f"Moving model to {device} device...") + model = model.to(device) + if device == "cuda" and not is_model_on_gpu(model): print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index dcd3e0f66..0f3c231df 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -30,6 +30,8 @@ get_processor, get_tokenizer, is_enc_dec, + is_nemotron_vl, + run_nemotron_vl_preview, ) from transformers import ( AutoConfig, @@ -48,7 +50,7 @@ export_tensorrt_llm_checkpoint, get_model_type, ) -from modelopt.torch.export.model_utils import is_multimodal_model +from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model from modelopt.torch.quantization.config import need_calibration from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights from modelopt.torch.quantization.utils import is_quantized @@ -283,6 +285,9 @@ def main(args): full_model = model + # Detect if this is a Nemotron VL model using architecture-based detection + is_nemotron_vl_model = is_nemotron_vl(full_model) + if model_type == "mllama": processor = get_processor( args.pyt_ckpt_path, @@ -312,15 +317,8 @@ def main(args): tokenizer.padding_side = "left" # We only quantize the language model for VLMs other than the type supported above. - if hasattr(model, "language_model"): - parent_model = model # llama4 case - if isinstance(type(model).__dict__.get("language_model"), property): - assert hasattr(model, "model") and hasattr(model.model, "language_model"), ( - "Expected language_model in model.model, but attribute not found. " - "This may indicate an unsupported model structure." - ) - parent_model = model.model # gemma3, qwen2.5 VL case - + language_model, parent_model = get_language_model_from_vl(model) + if language_model is not None: disabled_quant_cfg = { "quant_cfg": {"default": {"enable": False}}, "algorithm": "max", @@ -331,7 +329,7 @@ def main(args): if name != "language_model": mtq.quantize(child, disabled_quant_cfg, forward_loop=None) - model = model.language_model + model = language_model model_type = get_model_type(model) if model_type == "phi4mm": @@ -458,34 +456,65 @@ def main(args): KV_QUANT_CFG_CHOICES, ) + # For Nemotron VL models, disable quantization of vision components + if is_nemotron_vl_model: + print("Disabling quantization for vision components in Nemotron VL model") + quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + quant_cfg["quant_cfg"]["*image*"] = {"enable": False} + # Also disable radio model components specifically + quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} + quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} + if not model_is_already_quantized or calibration_only: # Only run single sample for preview input_ids = next(iter(calib_dataloader))[ "input_features" if model_type == "whisper" else "input_ids" ][0:1] - try: - generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) - except Exception as e: - print( - "Error during model generation. Please check if your transformers version is " - "compatible with the model." + + # Generate preview before quantization + if is_nemotron_vl_model and tokenizer is not None: + generated_ids_before_ptq = run_nemotron_vl_preview( + full_model, + tokenizer, + input_ids, + args.pyt_ckpt_path, + "before quantization", + allow_fallback=True, ) - print(f"Error details: {e}") - raise + else: + # Standard generation for non-Nemotron VL models + generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": print("Applying nvfp4 quantization (MoE only) for gpt-oss") # quantize the model model = quantize_model(model, quant_cfg, args, calib_dataloader, calibration_only) + + # For VL models, update full_model to use the quantized language model + if is_nemotron_vl_model: + _, parent_model = get_language_model_from_vl(full_model) + if parent_model is not None: + print("Updating full_model with quantized language_model...") + parent_model.language_model = model + if args.verbose: mtq.print_quant_summary(model) # Run some samples torch.cuda.empty_cache() generated_ids_after_ptq = None - if model_type != "llama4": + if model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100) + elif is_nemotron_vl_model and tokenizer is not None: + generated_ids_after_ptq = run_nemotron_vl_preview( + full_model, + tokenizer, + input_ids, + args.pyt_ckpt_path, + "after quantization", + allow_fallback=False, + ) else: warnings.warn( "Llama4 Maverick generation after quantization has a bug. Skipping generation sample." @@ -518,15 +547,25 @@ def output_decode(generated_ids, input_shape): if generated_ids_after_ptq is not None: print("--------") - print(f"example test input: {input_decode(input_ids)}") - print("--------") - print( - f"example outputs before ptq: {output_decode(generated_ids_before_ptq, input_ids.shape[1])}" - ) - print("--------") - print( - f"example outputs after ptq: {output_decode(generated_ids_after_ptq, input_ids.shape[1])}" - ) + if is_nemotron_vl_model: + # For Nemotron VL models, generated_ids are text strings from model.chat() + print("Nemotron VL model text-only generation results:") + print(f"Text response before quantization: {generated_ids_before_ptq}") + print("--------") + print(f"Text response after quantization: {generated_ids_after_ptq}") + print("--------") + print("Note: Additional VL tests with images were run separately above") + else: + # For regular LLMs, generated_ids are token tensors that need decoding + print(f"example test input: {input_decode(input_ids)}") + print("--------") + print( + f"example outputs before ptq: {output_decode(generated_ids_before_ptq, input_ids.shape[1])}" + ) + print("--------") + print( + f"example outputs after ptq: {output_decode(generated_ids_after_ptq, input_ids.shape[1])}" + ) else: warnings.warn("Skipping quantization: model is already quantized.") @@ -548,9 +587,12 @@ def output_decode(generated_ids, input_shape): # Save original model config and the processor config to the export path for VLMs. print(f"Saving original model config to {export_path}") - AutoConfig.from_pretrained( - args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code - ).save_pretrained(export_path) + config_kwargs = {"trust_remote_code": args.trust_remote_code} + if args.attn_implementation is not None: + config_kwargs["attn_implementation"] = args.attn_implementation + AutoConfig.from_pretrained(args.pyt_ckpt_path, **config_kwargs).save_pretrained( + export_path + ) # Try to save processor config if available try: @@ -748,7 +790,7 @@ def output_decode(generated_ids, input_shape): parser.add_argument( "--attn_implementation", help=( - "Specify the attention implementation to use." + "Specify the attention implementation to use. " "This arg will be passed to the HF model loading if specified." ), default=None, diff --git a/examples/llm_ptq/vlm_utils.py b/examples/llm_ptq/vlm_utils.py new file mode 100644 index 000000000..6c9d921b8 --- /dev/null +++ b/examples/llm_ptq/vlm_utils.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for Vision-Language Model (VLM) inference and testing.""" + +import os + +from PIL import Image +from transformers import AutoImageProcessor, AutoProcessor + + +def run_vl_preview_generation(model, tokenizer, model_path, stage_name): + """Run preview generation for VL models using sample images. + + Args: + model: The VL model + tokenizer: The tokenizer + model_path: Path to the model (for loading image processor) + stage_name: Description of the stage (e.g., "before quantization") + + Returns: + Generated response text for logging/comparison + """ + try: + print(f"Loading sample images for {stage_name} preview...") + + # Load sample images from the images directory + script_dir = os.path.dirname(os.path.abspath(__file__)) + images_dir = os.path.join(script_dir, "images") + + # Check if images directory exists + if not os.path.exists(images_dir): + print(f"❌ Warning: Images directory not found at {images_dir}") + print(" VL preview generation requires sample images to test vision capabilities.") + print(" Skipping VL preview generation.") + return None + + # Use single image for VL preview to avoid shape mismatch issues + image_files = ["example1a.jpeg", "example1b.jpeg", "example.jpg", "test.jpg", "sample.png"] + image = None + missing_files = [] + for img_file in image_files: + img_path = os.path.join(images_dir, img_file) + if os.path.exists(img_path): + try: + image = Image.open(img_path) + print(f" ✅ Successfully loaded: {img_file}") + break # Use the first available image + except Exception as e: + print(f" ⚠️ Warning: Could not open {img_file}: {e}") + missing_files.append(f"{img_file} (corrupted)") + else: + missing_files.append(img_file) + + if image is None: + print(f"❌ Warning: No valid sample images found in {images_dir}") + print(f" Searched for: {', '.join(image_files)}") + if missing_files: + print(f" Missing/invalid files: {', '.join(missing_files)}") + print(" VL preview generation requires sample images to test vision capabilities.") + print(" Skipping VL preview generation.") + return None + + # Generate response + question = "Describe this image briefly." # Updated for single image + generation_config = { + "max_new_tokens": 50, + "do_sample": False, + "eos_token_id": tokenizer.eos_token_id, + } + + print(f"Generating VL response ({stage_name})...") + + # Try to detect the VL model has chat method or generate method + if hasattr(model, "chat"): + image_processor = AutoImageProcessor.from_pretrained(model_path, trust_remote_code=True) + + image_features = image_processor([image]) # Pass as list with single image + + # Move image features to the same device as the model + model_device = model.device + for key, value in image_features.items(): + if hasattr(value, "to"): # Check if it's a tensor + image_features[key] = value.to(model_device) + print(f" Moved {key} to {model_device}") + + response = model.chat( + tokenizer=tokenizer, + question=question, + generation_config=generation_config, + **image_features, + ) + else: + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + messages = [ + {"role": "system", "content": "/no_think"}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "", + }, + { + "type": "text", + "text": question, + }, + ], + }, + ] + + # Apply chat template + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + # Process inputs using the processor with single image + inputs = processor( + text=[prompt], + images=[image], # Pass single image as list + return_tensors="pt", + ) + + # Move inputs to the same device as the model + model_device = model.device + inputs = inputs.to(model_device) + print(f" Moved inputs to {model_device}") + + # Generate response using model.generate + generated_ids = model.generate( + pixel_values=inputs.pixel_values, + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + **generation_config, + ) + + # Decode the response (trim input tokens like in the working example) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + output_text = processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + response = output_text[0] + + print(f"✅ VL generation {stage_name} successful!") + print(f"Question: {question}") + print(f"Response: {response}") + + # Return the response for comparison/logging + return response + + except Exception as e: + print(f"❌ VL preview generation {stage_name} failed: {e}") + print("This may indicate issues with the quantized model") + return None + + +def run_text_only_generation(model, tokenizer, question, generation_config, model_path): + """Run text-only generation for VL models, supporting both chat and generate methods. + + Args: + model: The VL model + tokenizer: The tokenizer + question: The text question to ask + generation_config: Generation configuration + model_path: Path to the model (for loading processor if needed) + + Returns: + Generated response text or None if failed + """ + try: + if hasattr(model, "chat"): + # Use model.chat with None for images (text-only mode) + response = model.chat(tokenizer, None, question, generation_config, history=None) + return response + else: + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + # Create text-only messages + messages = [ + {"role": "system", "content": "/no_think"}, + { + "role": "user", + "content": [ + { + "type": "text", + "text": question, + }, + ], + }, + ] + + # Apply chat template + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + # Process text-only inputs + inputs = processor( + text=[prompt], + images=None, # No images for text-only + return_tensors="pt", + ) + + # Move inputs to the same device as the model + model_device = model.device + inputs = inputs.to(model_device) + + # Generate response using model.generate + generated_ids = model.generate( + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + **generation_config, + ) + + # Decode the response (trim input tokens like in the working example) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + output_text = processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + return output_text[0] + + except Exception as e: + print(f"Text-only generation failed: {e}") + return None diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 5ce630168..44f3c185c 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -60,7 +60,7 @@ {MODEL_NAME_TO_TYPE=} """ -__all__ = ["get_model_type", "is_multimodal_model"] +__all__ = ["get_language_model_from_vl", "get_model_type", "is_multimodal_model"] def get_model_type(model): @@ -109,3 +109,49 @@ def is_multimodal_model(model): hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer") ) # Image embedding layers ) + + +def get_language_model_from_vl(model): + """Extract the language model component from a Vision-Language Model (VLM). + + This function handles the common patterns for accessing the language model component + in various VLM architectures. It checks multiple possible locations where the + language model might be stored. + + Args: + model: The VLM model instance to extract the language model from + + Returns: + tuple: (language_model, parent_model) where: + - language_model: The extracted language model component, or None if not found + - parent_model: The parent model containing the language_model attribute + + Examples: + >>> # For LLaVA-style models + >>> lang_model, parent = get_language_model_from_vl(vlm_model) + >>> if lang_model is not None: + ... # Work with the language model component + ... quantized_lang_model = quantize(lang_model) + ... # Update the parent model + ... parent.language_model = quantized_lang_model + """ + # Pattern 1: Direct language_model attribute (e.g., LLaVA, some Nemotron models) + if hasattr(model, "language_model"): + # Check if it's a property that might need special handling + if isinstance(type(model).__dict__.get("language_model"), property): + # Some models have language_model as a property that points to model.model.language_model + if hasattr(model, "model") and hasattr(model.model, "language_model"): + return model.model.language_model, model.model + else: + # Property exists but no nested structure found + return model.language_model, model + else: + # Direct attribute access + return model.language_model, model + + # Pattern 2: Nested in model.model.language_model (e.g., some Gemma3, Qwen2.5-VL models) + elif hasattr(model, "model") and hasattr(model.model, "language_model"): + return model.model.language_model, model.model + + # Pattern 3: No language_model found + return None, None diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 954707e8d..36520f9fc 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -20,6 +20,7 @@ import re import tempfile import warnings +from builtins import ValueError from collections import defaultdict from pathlib import Path from typing import Any @@ -55,6 +56,7 @@ QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_NVFP4_FP8, ) +from .model_utils import get_language_model_from_vl, is_multimodal_model from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only from .quant_utils import ( fuse_prequant_layernorm, @@ -134,6 +136,10 @@ def _output_hook(module, input, output): with torch.no_grad(): fake_input = torch.ones([1, 2], dtype=torch.long).to(model.device) decoder_fake_input = fake_input + + # Check if this is a VL model that needs special input handling + is_vl_model = is_multimodal_model(model) + if model_type.startswith("whisper"): # For Whisper models, we need to pass a fake input with the specific sequence length from transformers import AutoFeatureExtractor @@ -149,6 +155,23 @@ def _output_hook(module, input, output): if getattr(model.config, "is_encoder_decoder", False): # For encoder-decoder models, we need to pass both the encoder and decoder input ids model(fake_input, decoder_input_ids=decoder_fake_input) + elif is_vl_model and "nemotron" in model_type: + # For Nemotron VL models, try to run optimization on just the language model part + language_model, _ = get_language_model_from_vl(model) + + if language_model is not None: + # Run optimization on just the language model with the same input format as regular LLMs + # Use the same fake_input tensor that regular LLMs use + print( + f"Running optimization on language model with fake_input shape: {fake_input.shape}" + ) + language_model(fake_input) + else: + raise ValueError( + f"Cannot extract language_model from Nemotron VL model (type: {model_type}). " + "This is required for requantization/resmoothing optimization. " + "Please ensure the model architecture is supported or file an issue." + ) else: model(fake_input)