Commit f3e549c
authored
Add NVFP4 QAT (#2666)
* [bc-breaking] Generalize FakeQuantizeConfig beyond intx
**Summary:** The existing `FakeQuantizeConfig` performs only
intx quantization, but we plan to extend QAT to other dtypes
such as fp8 and nvfp4 in the near future. This is the necessary
refactor before that. Specifically:
```
# New abstract class
FakeQuantizeConfigBase
# Rename
FakeQuantizeConfig -> IntxFakeQuantizeConfig
```
In the future, we will have other types of `FakeQuantizeConfigBase`
for float dtypes that users can pass in instead of the existing
Intx one.
**BC-breaking notes:** For BC, we keep around the old names to
reference the new ones. However, this commit is still BC-breaking
in the sense that a few APIs now accept the abstract
`FakeQuantizeConfigBase` instead. For the most part, this abstract
class will be hidden from the user.
Before:
```
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
```
After:
```
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
```
**Test Plan:**
python test/quantization/test_qat.py
[ghstack-poisoned]
* New multi-step QAT API
**Summary:** This commit adds a new multi-step QAT API with the
main goal of simplifying the existing UX. The new API uses the
same `QATConfig` for both the prepare and convert steps, and
automatically infers the fake quantization configs based on
a PTQ base config provided by the user:
```
from torchao.quantization import (
quantize_,
Int8DynamicActivationInt4WeightConfig
)
from torchao.quantization.qat import QATConfig
\# prepare
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
qat_config = QATConfig(base_config, step="prepare")
quantize_(m, qat_config)
\# train (not shown)
\# convert
quantize_(m, QATConfig(base_config, step="convert"))
```
The main improvements include:
- A single config for both prepare and convert steps
- A single quantize_ for convert (instead of 2)
- No chance for incompatible prepare vs convert configs
- Much less boilerplate code for most common use case
- Simpler config names
For less common use cases such as experimentation, users can
still specify arbitrary fake quantization configs for
activations and/or weights as before. This is still important
since there may not always be a corresponding PTQ base config.
For example:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
\# train and convert same as above (not shown)
```
**BC-breaking notes:** This change by itself is technically not
BC-breaking since we keep around the old path, but will become
so when we deprecate and remove the old path in the future.
Before:
```
\# prepare
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
quantize_(model, qat_config)
\# train (not shown)
\# convert
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
```
After: (see above)
**Test Plan:**
```
python test/quantization/test_qat.py
```
[ghstack-poisoned]
* Update on "New multi-step QAT API"
**Summary:** This commit adds a new multi-step QAT API with the
main goal of simplifying the existing UX. The new API uses the
same `QATConfig` for both the prepare and convert steps, and
automatically infers the fake quantization configs based on
a PTQ base config provided by the user:
```
from torchao.quantization import (
quantize_,
Int8DynamicActivationInt4WeightConfig
)
from torchao.quantization.qat import QATConfig
# prepare
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(m, QATConfig(base_config, step="prepare"))
# train (not shown)
# convert
quantize_(m, QATConfig(base_config, step="convert"))
```
The main improvements include:
- A single config for both prepare and convert steps
- A single quantize_ for convert (instead of 2)
- No chance for incompatible prepare vs convert configs
- Much less boilerplate code for most common use case
- Simpler config names
For less common use cases such as experimentation, users can
still specify arbitrary fake quantization configs for
activations and/or weights as before. This is still important
since there may not always be a corresponding PTQ base config.
For example:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig
# prepare
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
# train and convert same as above (not shown)
```
**BC-breaking notes:** This change by itself is technically not
BC-breaking since we keep around the old path, but will become
so when we deprecate and remove the old path in the future.
Before:
```
# prepare
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
quantize_(model, qat_config)
# train (not shown)
# convert
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
```
After: (see above)
**Test Plan:**
```
python test/quantization/test_qat.py
```
[ghstack-poisoned]
* Update on "New multi-step QAT API"
**Summary:** This commit adds a new multi-step QAT API with the
main goal of simplifying the existing UX. The new API uses the
same `QATConfig` for both the prepare and convert steps, and
automatically infers the fake quantization configs based on
a PTQ base config provided by the user:
```
from torchao.quantization import (
quantize_,
Int8DynamicActivationInt4WeightConfig
)
from torchao.quantization.qat import QATConfig
# prepare
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(m, QATConfig(base_config, step="prepare"))
# train (not shown)
# convert
quantize_(m, QATConfig(base_config, step="convert"))
```
The main improvements include:
- A single config for both prepare and convert steps
- A single quantize_ for convert (instead of 2)
- No chance for incompatible prepare vs convert configs
- Much less boilerplate code for most common use case
- Simpler config names
For less common use cases such as experimentation, users can
still specify arbitrary fake quantization configs for
activations and/or weights as before. This is still important
since there may not always be a corresponding PTQ base config.
For example:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig
# prepare
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
# train and convert same as above (not shown)
```
**BC-breaking notes:** This change by itself is technically not
BC-breaking since we keep around the old path, but will become
so when we deprecate and remove the old path in the future.
Before:
```
# prepare
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
quantize_(model, qat_config)
# train (not shown)
# convert
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
```
After: (see above)
**Test Plan:**
```
python test/quantization/test_qat.py
```
[ghstack-poisoned]
* Deprecate old QAT APIs
**Summary:** Deprecates QAT APIs that should no longer be used.
Print helpful deprecation warning to help users migrate.
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_api_deprecation
```
[ghstack-poisoned]
* Update base for Update on "Deprecate old QAT APIs"
**Summary:** Deprecates QAT APIs that should no longer be used.
Print helpful deprecation warning to help users migrate.
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_api_deprecation
```
[ghstack-poisoned]
* Update base for Update on "Deprecate old QAT APIs"
**Summary:** Deprecates QAT APIs that should no longer be used.
Print helpful deprecation warning to help users migrate.
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_api_deprecation
```
Also manual testing:
```
'IntXQuantizationAwareTrainingConfig' is deprecated and will be removed in a future release. Please use the following API instead:
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(model, QATConfig(base_config, step="prepare"))
# train (not shown)
quantize_(model, QATConfig(base_config, step="convert"))
Alternatively, if you prefer to pass in fake quantization configs:
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
Please see #2630 for more details.
IntXQuantizationAwareTrainingConfig(activation_config=None, weight_config=None)
```
[ghstack-poisoned]
* Update base for Update on "Deprecate old QAT APIs"
**Summary:** Deprecates QAT APIs that should no longer be used.
Print helpful deprecation warning to help users migrate.
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_api_deprecation
```
Also manual testing:
```
'IntXQuantizationAwareTrainingConfig' is deprecated and will be removed in a future release. Please use the following API instead:
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(model, QATConfig(base_config, step="prepare"))
# train (not shown)
quantize_(model, QATConfig(base_config, step="convert"))
Alternatively, if you prefer to pass in fake quantization configs:
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
Please see #2630 for more details.
IntXQuantizationAwareTrainingConfig(activation_config=None, weight_config=None)
```
[ghstack-poisoned]
* Update base for Update on "Deprecate old QAT APIs"
**Summary:** Deprecates QAT APIs that should no longer be used.
Print helpful deprecation warning to help users migrate.
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_api_deprecation
```
Also manual testing:
```
>>> from torchao.quantization.qat import IntXQuantizationAwareTrainingConfig
>>> IntXQuantizationAwareTrainingConfig()
'IntXQuantizationAwareTrainingConfig' is deprecated and will be removed in a future release. Please use the following API instead:
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(model, QATConfig(base_config, step="prepare"))
# train (not shown)
quantize_(model, QATConfig(base_config, step="convert"))
Alternatively, if you prefer to pass in fake quantization configs:
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
Please see #2630 for more details.
IntXQuantizationAwareTrainingConfig(activation_config=None, weight_config=None)
```
[ghstack-poisoned]
* Update base for Update on "Deprecate old QAT APIs"
**Summary:** Deprecates QAT APIs that should no longer be used.
Print helpful deprecation warning to help users migrate.
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_api_deprecation
```
Also manual testing:
```
>>> from torchao.quantization.qat import IntXQuantizationAwareTrainingConfig
>>> IntXQuantizationAwareTrainingConfig()
'IntXQuantizationAwareTrainingConfig' is deprecated and will be removed in a future release. Please use the following API instead:
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(model, QATConfig(base_config, step="prepare"))
# train (not shown)
quantize_(model, QATConfig(base_config, step="convert"))
Alternatively, if you prefer to pass in fake quantization configs:
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
Please see #2630 for more details.
IntXQuantizationAwareTrainingConfig(activation_config=None, weight_config=None)
```
[ghstack-poisoned]
* Update base for Update on "Deprecate old QAT APIs"
**Summary:** Deprecates QAT APIs that should no longer be used.
Print helpful deprecation warning to help users migrate.
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_api_deprecation
```
Also manual testing:
```
>>> from torchao.quantization.qat import IntXQuantizationAwareTrainingConfig
>>> IntXQuantizationAwareTrainingConfig()
'IntXQuantizationAwareTrainingConfig' is deprecated and will be removed in a future release. Please use the following API instead:
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(model, QATConfig(base_config, step="prepare"))
# train (not shown)
quantize_(model, QATConfig(base_config, step="convert"))
Alternatively, if you prefer to pass in fake quantization configs:
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
Please see #2630 for more details.
IntXQuantizationAwareTrainingConfig(activation_config=None, weight_config=None)
```
[ghstack-poisoned]
* Update base for Update on "Deprecate old QAT APIs"
**Summary:** Deprecates QAT APIs that should no longer be used.
Print helpful deprecation warning to help users migrate.
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_api_deprecation
```
Also manual testing:
```
>>> from torchao.quantization.qat import IntXQuantizationAwareTrainingConfig
>>> IntXQuantizationAwareTrainingConfig()
'IntXQuantizationAwareTrainingConfig' is deprecated and will be removed in a future release. Please use the following API instead:
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(model, QATConfig(base_config, step="prepare"))
# train (not shown)
quantize_(model, QATConfig(base_config, step="convert"))
Alternatively, if you prefer to pass in fake quantization configs:
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
Please see #2630 for more details.
IntXQuantizationAwareTrainingConfig(activation_config=None, weight_config=None)
```
[ghstack-poisoned]
* Update base for Update on "Deprecate old QAT APIs"
**Summary:** Deprecates QAT APIs that should no longer be used.
Print helpful deprecation warning to help users migrate.
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_api_deprecation
```
Also manual testing:
```
>>> from torchao.quantization.qat import IntXQuantizationAwareTrainingConfig
>>> IntXQuantizationAwareTrainingConfig()
'IntXQuantizationAwareTrainingConfig' is deprecated and will be removed in a future release. Please use the following API instead:
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(model, QATConfig(base_config, step="prepare"))
# train (not shown)
quantize_(model, QATConfig(base_config, step="convert"))
Alternatively, if you prefer to pass in fake quantization configs:
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
Please see #2630 for more details.
IntXQuantizationAwareTrainingConfig(activation_config=None, weight_config=None)
```
[ghstack-poisoned]
* Update base for Update on "Deprecate old QAT APIs"
**Summary:** Deprecates QAT APIs that should no longer be used.
Print helpful deprecation warning to help users migrate.
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_api_deprecation
```
Also manual testing:
```
>>> from torchao.quantization.qat import IntXQuantizationAwareTrainingConfig
>>> IntXQuantizationAwareTrainingConfig()
'IntXQuantizationAwareTrainingConfig' is deprecated and will be removed in a future release. Please use the following API instead:
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(model, QATConfig(base_config, step="prepare"))
# train (not shown)
quantize_(model, QATConfig(base_config, step="convert"))
Alternatively, if you prefer to pass in fake quantization configs:
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
Please see #2630 for more details.
IntXQuantizationAwareTrainingConfig(activation_config=None, weight_config=None)
```
[ghstack-poisoned]
* Add NVFP4 QAT
**Summary:** This commit adds a QAT flow for NVFP4, following the
numerics in `NVFP4Tensor` closely but without the dtyping casting,
swizzling, and the packing/unpacking. Users can call this flow as follows:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(),
weight_config=NVFP4FakeQuantizeConfig(),
step="prepare",
)
quantize_(model, qat_config)
```
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
```
[ghstack-poisoned]
* Update base for Update on "Add NVFP4 QAT"
**Summary:** This commit adds a QAT flow for NVFP4, following the
numerics in `NVFP4Tensor` closely but without the dtyping casting,
swizzling, and the packing/unpacking. Users can call this flow as follows:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(),
weight_config=NVFP4FakeQuantizeConfig(),
step="prepare",
)
quantize_(model, qat_config)
```
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
```
[ghstack-poisoned]
* Update base for Update on "Add NVFP4 QAT"
**Summary:** This commit adds a QAT flow for NVFP4, following the
numerics in `NVFP4Tensor` closely but without the dtyping casting,
swizzling, and the packing/unpacking. Users can call this flow as follows:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(),
weight_config=NVFP4FakeQuantizeConfig(),
step="prepare",
)
quantize_(model, qat_config)
```
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
```
[ghstack-poisoned]
* Update base for Update on "Add NVFP4 QAT"
**Summary:** This commit adds a QAT flow for NVFP4, following the
numerics in `NVFP4Tensor` closely but without the dtyping casting,
swizzling, and the packing/unpacking. Users can call this flow as follows:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(),
weight_config=NVFP4FakeQuantizeConfig(),
step="prepare",
)
quantize_(model, qat_config)
```
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
```
[ghstack-poisoned]
* Update base for Update on "Add NVFP4 QAT"
**Summary:** This commit adds a QAT flow for NVFP4, following the
numerics in `NVFP4Tensor` closely but without the dtyping casting,
swizzling, and the packing/unpacking. Users can call this flow as follows:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(),
weight_config=NVFP4FakeQuantizeConfig(),
step="prepare",
)
quantize_(model, qat_config)
```
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
```
[ghstack-poisoned]
* Update base for Update on "Add NVFP4 QAT"
**Summary:** This commit adds a QAT flow for NVFP4, following the
numerics in `NVFP4Tensor` closely but without the dtyping casting,
swizzling, and the packing/unpacking. Users can call this flow as follows:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(),
weight_config=NVFP4FakeQuantizeConfig(),
step="prepare",
)
quantize_(model, qat_config)
```
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
```
[ghstack-poisoned]
* Update base for Update on "Add NVFP4 QAT"
**Summary:** This commit adds a QAT flow for NVFP4, following the
numerics in `NVFP4Tensor` closely but without the dtyping casting,
swizzling, and the packing/unpacking. Users can call this flow as follows:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(),
weight_config=NVFP4FakeQuantizeConfig(),
step="prepare",
)
quantize_(model, qat_config)
```
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
```
[ghstack-poisoned]
* Update base for Update on "Add NVFP4 QAT"
**Summary:** This commit adds a QAT flow for NVFP4, following the
numerics in `NVFP4Tensor` closely but without the dtyping casting,
swizzling, and the packing/unpacking. Users can call this flow as follows:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(),
weight_config=NVFP4FakeQuantizeConfig(),
step="prepare",
)
quantize_(model, qat_config)
```
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
```
Initial benchmarks on fine-tuning Qwen3-1.7B on oasst1 for 3 epochs:
```
# Without QAT
| Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext| 2|none |None |bits_per_byte |↓ | 0.7927|± | N/A|
| | |none |None |byte_perplexity|↓ | 1.7323|± | N/A|
| | |none |None |word_perplexity|↓ |18.8815|± | N/A|
# With QAT
| Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext| 2|none |None |bits_per_byte |↓ | 0.7921|± | N/A|
| | |none |None |byte_perplexity|↓ | 1.7316|± | N/A|
| | |none |None |word_perplexity|↓ |18.8409|± | N/A|
```
[ghstack-poisoned]
* Update base for Update on "Add NVFP4 QAT"
**Summary:** This commit adds a QAT flow for NVFP4, following the
numerics in `NVFP4Tensor` closely but without the dtyping casting,
swizzling, and the packing/unpacking. Users can call this flow as follows:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(),
weight_config=NVFP4FakeQuantizeConfig(),
step="prepare",
)
quantize_(model, qat_config)
```
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
```
Initial benchmarks on fine-tuning Qwen3-1.7B on alpaca for 3 epochs:
```
# Without QAT
| Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext| 2|none |None |bits_per_byte |↓ | 0.8322|± | N/A|
| | |none |None |byte_perplexity|↓ | 1.7804|± | N/A|
| | |none |None |word_perplexity|↓ |21.8611|± | N/A|
# With QAT
| Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext| 2|none |None |bits_per_byte |↓ | 0.8271|± | N/A|
| | |none |None |byte_perplexity|↓ | 1.7741|± | N/A|
| | |none |None |word_perplexity|↓ |21.4467|± | N/A|
```
[ghstack-poisoned]
* Update base for Update on "Add NVFP4 QAT"
**Summary:** This commit adds a QAT flow for NVFP4, following the
numerics in `NVFP4Tensor` closely but without the dtyping casting,
swizzling, and the packing/unpacking. Users can call this flow as follows:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(),
weight_config=NVFP4FakeQuantizeConfig(),
step="prepare",
)
quantize_(model, qat_config)
```
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
```
Initial benchmarks on fine-tuning Qwen3-1.7B on alpaca for 3 epochs:
```
# Without QAT
| Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext| 2|none |None |bits_per_byte |↓ | 0.8322|± | N/A|
| | |none |None |byte_perplexity|↓ | 1.7804|± | N/A|
| | |none |None |word_perplexity|↓ |21.8611|± | N/A|
# With QAT
| Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext| 2|none |None |bits_per_byte |↓ | 0.8271|± | N/A|
| | |none |None |byte_perplexity|↓ | 1.7741|± | N/A|
| | |none |None |word_perplexity|↓ |21.4467|± | N/A|
```
[ghstack-poisoned]1 parent bc2c83e commit f3e549c
File tree
7 files changed
+194
-18
lines changed- test/quantization
- torchao
- prototype
- mx_formats
- qat
- quantization/qat
7 files changed
+194
-18
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
118 | 118 | | |
119 | 119 | | |
120 | 120 | | |
121 | | - | |
122 | | - | |
| 121 | + | |
| 122 | + | |
123 | 123 | | |
124 | 124 | | |
125 | 125 | | |
| |||
1928 | 1928 | | |
1929 | 1929 | | |
1930 | 1930 | | |
1931 | | - | |
| 1931 | + | |
1932 | 1932 | | |
1933 | 1933 | | |
1934 | 1934 | | |
| |||
1952 | 1952 | | |
1953 | 1953 | | |
1954 | 1954 | | |
| 1955 | + | |
| 1956 | + | |
| 1957 | + | |
| 1958 | + | |
| 1959 | + | |
| 1960 | + | |
| 1961 | + | |
| 1962 | + | |
| 1963 | + | |
| 1964 | + | |
| 1965 | + | |
| 1966 | + | |
| 1967 | + | |
| 1968 | + | |
| 1969 | + | |
| 1970 | + | |
| 1971 | + | |
| 1972 | + | |
| 1973 | + | |
| 1974 | + | |
| 1975 | + | |
| 1976 | + | |
| 1977 | + | |
| 1978 | + | |
| 1979 | + | |
| 1980 | + | |
| 1981 | + | |
| 1982 | + | |
| 1983 | + | |
| 1984 | + | |
| 1985 | + | |
| 1986 | + | |
| 1987 | + | |
| 1988 | + | |
| 1989 | + | |
| 1990 | + | |
| 1991 | + | |
| 1992 | + | |
| 1993 | + | |
| 1994 | + | |
| 1995 | + | |
1955 | 1996 | | |
1956 | 1997 | | |
1957 | 1998 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
751 | 751 | | |
752 | 752 | | |
753 | 753 | | |
| 754 | + | |
| 755 | + | |
| 756 | + | |
| 757 | + | |
| 758 | + | |
| 759 | + | |
| 760 | + | |
| 761 | + | |
| 762 | + | |
| 763 | + | |
| 764 | + | |
| 765 | + | |
| 766 | + | |
| 767 | + | |
| 768 | + | |
| 769 | + | |
| 770 | + | |
| 771 | + | |
| 772 | + | |
| 773 | + | |
| 774 | + | |
| 775 | + | |
| 776 | + | |
754 | 777 | | |
755 | 778 | | |
756 | 779 | | |
757 | 780 | | |
758 | 781 | | |
759 | 782 | | |
760 | 783 | | |
| 784 | + | |
761 | 785 | | |
762 | 786 | | |
763 | 787 | | |
| |||
769 | 793 | | |
770 | 794 | | |
771 | 795 | | |
772 | | - | |
773 | | - | |
774 | | - | |
775 | | - | |
| 796 | + | |
| 797 | + | |
776 | 798 | | |
777 | 799 | | |
778 | 800 | | |
| |||
784 | 806 | | |
785 | 807 | | |
786 | 808 | | |
787 | | - | |
788 | | - | |
| 809 | + | |
| 810 | + | |
789 | 811 | | |
790 | 812 | | |
791 | 813 | | |
| |||
794 | 816 | | |
795 | 817 | | |
796 | 818 | | |
797 | | - | |
798 | | - | |
799 | | - | |
800 | | - | |
801 | | - | |
| 819 | + | |
| 820 | + | |
| 821 | + | |
| 822 | + | |
| 823 | + | |
| 824 | + | |
| 825 | + | |
| 826 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
320 | 320 | | |
321 | 321 | | |
322 | 322 | | |
323 | | - | |
324 | 323 | | |
325 | 324 | | |
326 | 325 | | |
| |||
331 | 330 | | |
332 | 331 | | |
333 | 332 | | |
| 333 | + | |
334 | 334 | | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
335 | 342 | | |
336 | 343 | | |
337 | 344 | | |
| |||
385 | 392 | | |
386 | 393 | | |
387 | 394 | | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
388 | 406 | | |
389 | 407 | | |
390 | 408 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
57 | 57 | | |
58 | 58 | | |
59 | 59 | | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
60 | 66 | | |
61 | 67 | | |
62 | | - | |
| 68 | + | |
63 | 69 | | |
| 70 | + | |
| 71 | + | |
64 | 72 | | |
65 | 73 | | |
66 | 74 | | |
| |||
73 | 81 | | |
74 | 82 | | |
75 | 83 | | |
| 84 | + | |
76 | 85 | | |
77 | 86 | | |
78 | 87 | | |
| |||
98 | 107 | | |
99 | 108 | | |
100 | 109 | | |
101 | | - | |
| 110 | + | |
102 | 111 | | |
103 | 112 | | |
104 | 113 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
92 | 92 | | |
93 | 93 | | |
94 | 94 | | |
95 | | - | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
96 | 98 | | |
97 | 99 | | |
98 | 100 | | |
| |||
0 commit comments