@@ -47,54 +47,54 @@ def infer_quantization_format(
4747 if quantization_format is not None :
4848 return quantization_format
4949
50+ if not save_compressed :
51+ # format will be inferred from config
52+ return None
53+
5054 weight_args , input_args = _get_unique_quant_args (model )
5155 if len (weight_args ) <= 0 :
5256 return None
5357
54- if save_compressed :
55- is_24_structure = (
56- SparsityStructure (sparsity_structure ) == SparsityStructure .TWO_FOUR
58+ is_24_structure = (
59+ SparsityStructure (sparsity_structure ) == SparsityStructure .TWO_FOUR
60+ )
61+ is_weight_only = len (input_args ) == 0 and len (weight_args ) > 0
62+
63+ if (
64+ weight_args [0 ].num_bits == 4
65+ and weight_args [0 ].type == QuantizationType .FLOAT .value
66+ ):
67+ return CompressionFormat .nvfp4_pack_quantized
68+
69+ if is_weight_only : # w4a16 and w8a16
70+ is_valid_pack = all (
71+ weight_arg .num_bits in [4 , 8 ]
72+ and weight_arg .type == QuantizationType .INT .value
73+ for weight_arg in weight_args
5774 )
58- is_weight_only = len (input_args ) == 0 and len (weight_args ) > 0
59-
60- if (
61- weight_args [0 ].num_bits == 4
62- and weight_args [0 ].type == QuantizationType .FLOAT .value
63- ):
64- return CompressionFormat .nvfp4_pack_quantized
65-
66- if is_weight_only : # w4a16 and w8a16
67- is_valid_pack = all (
68- weight_arg .num_bits in [4 , 8 ]
69- and weight_arg .type == QuantizationType .INT .value
70- for weight_arg in weight_args
71- )
72- if not is_valid_pack : # packing only valid for int4 and int 8
73- return CompressionFormat .naive_quantized
74- if is_24_structure :
75- for arg in weight_args :
76- if (
77- arg .strategy is not QuantizationStrategy .CHANNEL .value
78- and arg .strategy is not QuantizationStrategy .GROUP .value
79- ):
80- # marlin24 kernel only applicable for channel/group quantization
81- return CompressionFormat .pack_quantized
82- return CompressionFormat .marlin_24
83- return CompressionFormat .pack_quantized
84- else : # w8a8 float and int
85- if len (weight_args ) == 1 :
75+ if not is_valid_pack : # packing only valid for int4 and int 8
76+ return CompressionFormat .naive_quantized
77+ if is_24_structure :
78+ for arg in weight_args :
8679 if (
87- weight_args [ 0 ]. type == QuantizationType . FLOAT .value
88- and weight_args [ 0 ]. num_bits == 8
80+ arg . strategy is not QuantizationStrategy . CHANNEL .value
81+ and arg . strategy is not QuantizationStrategy . GROUP . value
8982 ):
90- return CompressionFormat .float_quantized
91- if weight_args [0 ].type == QuantizationType .INT .value :
92- return CompressionFormat .int_quantized
83+ # marlin24 kernel only applicable for channel/group quantization
84+ return CompressionFormat .pack_quantized
85+ return CompressionFormat .marlin_24
86+ return CompressionFormat .pack_quantized
87+ else : # w8a8 float and int
88+ if len (weight_args ) == 1 :
89+ if (
90+ weight_args [0 ].type == QuantizationType .FLOAT .value
91+ and weight_args [0 ].num_bits == 8
92+ ):
93+ return CompressionFormat .float_quantized
94+ if weight_args [0 ].type == QuantizationType .INT .value :
95+ return CompressionFormat .int_quantized
9396
94- return CompressionFormat .naive_quantized
95- else :
96- # format will be inferred from config
97- return None
97+ return CompressionFormat .naive_quantized
9898
9999
100100def _get_unique_quant_args (model ):
0 commit comments