@@ -182,6 +182,8 @@ def validate_awq_after(model: "AWQModifier") -> "AWQModifier":
182182 ), "In AWQ, all config groups must use the same configuration for group_size"
183183
184184 model ._group_size = next (iter (group_size_set ))
185+ if model ._group_size is None :
186+ model ._group_size = - 1
185187
186188 in_num_bits_set = set (
187189 group .input_activations .num_bits
@@ -460,14 +462,16 @@ def _apply_smoothing(self, model: Module) -> None:
460462 weight = torch .cat ([bl .weight for bl in balance_layers ], dim = 0 )
461463 org_shape = weight .shape
462464 # The weights are reshaped to be organised by quantization group
463- weight = weight .view (- 1 , self ._group_size )
464- # Calculates the relative magnitude of the weights within
465- # each of the quantization groups, and rescales each group
466- # individually so that each group has weights on a 0-1 scale.
467- weight .abs_ ()
468- weight .div_ (weight .amax (dim = 1 , keepdim = True ) + 1e-6 )
469- # Resizes the rescaled weight matrix back up to its original dimensions
470- weight = weight .view (org_shape )
465+ if self ._group_size > 0 :
466+ weight = weight .view (- 1 , self ._group_size )
467+ # Calculates the relative magnitude of the weights within
468+ # each of the quantization groups, and rescales each group
469+ # individually so that each group has weights on a 0-1 scale.
470+ weight .abs_ ()
471+ weight .div_ (weight .amax (dim = 1 , keepdim = True ) + 1e-6 )
472+ # Resizes the rescaled weight matrix back up to
473+ # its original dimensions
474+ weight = weight .view (org_shape )
471475 # Gets the average rescaled magnitude for each output channel
472476 w_mean = weight .mean (0 )
473477 del weight
0 commit comments