Skip to content

Conversation

liangel-02
Copy link

Context

Currently, we need to use safe_serialization=False while saving models as shown here. This PR enables safetensors support for torchao so that users can now save and load checkpoints using safetensors. Currently, only Float8Tensor is supported, but allowing other subclasses should involve minimal code changes.

# default forsafe_serialization is True
quantized_model.push_to_hub(save_to)

Summary

Changes to transformers code includes:

  1. In TorchAoHfQuantizer, we provide get_state_dict and transform_state_dict that flattens/unflattens a model state dict with tensor subclasses by calling functionality built out in this PR.
  2. In modeling_utils.py, we make appropriate changes to support propagating the metadata from tensor subclasses. We also add logic similar to hqq and bnb to directly load onto cpu rather than meta.

Test Plan

Modified unit test to allow safe serialization. Run using python tests/quantization/torchao_integration/test_torchao.py

Copy link
Contributor

github-actions bot commented Sep 5, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: torchao_integration

@liangel-02 liangel-02 marked this pull request as draft September 5, 2025 21:40
@@ -727,6 +729,7 @@ def _load_state_dict_into_meta_model(
keep_in_fp32_regex: Optional[re.Pattern] = None,
unexpected_keys: Optional[list[str]] = None, # passing `unexpected` for cleanup from quantization items
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
metadata: Optional[dict] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use metadata_dict here for clarity

Comment on lines +4025 to +4026
if isinstance(state_dict, tuple):
state_dict, metadata = state_dict
Copy link
Contributor

@jerryzh168 jerryzh168 Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might be a bit confusing I feel, I think we should use a different API to get tensor_data_dict and metadata_dict

@@ -4286,7 +4301,8 @@ def save_pretrained(
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
metadata["format"] = "pt"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be the default value in L4022?

@@ -279,6 +286,9 @@ def create_quantized_param(

quantize_(module, self.quantization_config.get_apply_tensor_subclass())

def transform_state_dict(self, tensor_data, metadata):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we can make this a bit more clear, e.g. transform_state_dict_before_saving

@@ -297,10 +307,13 @@ def _process_model_after_weight_loading(self, model, **kwargs):

def is_serializable(self, safe_serialization=None) -> bool:
if safe_serialization:
from torchao.quantization import Float8WeightOnlyConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also Float8DynamicActivationFloat8Config?


def setUp(self):
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
from torchao.quantization import Float8WeightOnlyConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a new test instead of overriding the old one?

with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add a new test for Float8WeightOnlyConfig and Float8DynamicActivationFloat8WeightConfig I think, and revert all the changes to previous tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants