diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 3ac167db2..1493c0aa8 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -136,6 +136,17 @@ def get_model( if device == "cpu": device_map = "cpu" + # Special handling for vision-language models that may have device mapping issues + # Check if this is a VL model by looking at the model path + is_vl_model = any( + vl_keyword in ckpt_path.lower() for vl_keyword in ["vl", "vision", "nemotron-nano-vl"] + ) + if is_vl_model: + print( + "Detected vision-language model. Disabling automatic device mapping to avoid device_map errors." + ) + device_map = None + config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {} if attn_implementation is not None: config_kwargs["attn_implementation"] = attn_implementation @@ -235,6 +246,12 @@ def get_model( **model_kwargs, ) model.eval() + + # If device_map was disabled (None), manually move model to target device + if device_map is None and device != "cpu": + print(f"Moving model to {device} device...") + model = model.to(device) + if device == "cuda" and not is_model_on_gpu(model): print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 81f4b6392..4a9f45c56 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -26,6 +26,7 @@ from example_utils import apply_kv_cache_quant, get_model, get_processor, get_tokenizer, is_enc_dec from transformers import ( AutoConfig, + AutoImageProcessor, AutoModelForCausalLM, AutoProcessor, PreTrainedTokenizer, @@ -83,6 +84,86 @@ mto.enable_huggingface_checkpointing() +def _run_vl_preview_generation(model, tokenizer, model_path, stage_name): + """Run preview generation for VL models using sample images. + + Args: + model: The VL model + tokenizer: The tokenizer + model_path: Path to the model (for loading image processor) + stage_name: Description of the stage (e.g., "before quantization") + + Returns: + Generated response text for logging/comparison + """ + import os + + from PIL import Image + from transformers import AutoImageProcessor + + try: + print(f"Loading sample images for {stage_name} preview...") + + # Load image processor + image_processor = AutoImageProcessor.from_pretrained(model_path, trust_remote_code=True) + + # Load sample images from the images directory + script_dir = os.path.dirname(os.path.abspath(__file__)) + images_dir = os.path.join(script_dir, "images") + + image_files = ["example1a.jpeg", "example1b.jpeg"] + images = [] + for img_file in image_files: + img_path = os.path.join(images_dir, img_file) + if os.path.exists(img_path): + images.append(Image.open(img_path)) + print(f" Loaded: {img_file}") + else: + print(f" Warning: {img_file} not found") + + if not images: + print("No sample images found - skipping VL preview generation") + return None + + # Process images + image_features = image_processor(images) + + # Move image features to the same device as the model + model_device = model.device + for key, value in image_features.items(): + if hasattr(value, "to"): # Check if it's a tensor + image_features[key] = value.to(model_device) + print(f" Moved {key} to {model_device}") + + # Generate response + question = "Describe these images briefly." + generation_config = { + "max_new_tokens": 50, + "do_sample": False, + "eos_token_id": tokenizer.eos_token_id, + } + + print(f"Generating VL response ({stage_name})...") + response = model.chat( + tokenizer=tokenizer, + question=question, + generation_config=generation_config, + **image_features, + ) + + print(f"✅ VL generation {stage_name} successful!") + print(f"Question: {question}") + print(f"Response: {response}") + + # Return the response for comparison/logging + return response + + except Exception as e: + print(f"❌ VL preview generation {stage_name} failed: {e}") + print("This may indicate issues with the quantized model") + return None + + def auto_quantize( model, qformat, auto_quantize_bits, calib_dataloader, calibrate_loop, batch_size=1 ): @@ -281,6 +362,16 @@ def main(args): model_type = get_model_type(model) + # Special handling for Nemotron VL models that aren't detected by standard model type detection + # For HF export, we want to keep vision unquantized, so we treat it as a regular language model + # and only quantize the language components + if model_type != "mllama" and is_multimodal_model(model): + print( + f"Detected multimodal model: {type(model).__name__}. " + f"For HF export, will quantize language components only, keeping vision unquantized." + ) + # Keep as regular model type to use text-only calibration + device = model.device if hasattr(model, "model"): device = model.model.device @@ -487,34 +578,126 @@ def main(args): "Please set the default input_mode to InputMode.LANGUAGE before quantizing." ) + # For Nemotron VL models, disable quantization of vision components + is_nemotron_vl = ( + "nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower() + ) + if is_nemotron_vl: + print("Disabling quantization for vision components in Nemotron VL model") + quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + quant_cfg["quant_cfg"]["*image*"] = {"enable": False} + # Also disable radio model components specifically + quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} + quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} + if not model_is_already_quantized or calibration_only: # Only run single sample for preview input_ids = next(iter(calib_dataloader))[ "input_features" if model_type == "whisper" else "input_ids" ][0:1] - try: - generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) - except Exception as e: - print( - "Error during model generation. Please check if your transformers version is " - "compatible with the model." + + # For Nemotron VL models, try text-only generation first, then VL generation as additional test + is_nemotron_vl = ( + "nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower() + ) + if is_nemotron_vl: + print("Running text-only preview generation for Nemotron VL model...") + try: + # Try text-only generation using model.chat with None for images + if tokenizer is None: + raise ValueError("Tokenizer is required for Nemotron VL text generation") + + question = tokenizer.decode(input_ids[0], skip_special_tokens=True) + generation_config = { + "max_new_tokens": 100, + "do_sample": False, + "eos_token_id": tokenizer.eos_token_id, + } + + # Use model.chat with None for images (text-only mode) + text_response = full_model.chat( + tokenizer, None, question, generation_config, history=None + ) + generated_ids_before_ptq = text_response # Store text response + print(f"✅ Text-only generation successful: {text_response[:100]}...") + + except Exception as e: + print(f"Text-only generation failed: {e}") + print("Falling back to standard generate() method...") + try: + generated_ids_before_ptq = full_model.generate( + input_ids, max_new_tokens=100 + ) + except Exception as e2: + print(f"Standard generation also failed: {e2}") + generated_ids_before_ptq = None + + # Run additional VL test with images + print("Running additional VL test with images...") + _run_vl_preview_generation( + full_model, tokenizer, args.pyt_ckpt_path, "before quantization (VL test)" ) - print(f"Error details: {e}") - raise + + else: + try: + generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) + except Exception as e: + print( + "Error during model generation. Please check if your transformers version is " + "compatible with the model." + ) + print(f"Error details: {e}") + raise if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": print("Applying nvfp4 quantization (MoE only) for gpt-oss") # quantize the model model = quantize_model(model, quant_cfg, args, calib_dataloader, calibration_only) + + # For VL models, update full_model to use the quantized language model + if is_nemotron_vl and hasattr(full_model, "language_model"): + print("Updating full_model with quantized language_model...") + full_model.language_model = model if args.verbose: mtq.print_quant_summary(model) # Run some samples torch.cuda.empty_cache() generated_ids_after_ptq = None - if model_type != "llama4": + if model_type != "llama4" and not is_nemotron_vl: # Our fake quantizer may not be fully compatible with torch.compile. generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100) + elif is_nemotron_vl: + print("Running text-only preview generation for quantized Nemotron VL model...") + try: + # Try text-only generation using model.chat with None for images + if tokenizer is None: + raise ValueError("Tokenizer is required for Nemotron VL text generation") + + question = tokenizer.decode(input_ids[0], skip_special_tokens=True) + generation_config = { + "max_new_tokens": 100, + "do_sample": False, + "eos_token_id": tokenizer.eos_token_id, + } + + # Use model.chat with None for images (text-only mode) + text_response = full_model.chat( + tokenizer, None, question, generation_config, history=None + ) + generated_ids_after_ptq = text_response # Store text response + print(f"✅ Text-only generation successful: {text_response[:100]}...") + + except Exception as e: + print(f"Text-only generation failed: {e}") + generated_ids_after_ptq = None + + # Run additional VL test with images + print("Running additional VL test with images...") + _run_vl_preview_generation( + full_model, tokenizer, args.pyt_ckpt_path, "after quantization (VL test)" + ) + else: warnings.warn( "Llama4 Maverick generation after quantization has a bug. Skipping generation sample." @@ -547,15 +730,25 @@ def output_decode(generated_ids, input_shape): if generated_ids_after_ptq is not None: print("--------") - print(f"example test input: {input_decode(input_ids)}") - print("--------") - print( - f"example outputs before ptq: {output_decode(generated_ids_before_ptq, input_ids.shape[1])}" - ) - print("--------") - print( - f"example outputs after ptq: {output_decode(generated_ids_after_ptq, input_ids.shape[1])}" - ) + if is_nemotron_vl: + # For Nemotron VL models, generated_ids are text strings from model.chat() + print("Nemotron VL model text-only generation results:") + print(f"Text response before quantization: {generated_ids_before_ptq}") + print("--------") + print(f"Text response after quantization: {generated_ids_after_ptq}") + print("--------") + print("Note: Additional VL tests with images were run separately above") + else: + # For regular LLMs, generated_ids are token tensors that need decoding + print(f"example test input: {input_decode(input_ids)}") + print("--------") + print( + f"example outputs before ptq: {output_decode(generated_ids_before_ptq, input_ids.shape[1])}" + ) + print("--------") + print( + f"example outputs after ptq: {output_decode(generated_ids_after_ptq, input_ids.shape[1])}" + ) else: warnings.warn("Skipping quantization: model is already quantized.") @@ -577,19 +770,55 @@ def output_decode(generated_ids, input_shape): # Save original model config and the processor config to the export path for VLMs. print(f"Saving original model config to {export_path}") - AutoConfig.from_pretrained( - args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code - ).save_pretrained(export_path) + config_kwargs = {"trust_remote_code": args.trust_remote_code} + if args.attn_implementation is not None: + config_kwargs["attn_implementation"] = args.attn_implementation + AutoConfig.from_pretrained(args.pyt_ckpt_path, **config_kwargs).save_pretrained( + export_path + ) - # Try to save processor config if available - try: - print(f"Saving processor config to {export_path}") - AutoProcessor.from_pretrained( - args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code - ).save_pretrained(export_path) - except Exception as e: - print(f"Warning: Could not save processor config: {e}") - print("This is normal for some VLM architectures that don't use AutoProcessor") + # Try to save processor config if available (skip for Nemotron VL models) + if not is_nemotron_vl: + try: + print(f"Saving processor config to {export_path}") + AutoProcessor.from_pretrained( + args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code + ).save_pretrained(export_path) + except Exception as e: + print(f"Warning: Could not save processor config: {e}") + print("This is normal for some VLM architectures that don't use AutoProcessor") + else: + print("Skipping AutoProcessor for Nemotron VL (uses separate AutoImageProcessor)") + + # For Nemotron VL models, save image processor using proper HuggingFace APIs + if is_nemotron_vl: + import os + import shutil + + # Try to save image processor config using HuggingFace API + try: + print("Saving image processor config using AutoImageProcessor...") + image_processor = AutoImageProcessor.from_pretrained( + args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code + ) + image_processor.save_pretrained(export_path) + print(" ✅ Image processor config saved successfully") + except Exception as e: + print(f" Warning: Could not save image processor config: {e}") + + # Manually copy image_processing.py as it contains custom code that save_pretrained doesn't handle + print("Copying custom image processing implementation...") + src_path = os.path.join(args.pyt_ckpt_path, "image_processing.py") + dst_path = os.path.join(export_path, "image_processing.py") + + if os.path.exists(src_path): + try: + shutil.copy2(src_path, dst_path) + print(" ✅ Copied: image_processing.py") + except Exception as copy_e: + print(f" Warning: Could not copy image_processing.py: {copy_e}") + else: + print(" Warning: image_processing.py not found in source model") if model_type == "mllama": full_model_config = model.config @@ -758,10 +987,10 @@ def output_decode(generated_ids, input_shape): parser.add_argument( "--attn_implementation", help=( - "Specify the attention implementation to use." + "Specify the attention implementation to use. " "This arg will be passed to the HF model loading if specified." ), - default=None, + default="eager", type=str, ) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index f514e660d..5be3ec460 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -131,6 +131,14 @@ def _output_hook(module, input, output): with torch.no_grad(): fake_input = torch.ones([1, 2], dtype=torch.long).to(model.device) decoder_fake_input = fake_input + + # Check if this is a VL model that needs special input handling + is_vl_model = ( + hasattr(model.config, "vision_config") + or hasattr(model, "vision_model") + or "nemotron" in getattr(model, "name_or_path", "").lower() + ) + if model_type.startswith("whisper"): # For Whisper models, we need to pass a fake input with the specific sequence length from transformers import AutoFeatureExtractor @@ -139,6 +147,9 @@ def _output_hook(module, input, output): fake_input = torch.ones( [1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype ).to(model.device) + elif is_vl_model: + # For VL models, run optimization on language model component only + print("Detected VL model during export - optimizing language model component") # Run forward pass so that all modules sharing the same input are collected using forward hook. @@ -146,6 +157,35 @@ def _output_hook(module, input, output): if getattr(model.config, "is_encoder_decoder", False): # For encoder-decoder models, we need to pass both the encoder and decoder input ids model(fake_input, decoder_input_ids=decoder_fake_input) + elif is_vl_model: + # For VL models, try to run optimization on just the language model part + language_model = None + if hasattr(model, "language_model"): + language_model = model.language_model + print( + "Found language_model attribute - running optimization on language model only" + ) + elif hasattr(model, "model") and hasattr(model.model, "language_model"): + language_model = model.model.language_model + print( + "Found language_model in model.model - running optimization on language model only" + ) + + if language_model is not None: + # Run optimization on just the language model with the same input format as regular LLMs + # Use the same fake_input tensor that regular LLMs use + print( + f"Running optimization on language model with fake_input shape: {fake_input.shape}" + ) + try: + language_model(fake_input) + print("✅ Language model optimization completed successfully") + except Exception as e: + print(f"Language model optimization failed: {e}") + print("Continuing with export...") + else: + print("Warning: No language_model found in VL model - skipping optimization") + print("This is unexpected for most VL models") else: model(fake_input)