Skip to content

Commit c15e352

Browse files
committed
Format
1 parent 00a8cd8 commit c15e352

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

auto_fp8/modeling.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from llmcompressor.transformers import oneshot
88
from llmcompressor.modifiers.quantization import QuantizationModifier
99

10+
1011
class BaseQuantizeConfig:
1112
"""Configuration for model quantization.
1213
@@ -24,6 +25,7 @@ class BaseQuantizeConfig:
2425
By default, "lm_head" is included to ignore the embedding
2526
Linear layer usually at the end of decoder LLMs
2627
"""
28+
2729
def __init__(
2830
self,
2931
quant_method: str = "fp8",
@@ -36,32 +38,41 @@ def __init__(
3638

3739

3840
class AutoFP8ForCausalLM:
39-
def __init__(self, model: SparseAutoModelForCausalLM, quantize_config: BaseQuantizeConfig):
41+
def __init__(
42+
self, model: SparseAutoModelForCausalLM, quantize_config: BaseQuantizeConfig
43+
):
4044
self.model = model
4145
self.model_type = self.model.config.model_type
4246
self.config = self.model.config
4347
self.quantize_config = quantize_config
4448

4549
@classmethod
46-
def from_pretrained(cls, pretrained_model_name_or_path: str, quantize_config: BaseQuantizeConfig, **kwargs):
50+
def from_pretrained(
51+
cls,
52+
pretrained_model_name_or_path: str,
53+
quantize_config: BaseQuantizeConfig,
54+
**kwargs,
55+
):
4756
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
4857
model = SparseAutoModelForCausalLM.from_pretrained(
4958
pretrained_model_name_or_path,
5059
config=config,
5160
device_map="auto",
5261
torch_dtype="auto",
53-
**kwargs
62+
**kwargs,
5463
)
5564
return cls(model, quantize_config)
5665

5766
def quantize(self, dataset: Optional[Dataset] = None):
58-
assert self.quantize_config.activation_scheme == "static"
59-
assert dataset is not None, "Calibration tokens required for static activation quantization"
67+
assert (
68+
self.quantize_config.activation_scheme == "static"
69+
), "Dynamic isn't supported yet"
70+
assert (
71+
dataset is not None
72+
), "Calibration tokens required for static activation quantization"
6073

6174
recipe = QuantizationModifier(
62-
targets="Linear",
63-
scheme="FP8",
64-
ignore=self.quantize_config.ignore_patterns
75+
targets="Linear", scheme="FP8", ignore=self.quantize_config.ignore_patterns
6576
)
6677

6778
oneshot(
@@ -73,7 +84,7 @@ def quantize(self, dataset: Optional[Dataset] = None):
7384
def save_quantized(self, save_directory: str):
7485
self.save_pretrained(save_directory, save_compressed=True)
7586

76-
def save_pretrained(self, save_directory: str, save_compressed: bool = True):
87+
def save_pretrained(self, save_directory: str, save_compressed: bool = True):
7788
self.model.save_pretrained(save_directory, save_compressed=save_compressed)
7889
tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
7990
tokenizer.save_pretrained(save_directory)

0 commit comments

Comments
 (0)