Skip to content

Conversation

@jcaip
Copy link
Contributor

@jcaip jcaip commented Dec 8, 2025

This PR hooks up the static quant workflow added in #3442 to the prototype smoothquant API.

You can use the new flow like follows:

from torchao.quantization.quant_api import (
    Int8StaticActivationInt8WeightConfig,
)
from torchao.prototype.smoothquant import (
    SmoothQuantConfig
)

config = SmoothQuantConfig(
            base_config=Int8StaticActivationInt8Weight(granularity=PerRow()),
            step=SmoothQuantStep.PREPARE,
            alpha=0.5,
        )

quantize_(model, config)

# Perform calibration with test data
model(*x)

config.step = SmoothQuantStep.CONVERT
quantize_(model, config)

# model will now be statically quantized with the inputs used in smoothquant observer. 
model(*x)

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3468

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 2586ab6 with merge base f99105a (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 8, 2025
@jcaip jcaip added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Dec 8, 2025
@jcaip jcaip changed the title [wip] enable smoothquant for int8 static tensor enable smoothquant for int8 static tensor Dec 8, 2025
@jcaip jcaip marked this pull request as ready for review December 8, 2025 22:24
@jcaip jcaip requested a review from jerryzh168 December 8, 2025 22:29
@jcaip
Copy link
Contributor Author

jcaip commented Dec 8, 2025

cc @Xia-Weiwen and @cyxlily fyi

qw = quant_mod.weight

# Add smoothing factor metadata
qw = to_weight_tensor_with_linear_activation_scale_metadata(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not be using this, please check awq on how this should be implemented in the new stack:

assert isinstance(qw, SupportsActivationPreScaling), (
"weight must support activation scaling through implementing `SupportsActivationPreScaling`"
)
# since we want to do `act` * `act_pre_scale` during runtime for speed, we'll save the
# reciprocal of the `equalization_scale`
qw.act_pre_scale = 1.0 / equalization_scale

"""

scale: torch.Tensor
scale: torch.Tensor = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Optional[torch.Tensor]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also maybe static_scale might be more descriptive I feel

@jcaip jcaip changed the base branch from jcaip/static-quant-rebased to main December 9, 2025 04:34
@meta-codesync
Copy link

meta-codesync bot commented Dec 15, 2025

@jcaip has imported this pull request. If you are a Meta employee, you can view this in D88784212.

@cyxlily
Copy link
Contributor

cyxlily commented Dec 17, 2025

@jcaip Our customer needs activation quantization PerTensor and weight quantization PerRow. Will you implement it, or may I create a new PR to do it?

@jcaip
Copy link
Contributor Author

jcaip commented Dec 17, 2025

@cyxlily feel free to open a new PR for activation per tensor x weight per row, it's not something im planning to do currently.

Thank you for your smoothquant pr btw, I used it to implement this.

@jcaip jcaip force-pushed the jcaip/enable-smoothquant branch from 0c23589 to f389a94 Compare December 17, 2025 23:56
Comment on lines 309 to 281
sqnr_static_compile
== sqnr_static_eager
== sqnr_dynamic_compile
== sqnr_dynamic_eager
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we trying to say static_out_compile == static_out_eager == dynamic_out_compile == dynamic_out_eager here? if so I think it might be clearer just to assert all these are equal to each other

else:
raise ValueError(f"Unexpected step: {step}")

if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
Copy link
Contributor

@jerryzh168 jerryzh168 Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't have specific config here, maybe change this to a similar protocol like SupportsActivationPreScaling for config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think figuring out how to do this generally will need a bit more design, we'd need to figure out how to map to the appropriate QuantizeTensorToInt/FloatXKwargs object. Agree we should be able to do this though, but can I address in a later PR?

@jcaip jcaip force-pushed the jcaip/enable-smoothquant branch from f389a94 to 2586ab6 Compare December 18, 2025 00:02
block_size,
self.dtype,
act_quant_kwargs=self.act_quant_kwargs,
act_scale=self.act_scale,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess slice doesn't work for static quant int8 before, can you add a test for that?

old_int8_tensor.scale[index],
old_int8_tensor.block_size[1:],
old_int8_tensor.dtype,
old_int8_tensor.act_scale,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this one, seems like select op breaks before with static quant

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants