diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 002045215c..f81b105ea2 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -93,9 +93,16 @@ def run_evaluation( assert groupsize in [32, 64, 128, 256], ( f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" ) + int4_packing_format = ( + "plain_int32" if device == "xpu" else "tile_packed_to_4d" + ) quantize_( model.to(device), - Int4WeightOnlyConfig(group_size=groupsize, use_hqq=use_hqq, version=1), + Int4WeightOnlyConfig( + group_size=groupsize, + use_hqq=use_hqq, + int4_packing_format=int4_packing_format, + ), ) if "uintx" in quantization: # uintx-nbits-groupsize diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index fc3d371139..6b4667a442 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -430,9 +430,16 @@ def ffn_or_attn_only(mod, fqn): ], ( f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" ) + int4_packing_format = ( + "plain_int32" if device == "xpu" else "tile_packed_to_4d" + ) quantize_( model, - Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq, version=1), + Int4WeightOnlyConfig( + group_size=group_size, + use_hqq=use_hqq, + int4_packing_format=int4_packing_format, + ), ) elif "int4dq-" in quantization: from torchao.dtypes import CutlassInt4PackedLayout