Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,15 @@ We also support deployment to edge devices through ExecuTorch, for more detail,
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with [TorchTune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the [QAT README](torchao/quantization/qat/README.md) and the [original blog](https://pytorch.org/blog/quantization-aware-training/):

```python
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
import torch
from torchao.quantization import quantize_, Int8DynamicActivationIntxWeightConfig, PerGroup
from torchao.quantization.qat import QATConfig

# prepare
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
base_config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
)
quantize_(my_model, QATConfig(base_config, step="prepare"))

# train model (not shown)
Expand Down Expand Up @@ -266,15 +270,6 @@ The best example we have combining the composability of lower bit dtype with com

Our framework makes it straightforward to add tensor parallel support to your custom quantized tensor subclass. Check out our [tensor parallel tutorial](tutorials/developer_api_guide/tensor_parallel.py) to see how a quantized tensor subclass can be extended to support column and row-wise tensor sharding while maintaining compatibility with `torch.compile`.

### Custom Kernels
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is no longer relevant


We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow

1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))`
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference

If you believe there's other CUDA kernels we should be taking a closer look at please leave a comment on [this issue](https://github.com/pytorch/ao/issues/697) or feel free to contribute directly to the repo.
-->

## 🔗 Integrations
Expand Down
4 changes: 1 addition & 3 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,7 @@ def test_fp8_weight_dimension_warning(self):
model = ToyLinearModel(10, 25).to(_DEVICE) # 10x25 and 25x10 weights

# Set up logging capture
with self.assertLogs(
"torchao.quantization.quant_api", level="INFO"
) as log_context:
with self.assertLogs("torchao.quantization.utils", level="INFO") as log_context:
quantize_(
model,
Float8DynamicActivationFloat8WeightConfig(
Expand Down
4 changes: 1 addition & 3 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,9 +834,7 @@ def test_config_deprecation(self):
self.assertTrue(len(_warnings) == 1)
found_deprecated = False
for w in _warnings:
if "will be moving to prototype in a future release" in str(
w.message
):
if "will be deleted in a future release" in str(w.message):
found_deprecated = True
self.assertTrue(
found_deprecated, f"did not find deprecated warning for {cls}"
Expand Down
Loading
Loading