Skip to content

PARQ quantizer support for torchao's weight-only configs #2091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

lisjin
Copy link
Contributor

@lisjin lisjin commented Apr 21, 2025

This is the first step in supporting torchao.quantize_ for PARQ trained models. I target only Int4WeightOnlyConfig and IntxWeightOnlyConfig for now since PARQ does not have activation quantization.

Instead of converting the state (e.g., scale, zero point) from PARQ's existing quantizers to torchao format, I decided to create a new quantizer UnifTorchaoQuantizer. This quantizer calls torchao's quantization primitives choose_qparams_affine, quantize_affine, dequantize_affine to ensure parity between the two QAT methods.

@metascroy It would be great if you could check the correctness of how the quantizer in TestUnifTorchaoQuantizer.test_intx_weight_only is initialized. I'm not sure if I missed any subtleties with int8.

Copy link

pytorch-bot bot commented Apr 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2091

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9e70d6d with merge base cdced21 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 21, 2025
@lisjin lisjin added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Apr 21, 2025
@lisjin lisjin force-pushed the parq branch 4 times, most recently from 48520cb to bea111c Compare April 22, 2025 15:22
@metascroy
Copy link
Contributor

@lisjin can you give a little code snippet of our QAT prepare/convert would work for this API?

I'm having trouble following. Here are some example code snippets from other APIs: https://fb.workplace.com/groups/pytorch.edge2.team/permalink/1186139489308568/

Comment on lines +190 to +191
@common_utils.parametrize("b", [2, 3, 4, 8])
@common_utils.parametrize("group_size", [32, 512])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added parametrized test cases that loop over (int2, int3, int4, int8) data types.

@lisjin lisjin requested a review from metascroy April 23, 2025 21:27
@lisjin lisjin force-pushed the parq branch 2 times, most recently from f6362f7 to fb7f521 Compare April 24, 2025 02:08
@@ -106,6 +108,28 @@ def quantize_(
quants.copy_(Q)
return q

@torch.no_grad()
def torchao_quantize_(self, model):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@metascroy Here is a new optimizer-level method within PARQ that (1) loads the latent full-precision weights and (2) recursively quantizes all modules using torchao.quantize_. Just to double-check, will it also work on modules besides torch.nn.Linear?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually to do non-linear, we need a third filter_fn argument, e.g., for embedding:

quantize_(model, IntxWeightOnlyConfig(torch.int4), lambda m, fqn: isinstance(m, torch.nn.Embedding))

By default the filter function in quantize_ is torch.nn.Linear, which is why you don't usually see it specified.

@andrewor14
Copy link
Contributor

Hi @lisjin, do you mind adding a code snippet on the main README on what the end-to-end flow would look like? My understanding is you can just replace the LSBQuantizer there with your new UnifTorchaoQuantizer. Then what happens after training? Do we call quantize_(model, Int4WeightOnlyConfig) as before? Would be good to clarify

@lisjin
Copy link
Contributor Author

lisjin commented Apr 25, 2025

@andrewor14 Thanks for the feedback—I removed config from UnifTorchaoQuantizer. In the README, I've also added a side-by-side comparison of PARQ vs. torchao prepare and convert steps. After PARQ training, we call optimizer.torchao_quantize_(model, config). Let me know if there's anything missing.

@andrewor14
Copy link
Contributor

Looks great, thanks @lisjin! The README is very clear.

One thing I want to discuss is whether we can just use a new PARQConfig instead so the PARQ flow looks more like the existing torchao QAT flow. This is current convert flow in the PR now:

config = IntXWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32))
optimizer.torchao_quantize_(model, config)

What do you think about something like this instead?

inner_config = IntXWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32))
parq_config = PARQConfig(optimizer, inner_config)
quantize_(model, parq_config)

Also curious if @metascroy has any thoughts on this

Comment on lines +116 to +129
param_set = {
p.data_ptr()
for group in self.regularized_param_groups()
for p in group["params"]
}

def inner_quantize_(model):
for module in model.children():
for param in module.parameters(recurse=False):
if param.data_ptr() in param_set:
quantize_(module, config)
break

inner_quantize_(module)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andrewor14 I think the quantize_(model, PARQConfig(optimizer, inner_config)) interface looks cleaner, but I see a potential issue with applying this inner_quantize_ function. PARQ uses param_set to filter which modules to quantize. However, quantize_ has its own recursive function that quantizes parameters based off filter_fn mapping nn.Module -> bool. I'm not sure how to reconcile the two.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yeah doesn't seem trivial to reconcile these two. Maybe this is OK then

Copy link
Contributor

@andrewor14 andrewor14 Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another idea is:

quantize_(model, config, optimizer.filter_fn)

where optimizer.filter_fn checks against the param_set. Then you don't need PARQConfig. What do you think?

for module in model.children():
for param in module.parameters(recurse=False):
if param.data_ptr() in param_set:
quantize_(module, config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way this call itself is also recursive, so I feel this will end up quantizing the same submodules multiple times? Maybe we should express this as part of a filter_fn, something like:

def _filter_fn(module: torch.nn.Module) -> bool:
    for param in module.parameters(recurse=False):
        if param.data_ptr() in param_set:
            return True
    return False

quantize_(model, config, _filter_fn)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I'll add this change in

Copy link
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me other than the recursion comment. @metascroy any other thoughts?

@metascroy
Copy link
Contributor

Looks good to me! Thanks @lisjin!

Can we add an end-to-end test_intx_weight_only_e2e for intx (with various x-values), similar to test_int4_weight_only_e2e?

) -> None:
super().__init__(center=False)

self.mapping_type = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: pass mapping type directly?

_BIT_WIDTH_TO_DTYPE[b]
]
if self.target_dtype is None:
self.target_dtype = torch.int8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just make torch.int8 the default in init?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants