diff --git a/.github/scripts/torchao_model_releases/quantize_and_upload.py b/.github/scripts/torchao_model_releases/quantize_and_upload.py index ab6338aff5..3d5c7f4454 100644 --- a/.github/scripts/torchao_model_releases/quantize_and_upload.py +++ b/.github/scripts/torchao_model_releases/quantize_and_upload.py @@ -44,9 +44,9 @@ def _get_username(): return username -def _untie_weights_and_save_locally(model_id): +def _untie_weights_and_save_locally(model_id, device): untied_model = AutoModelForCausalLM.from_pretrained( - model_id, torch_dtype="auto", device_map="cuda:0" + model_id, torch_dtype="auto", device_map=device ) tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -150,7 +150,7 @@ def _untie_weights_and_save_locally(model_id): inputs = tokenizer( templated_prompt, return_tensors="pt", -).to("cuda") +).to("{device}") generated_ids = quantized_model.generate(**inputs, max_new_tokens=128) output_text = tokenizer.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False @@ -222,7 +222,7 @@ def _untie_weights_and_save_locally(model_id): from torchao.quantization import Int4WeightOnlyConfig quant_config = Int4WeightOnlyConfig(group_size=128, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq") quantization_config = TorchAoConfig(quant_type=quant_config) -quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config) +quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="{device}", torch_dtype=torch.bfloat16, quantization_config=quantization_config) tokenizer = AutoTokenizer.from_pretrained(model_id) """ @@ -230,7 +230,7 @@ def _untie_weights_and_save_locally(model_id): from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow quant_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) quantization_config = TorchAoConfig(quant_type=quant_config) -quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config) +quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="{device}", torch_dtype=torch.bfloat16, quantization_config=quantization_config) tokenizer = AutoTokenizer.from_pretrained(model_id) """ @@ -251,7 +251,7 @@ def _untie_weights_and_save_locally(model_id): ) quant_config = ModuleFqnToConfig({{"_default": linear_config, "model.embed_tokens": embedding_config}}) quantization_config = TorchAoConfig(quant_type=quant_config, include_input_output_embeddings=True, modules_to_not_convert=[]) -quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config) +quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="{device}", torch_dtype=torch.bfloat16, quantization_config=quantization_config) tokenizer = AutoTokenizer.from_pretrained(model_id) """ @@ -274,7 +274,7 @@ def _untie_weights_and_save_locally(model_id): ) quant_config = ModuleFqnToConfig({{"_default": linear_config, "model.embed_tokens": embedding_config}}) quantization_config = TorchAoConfig(quant_type=quant_config, include_input_output_embeddings=True, modules_to_not_convert=[]) -quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config) +quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="{device}", torch_dtype=torch.bfloat16, quantization_config=quantization_config) tokenizer = AutoTokenizer.from_pretrained(model_id) """ @@ -322,7 +322,7 @@ def _untie_weights_and_save_locally(model_id): from torchao._models._eval import TransformerEvalWrapper model = AutoModelForCausalLM.from_pretrained( model_to_quantize, - device_map="cuda:0", + device_map="{device}", torch_dtype=torch.bfloat16, ) tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -405,7 +405,7 @@ def _untie_weights_and_save_locally(model_id): model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", - device_map="cuda:0" + device_map="{device}" ) # prepare the model input @@ -466,7 +466,7 @@ def _untie_weights_and_save_locally(model_id): # use "{base_model}" or "{quantized_model}" model_id = "{quantized_model}" -quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype=torch.bfloat16) +quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="{device}", torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_id) torch.cuda.reset_peak_memory_stats() @@ -489,7 +489,7 @@ def _untie_weights_and_save_locally(model_id): inputs = tokenizer( templated_prompt, return_tensors="pt", -).to("cuda") +).to("{device}") generated_ids = quantized_model.generate(**inputs, max_new_tokens=128) output_text = tokenizer.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False @@ -569,7 +569,7 @@ def _untie_weights_and_save_locally(model_id): import torch model_id = "{base_model}" -untied_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="cuda:0") +untied_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="{device}") tokenizer = AutoTokenizer.from_pretrained(model_id) print(untied_model) @@ -660,6 +660,7 @@ def quantize_and_upload( push_to_hub: bool, push_to_user_id: str, populate_model_card_template: bool, + device: str, ): is_mobile = quant in ["INT8-INT4", "INT8-INT4-HQQ"] @@ -711,16 +712,15 @@ def quantize_and_upload( # preparation model_to_quantize = model_id if is_mobile: - model_to_quantize = _untie_weights_and_save_locally(model_to_quantize) + model_to_quantize = _untie_weights_and_save_locally(model_to_quantize, device) # quantization - if "AWQ" in quant: # awq will use torchao API directly assert quant == "AWQ-INT4", "Only support AWQ-INT4 for now" model = AutoModelForCausalLM.from_pretrained( model_to_quantize, - device_map="cuda:0", + device_map=device, torch_dtype=torch.bfloat16, ) tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -766,7 +766,7 @@ def filter_fn_skip_lmhead(module, fqn): config=quantized_model.config, quantization_config=None, dtype=torch.bfloat16, - device_map="cuda:0", + device_map=device, weights_only=True, user_agent={ "file_type": "model", @@ -842,16 +842,19 @@ def filter_fn_skip_lmhead(module, fqn): quantized_model=quantized_model_id, model_type=quantized_model.config.model_type, quant=quant, - quant_code=quant_to_quant_code[quant], + quant_code=quant_to_quant_code[quant].format(device=device), safe_serialization=safe_serialization, + device=device, # server specific recipes server_inference_recipe="" if is_mobile - else _server_inference_recipe.format(quantized_model=quantized_model_id), + else _server_inference_recipe.format( + quantized_model=quantized_model_id, device=device + ), server_peak_memory_usage="" if is_mobile else _server_peak_memory_usage.format( - base_model=model_id, quantized_model=quantized_model_id + base_model=model_id, quantized_model=quantized_model_id, device=device ), server_model_performance="" if is_mobile @@ -860,7 +863,11 @@ def filter_fn_skip_lmhead(module, fqn): ), # mobile specific recipes untied_model=untied_model_path if is_mobile else model_id, - untie_embedding_recipe=_untie_embedding_recipe if is_mobile else "", + untie_embedding_recipe=_untie_embedding_recipe.format( + base_model=model_id, device=device + ) + if is_mobile + else "", mobile_inference_recipe=_mobile_inference_recipe.format( quantized_model=quantized_model_id ) @@ -920,11 +927,15 @@ def filter_fn_skip_lmhead(module, fqn): description="Evaluate a model with the specified parameters." ) parser.add_argument( - "--model_id", type=str, help="Huggingface hub model ID of the model." + "--model_id", + type=str, + required=True, + help="Huggingface hub model ID of the model.", ) parser.add_argument( "--quant", type=str, + required=True, help="Quantization method. Options are FP8, INT4, INT8-INT4, INT8-INT4-HQQ, AWQ-INT4, SmoothQuant-INT8-INT8", ) parser.add_argument( @@ -964,6 +975,12 @@ def filter_fn_skip_lmhead(module, fqn): default=False, help="Flag to indicate whether push model card to huggingface hub or not", ) + parser.add_argument( + "--device", + type=str, + default="cuda:0", + help="Device to run the model on (e.g., 'cuda', 'cuda:0', 'cpu'). Default is 'cuda:0'", + ) args = parser.parse_args() quantize_and_upload( args.model_id, @@ -974,4 +991,5 @@ def filter_fn_skip_lmhead(module, fqn): args.push_to_hub, args.push_to_user_id, args.populate_model_card_template, + args.device, )