7
7
from llmcompressor .transformers import oneshot
8
8
from llmcompressor .modifiers .quantization import QuantizationModifier
9
9
10
+
10
11
class BaseQuantizeConfig :
11
12
"""Configuration for model quantization.
12
13
@@ -24,6 +25,7 @@ class BaseQuantizeConfig:
24
25
By default, "lm_head" is included to ignore the embedding
25
26
Linear layer usually at the end of decoder LLMs
26
27
"""
28
+
27
29
def __init__ (
28
30
self ,
29
31
quant_method : str = "fp8" ,
@@ -36,32 +38,41 @@ def __init__(
36
38
37
39
38
40
class AutoFP8ForCausalLM :
39
- def __init__ (self , model : SparseAutoModelForCausalLM , quantize_config : BaseQuantizeConfig ):
41
+ def __init__ (
42
+ self , model : SparseAutoModelForCausalLM , quantize_config : BaseQuantizeConfig
43
+ ):
40
44
self .model = model
41
45
self .model_type = self .model .config .model_type
42
46
self .config = self .model .config
43
47
self .quantize_config = quantize_config
44
48
45
49
@classmethod
46
- def from_pretrained (cls , pretrained_model_name_or_path : str , quantize_config : BaseQuantizeConfig , ** kwargs ):
50
+ def from_pretrained (
51
+ cls ,
52
+ pretrained_model_name_or_path : str ,
53
+ quantize_config : BaseQuantizeConfig ,
54
+ ** kwargs ,
55
+ ):
47
56
config = AutoConfig .from_pretrained (pretrained_model_name_or_path )
48
57
model = SparseAutoModelForCausalLM .from_pretrained (
49
58
pretrained_model_name_or_path ,
50
59
config = config ,
51
60
device_map = "auto" ,
52
61
torch_dtype = "auto" ,
53
- ** kwargs
62
+ ** kwargs ,
54
63
)
55
64
return cls (model , quantize_config )
56
65
57
66
def quantize (self , dataset : Optional [Dataset ] = None ):
58
- assert self .quantize_config .activation_scheme == "static"
59
- assert dataset is not None , "Calibration tokens required for static activation quantization"
67
+ assert (
68
+ self .quantize_config .activation_scheme == "static"
69
+ ), "Dynamic isn't supported yet"
70
+ assert (
71
+ dataset is not None
72
+ ), "Calibration tokens required for static activation quantization"
60
73
61
74
recipe = QuantizationModifier (
62
- targets = "Linear" ,
63
- scheme = "FP8" ,
64
- ignore = self .quantize_config .ignore_patterns
75
+ targets = "Linear" , scheme = "FP8" , ignore = self .quantize_config .ignore_patterns
65
76
)
66
77
67
78
oneshot (
@@ -73,7 +84,7 @@ def quantize(self, dataset: Optional[Dataset] = None):
73
84
def save_quantized (self , save_directory : str ):
74
85
self .save_pretrained (save_directory , save_compressed = True )
75
86
76
- def save_pretrained (self , save_directory : str , save_compressed : bool = True ):
87
+ def save_pretrained (self , save_directory : str , save_compressed : bool = True ):
77
88
self .model .save_pretrained (save_directory , save_compressed = save_compressed )
78
89
tokenizer = AutoTokenizer .from_pretrained (self .model .config ._name_or_path )
79
90
tokenizer .save_pretrained (save_directory )
0 commit comments