-
Notifications
You must be signed in to change notification settings - Fork 30.3k
[torchao safetensors] integrate torchao safetensors support with transformers #40735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
[For maintainers] Suggested jobs to run (before merge) run-slow: torchao_integration |
0fe58b2
to
d60acfe
Compare
d60acfe
to
392a504
Compare
@@ -727,6 +729,7 @@ def _load_state_dict_into_meta_model( | |||
keep_in_fp32_regex: Optional[re.Pattern] = None, | |||
unexpected_keys: Optional[list[str]] = None, # passing `unexpected` for cleanup from quantization items | |||
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None, | |||
metadata: Optional[dict] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use metadata_dict here for clarity
if isinstance(state_dict, tuple): | ||
state_dict, metadata = state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this might be a bit confusing I feel, I think we should use a different API to get tensor_data_dict and metadata_dict
@@ -4286,7 +4301,8 @@ def save_pretrained( | |||
if safe_serialization: | |||
# At some point we will need to deal better with save_function (used for TPU and other distributed | |||
# joyfulness), but for now this enough. | |||
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}) | |||
metadata["format"] = "pt" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be the default value in L4022?
@@ -279,6 +286,9 @@ def create_quantized_param( | |||
|
|||
quantize_(module, self.quantization_config.get_apply_tensor_subclass()) | |||
|
|||
def transform_state_dict(self, tensor_data, metadata): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think we can make this a bit more clear, e.g. transform_state_dict_before_saving
@@ -297,10 +307,13 @@ def _process_model_after_weight_loading(self, model, **kwargs): | |||
|
|||
def is_serializable(self, safe_serialization=None) -> bool: | |||
if safe_serialization: | |||
from torchao.quantization import Float8WeightOnlyConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also Float8DynamicActivationFloat8Config
?
|
||
def setUp(self): | ||
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs) | ||
dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto" | ||
from torchao.quantization import Float8WeightOnlyConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a new test instead of overriding the old one?
with tempfile.TemporaryDirectory() as tmpdirname: | ||
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False) | ||
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can add a new test for Float8WeightOnlyConfig and Float8DynamicActivationFloat8WeightConfig I think, and revert all the changes to previous tests
Context
Currently, we need to use
safe_serialization=False
while saving models as shown here. This PR enables safetensors support for torchao so that users can now save and load checkpoints using safetensors. Currently, only Float8Tensor is supported, but allowing other subclasses should involve minimal code changes.Summary
Changes to transformers code includes:
TorchAoHfQuantizer
, we provideget_state_dict
andtransform_state_dict
that flattens/unflattens a model state dict with tensor subclasses by calling functionality built out in this PR.modeling_utils.py
, we make appropriate changes to support propagating the metadata from tensor subclasses. We also add logic similar tohqq
andbnb
to directly load ontocpu
rather thanmeta
.Test Plan
Modified unit test to allow safe serialization. Run using
python tests/quantization/torchao_integration/test_torchao.py