Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 51 additions & 21 deletions .github/scripts/torchao_model_releases/quantize_and_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def _get_username():
return username


def _untie_weights_and_save_locally(model_id):
def _untie_weights_and_save_locally(model_id, device):
untied_model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype="auto", device_map="cuda:0"
model_id, torch_dtype="auto", device_map=device
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down Expand Up @@ -150,7 +150,7 @@ def _untie_weights_and_save_locally(model_id):
inputs = tokenizer(
templated_prompt,
return_tensors="pt",
).to("cuda")
).to("{device}")
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
output_text = tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
Expand Down Expand Up @@ -222,15 +222,15 @@ def _untie_weights_and_save_locally(model_id):
from torchao.quantization import Int4WeightOnlyConfig
quant_config = Int4WeightOnlyConfig(group_size=128, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq")
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="{device}", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
"""

_fp8_quant_code = """
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
quant_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="{device}", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
"""

Expand All @@ -251,7 +251,7 @@ def _untie_weights_and_save_locally(model_id):
)
quant_config = ModuleFqnToConfig({{"_default": linear_config, "model.embed_tokens": embedding_config}})
quantization_config = TorchAoConfig(quant_type=quant_config, include_input_output_embeddings=True, modules_to_not_convert=[])
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="{device}", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
"""

Expand All @@ -274,7 +274,7 @@ def _untie_weights_and_save_locally(model_id):
)
quant_config = ModuleFqnToConfig({{"_default": linear_config, "model.embed_tokens": embedding_config}})
quantization_config = TorchAoConfig(quant_type=quant_config, include_input_output_embeddings=True, modules_to_not_convert=[])
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="{device}", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
"""

Expand Down Expand Up @@ -322,7 +322,7 @@ def _untie_weights_and_save_locally(model_id):
from torchao._models._eval import TransformerEvalWrapper
model = AutoModelForCausalLM.from_pretrained(
model_to_quantize,
device_map="cuda:0",
device_map="{device}",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down Expand Up @@ -405,7 +405,7 @@ def _untie_weights_and_save_locally(model_id):
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="cuda:0"
device_map="{device}"
)

# prepare the model input
Expand Down Expand Up @@ -466,7 +466,7 @@ def _untie_weights_and_save_locally(model_id):

# use "{base_model}" or "{quantized_model}"
model_id = "{quantized_model}"
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype=torch.bfloat16)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="{device}", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)

torch.cuda.reset_peak_memory_stats()
Expand All @@ -489,7 +489,7 @@ def _untie_weights_and_save_locally(model_id):
inputs = tokenizer(
templated_prompt,
return_tensors="pt",
).to("cuda")
).to("{device}")
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
output_text = tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
Expand Down Expand Up @@ -569,7 +569,7 @@ def _untie_weights_and_save_locally(model_id):
import torch

model_id = "{base_model}"
untied_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="cuda:0")
untied_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="{device}")
tokenizer = AutoTokenizer.from_pretrained(model_id)

print(untied_model)
Expand Down Expand Up @@ -660,6 +660,7 @@ def quantize_and_upload(
push_to_hub: bool,
push_to_user_id: str,
populate_model_card_template: bool,
device: str,
):
is_mobile = quant in ["INT8-INT4", "INT8-INT4-HQQ"]

Expand Down Expand Up @@ -711,16 +712,15 @@ def quantize_and_upload(
# preparation
model_to_quantize = model_id
if is_mobile:
model_to_quantize = _untie_weights_and_save_locally(model_to_quantize)
model_to_quantize = _untie_weights_and_save_locally(model_to_quantize, device)

# quantization

if "AWQ" in quant:
# awq will use torchao API directly
assert quant == "AWQ-INT4", "Only support AWQ-INT4 for now"
model = AutoModelForCausalLM.from_pretrained(
model_to_quantize,
device_map="cuda:0",
device_map=device,
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down Expand Up @@ -766,7 +766,7 @@ def filter_fn_skip_lmhead(module, fqn):
config=quantized_model.config,
quantization_config=None,
dtype=torch.bfloat16,
device_map="cuda:0",
device_map=device,
weights_only=True,
user_agent={
"file_type": "model",
Expand Down Expand Up @@ -842,16 +842,19 @@ def filter_fn_skip_lmhead(module, fqn):
quantized_model=quantized_model_id,
model_type=quantized_model.config.model_type,
quant=quant,
quant_code=quant_to_quant_code[quant],
quant_code=quant_to_quant_code[quant].format(device=device),
safe_serialization=safe_serialization,
device=device,
# server specific recipes
server_inference_recipe=""
if is_mobile
else _server_inference_recipe.format(quantized_model=quantized_model_id),
else _server_inference_recipe.format(
quantized_model=quantized_model_id, device=device
),
server_peak_memory_usage=""
if is_mobile
else _server_peak_memory_usage.format(
base_model=model_id, quantized_model=quantized_model_id
base_model=model_id, quantized_model=quantized_model_id, device=device
),
server_model_performance=""
if is_mobile
Expand All @@ -860,7 +863,11 @@ def filter_fn_skip_lmhead(module, fqn):
),
# mobile specific recipes
untied_model=untied_model_path if is_mobile else model_id,
untie_embedding_recipe=_untie_embedding_recipe if is_mobile else "",
untie_embedding_recipe=_untie_embedding_recipe.format(
base_model=model_id, device=device
)
if is_mobile
else "",
mobile_inference_recipe=_mobile_inference_recipe.format(
quantized_model=quantized_model_id
)
Expand Down Expand Up @@ -920,11 +927,15 @@ def filter_fn_skip_lmhead(module, fqn):
description="Evaluate a model with the specified parameters."
)
parser.add_argument(
"--model_id", type=str, help="Huggingface hub model ID of the model."
"--model_id",
type=str,
required=True,
help="Huggingface hub model ID of the model.",
)
parser.add_argument(
"--quant",
type=str,
required=True,
help="Quantization method. Options are FP8, INT4, INT8-INT4, INT8-INT4-HQQ, AWQ-INT4, SmoothQuant-INT8-INT8",
)
parser.add_argument(
Expand Down Expand Up @@ -964,7 +975,25 @@ def filter_fn_skip_lmhead(module, fqn):
default=False,
help="Flag to indicate whether push model card to huggingface hub or not",
)
parser.add_argument(
"--device",
type=str,
default="cuda:0",
help="Device to run the model on (e.g., 'cuda', 'cuda:0', 'cpu'). Default is 'cuda:0'",
)
args = parser.parse_args()

# Validate CUDA devicedevice
try:
device = torch.device(args.device)
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available on this system")
if device.index is not None and device.index >= torch.cuda.device_count():
raise RuntimeError(
f"CUDA device {device.index} is not available. Available devices: {torch.cuda.device_count()}"
)
except RuntimeError as e:
parser.error(f"Invalid device '{args.device}': {e}")
quantize_and_upload(
args.model_id,
args.quant,
Expand All @@ -974,4 +1003,5 @@ def filter_fn_skip_lmhead(module, fqn):
args.push_to_hub,
args.push_to_user_id,
args.populate_model_card_template,
args.device,
)