Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,17 @@ def get_model(
if device == "cpu":
device_map = "cpu"

# Special handling for vision-language models that may have device mapping issues
# Check if this is a VL model by looking at the model path
is_vl_model = any(
vl_keyword in ckpt_path.lower() for vl_keyword in ["vl", "vision", "nemotron-nano-vl"]
)
if is_vl_model:
print(
"Detected vision-language model. Disabling automatic device mapping to avoid device_map errors."
)
device_map = None

config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {}
if attn_implementation is not None:
config_kwargs["attn_implementation"] = attn_implementation
Expand Down Expand Up @@ -235,6 +246,12 @@ def get_model(
**model_kwargs,
)
model.eval()

# If device_map was disabled (None), manually move model to target device
if device_map is None and device != "cpu":
print(f"Moving model to {device} device...")
model = model.to(device)

if device == "cuda" and not is_model_on_gpu(model):
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")

Expand Down
293 changes: 261 additions & 32 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from example_utils import apply_kv_cache_quant, get_model, get_processor, get_tokenizer, is_enc_dec
from transformers import (
AutoConfig,
AutoImageProcessor,
AutoModelForCausalLM,
AutoProcessor,
PreTrainedTokenizer,
Expand Down Expand Up @@ -83,6 +84,86 @@
mto.enable_huggingface_checkpointing()


def _run_vl_preview_generation(model, tokenizer, model_path, stage_name):
"""Run preview generation for VL models using sample images.

Args:
model: The VL model
tokenizer: The tokenizer
model_path: Path to the model (for loading image processor)
stage_name: Description of the stage (e.g., "before quantization")

Returns:
Generated response text for logging/comparison
"""
import os

from PIL import Image
from transformers import AutoImageProcessor

try:
print(f"Loading sample images for {stage_name} preview...")

# Load image processor
image_processor = AutoImageProcessor.from_pretrained(model_path, trust_remote_code=True)

# Load sample images from the images directory
script_dir = os.path.dirname(os.path.abspath(__file__))
images_dir = os.path.join(script_dir, "images")

image_files = ["example1a.jpeg", "example1b.jpeg"]
images = []
for img_file in image_files:
img_path = os.path.join(images_dir, img_file)
if os.path.exists(img_path):
images.append(Image.open(img_path))
print(f" Loaded: {img_file}")
else:
print(f" Warning: {img_file} not found")

if not images:
print("No sample images found - skipping VL preview generation")
return None

# Process images
image_features = image_processor(images)

# Move image features to the same device as the model
model_device = model.device
for key, value in image_features.items():
if hasattr(value, "to"): # Check if it's a tensor
image_features[key] = value.to(model_device)
print(f" Moved {key} to {model_device}")

# Generate response
question = "Describe these images briefly."
generation_config = {
"max_new_tokens": 50,
"do_sample": False,
"eos_token_id": tokenizer.eos_token_id,
}

print(f"Generating VL response ({stage_name})...")
response = model.chat(
tokenizer=tokenizer,
question=question,
generation_config=generation_config,
**image_features,
)

print(f"✅ VL generation {stage_name} successful!")
print(f"Question: {question}")
print(f"Response: {response}")

# Return the response for comparison/logging
return response

except Exception as e:
print(f"❌ VL preview generation {stage_name} failed: {e}")
print("This may indicate issues with the quantized model")
return None


def auto_quantize(
model, qformat, auto_quantize_bits, calib_dataloader, calibrate_loop, batch_size=1
):
Expand Down Expand Up @@ -281,6 +362,16 @@ def main(args):

model_type = get_model_type(model)

# Special handling for Nemotron VL models that aren't detected by standard model type detection
# For HF export, we want to keep vision unquantized, so we treat it as a regular language model
# and only quantize the language components
if model_type != "mllama" and is_multimodal_model(model):
print(
f"Detected multimodal model: {type(model).__name__}. "
f"For HF export, will quantize language components only, keeping vision unquantized."
)
# Keep as regular model type to use text-only calibration

device = model.device
if hasattr(model, "model"):
device = model.model.device
Expand Down Expand Up @@ -487,34 +578,126 @@ def main(args):
"Please set the default input_mode to InputMode.LANGUAGE before quantizing."
)

# For Nemotron VL models, disable quantization of vision components
is_nemotron_vl = (
"nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower()
)
if is_nemotron_vl:
print("Disabling quantization for vision components in Nemotron VL model")
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
# Also disable radio model components specifically
quant_cfg["quant_cfg"]["*radio*"] = {"enable": False}
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}

if not model_is_already_quantized or calibration_only:
# Only run single sample for preview
input_ids = next(iter(calib_dataloader))[
"input_features" if model_type == "whisper" else "input_ids"
][0:1]
try:
generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100)
except Exception as e:
print(
"Error during model generation. Please check if your transformers version is "
"compatible with the model."

# For Nemotron VL models, try text-only generation first, then VL generation as additional test
is_nemotron_vl = (
"nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower()
)
if is_nemotron_vl:
print("Running text-only preview generation for Nemotron VL model...")
try:
# Try text-only generation using model.chat with None for images
if tokenizer is None:
raise ValueError("Tokenizer is required for Nemotron VL text generation")

question = tokenizer.decode(input_ids[0], skip_special_tokens=True)
generation_config = {
"max_new_tokens": 100,
"do_sample": False,
"eos_token_id": tokenizer.eos_token_id,
}

# Use model.chat with None for images (text-only mode)
text_response = full_model.chat(
tokenizer, None, question, generation_config, history=None
)
generated_ids_before_ptq = text_response # Store text response
print(f"✅ Text-only generation successful: {text_response[:100]}...")

except Exception as e:
print(f"Text-only generation failed: {e}")
print("Falling back to standard generate() method...")
try:
generated_ids_before_ptq = full_model.generate(
input_ids, max_new_tokens=100
)
except Exception as e2:
print(f"Standard generation also failed: {e2}")
generated_ids_before_ptq = None

# Run additional VL test with images
print("Running additional VL test with images...")
_run_vl_preview_generation(
full_model, tokenizer, args.pyt_ckpt_path, "before quantization (VL test)"
)
print(f"Error details: {e}")
raise

else:
try:
generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100)
except Exception as e:
print(
"Error during model generation. Please check if your transformers version is "
"compatible with the model."
)
print(f"Error details: {e}")
raise
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
print("Applying nvfp4 quantization (MoE only) for gpt-oss")

# quantize the model
model = quantize_model(model, quant_cfg, args, calib_dataloader, calibration_only)

# For VL models, update full_model to use the quantized language model
if is_nemotron_vl and hasattr(full_model, "language_model"):
print("Updating full_model with quantized language_model...")
full_model.language_model = model
if args.verbose:
mtq.print_quant_summary(model)

# Run some samples
torch.cuda.empty_cache()
generated_ids_after_ptq = None
if model_type != "llama4":
if model_type != "llama4" and not is_nemotron_vl:
# Our fake quantizer may not be fully compatible with torch.compile.
generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100)
elif is_nemotron_vl:
print("Running text-only preview generation for quantized Nemotron VL model...")
try:
# Try text-only generation using model.chat with None for images
if tokenizer is None:
raise ValueError("Tokenizer is required for Nemotron VL text generation")

question = tokenizer.decode(input_ids[0], skip_special_tokens=True)
generation_config = {
"max_new_tokens": 100,
"do_sample": False,
"eos_token_id": tokenizer.eos_token_id,
}

# Use model.chat with None for images (text-only mode)
text_response = full_model.chat(
tokenizer, None, question, generation_config, history=None
)
generated_ids_after_ptq = text_response # Store text response
print(f"✅ Text-only generation successful: {text_response[:100]}...")

except Exception as e:
print(f"Text-only generation failed: {e}")
generated_ids_after_ptq = None

# Run additional VL test with images
print("Running additional VL test with images...")
_run_vl_preview_generation(
full_model, tokenizer, args.pyt_ckpt_path, "after quantization (VL test)"
)

else:
warnings.warn(
"Llama4 Maverick generation after quantization has a bug. Skipping generation sample."
Expand Down Expand Up @@ -547,15 +730,25 @@ def output_decode(generated_ids, input_shape):

if generated_ids_after_ptq is not None:
print("--------")
print(f"example test input: {input_decode(input_ids)}")
print("--------")
print(
f"example outputs before ptq: {output_decode(generated_ids_before_ptq, input_ids.shape[1])}"
)
print("--------")
print(
f"example outputs after ptq: {output_decode(generated_ids_after_ptq, input_ids.shape[1])}"
)
if is_nemotron_vl:
# For Nemotron VL models, generated_ids are text strings from model.chat()
print("Nemotron VL model text-only generation results:")
print(f"Text response before quantization: {generated_ids_before_ptq}")
print("--------")
print(f"Text response after quantization: {generated_ids_after_ptq}")
print("--------")
print("Note: Additional VL tests with images were run separately above")
else:
# For regular LLMs, generated_ids are token tensors that need decoding
print(f"example test input: {input_decode(input_ids)}")
print("--------")
print(
f"example outputs before ptq: {output_decode(generated_ids_before_ptq, input_ids.shape[1])}"
)
print("--------")
print(
f"example outputs after ptq: {output_decode(generated_ids_after_ptq, input_ids.shape[1])}"
)
else:
warnings.warn("Skipping quantization: model is already quantized.")

Expand All @@ -577,19 +770,55 @@ def output_decode(generated_ids, input_shape):
# Save original model config and the processor config to the export path for VLMs.
print(f"Saving original model config to {export_path}")

AutoConfig.from_pretrained(
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
).save_pretrained(export_path)
config_kwargs = {"trust_remote_code": args.trust_remote_code}
if args.attn_implementation is not None:
config_kwargs["attn_implementation"] = args.attn_implementation
AutoConfig.from_pretrained(args.pyt_ckpt_path, **config_kwargs).save_pretrained(
export_path
)

# Try to save processor config if available
try:
print(f"Saving processor config to {export_path}")
AutoProcessor.from_pretrained(
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
).save_pretrained(export_path)
except Exception as e:
print(f"Warning: Could not save processor config: {e}")
print("This is normal for some VLM architectures that don't use AutoProcessor")
# Try to save processor config if available (skip for Nemotron VL models)
if not is_nemotron_vl:
try:
print(f"Saving processor config to {export_path}")
AutoProcessor.from_pretrained(
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
).save_pretrained(export_path)
except Exception as e:
print(f"Warning: Could not save processor config: {e}")
print("This is normal for some VLM architectures that don't use AutoProcessor")
else:
print("Skipping AutoProcessor for Nemotron VL (uses separate AutoImageProcessor)")

# For Nemotron VL models, save image processor using proper HuggingFace APIs
if is_nemotron_vl:
import os
import shutil

# Try to save image processor config using HuggingFace API
try:
print("Saving image processor config using AutoImageProcessor...")
image_processor = AutoImageProcessor.from_pretrained(
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
)
image_processor.save_pretrained(export_path)
print(" ✅ Image processor config saved successfully")
except Exception as e:
print(f" Warning: Could not save image processor config: {e}")

# Manually copy image_processing.py as it contains custom code that save_pretrained doesn't handle
print("Copying custom image processing implementation...")
src_path = os.path.join(args.pyt_ckpt_path, "image_processing.py")
dst_path = os.path.join(export_path, "image_processing.py")

if os.path.exists(src_path):
try:
shutil.copy2(src_path, dst_path)
print(" ✅ Copied: image_processing.py")
except Exception as copy_e:
print(f" Warning: Could not copy image_processing.py: {copy_e}")
else:
print(" Warning: image_processing.py not found in source model")

if model_type == "mllama":
full_model_config = model.config
Expand Down Expand Up @@ -758,10 +987,10 @@ def output_decode(generated_ids, input_shape):
parser.add_argument(
"--attn_implementation",
help=(
"Specify the attention implementation to use."
"Specify the attention implementation to use. "
"This arg will be passed to the HF model loading if specified."
),
default=None,
default="eager",
type=str,
)

Expand Down
Loading
Loading