Skip to content

Commit 3322d21

Browse files
committed
add block-wise scaled int8 quantization based on QuantizedLayout mechanism
add more tests by comparing with manual torch implementation add perf benchmarks fix errors caused by merging default no output quant fix unittest
1 parent dd41b74 commit 3322d21

File tree

12 files changed

+4703
-36
lines changed

12 files changed

+4703
-36
lines changed

QUANTIZATION.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,24 +124,30 @@ We define 4 possible scaling parameters that should cover most recipes in the ne
124124
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
125125
|--------|---------------|--------------|----------------|-----------------|-------------|
126126
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
127+
| int8_blockwise | int8 | float32 (per-block) | - | - | - |
128+
129+
For int8_blockwise with block_size=128 and weight shape (N, K):
130+
- weight_scale shape: (N//128, K//128)
127131

128132
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
129133

130134
### Quantization Metadata
131135

132136
The metadata stored alongside the checkpoint contains:
133137
- **format_version**: String to define a version of the standard
134-
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
138+
- **layers**: A dictionary mapping layer names to their quantization configuration. Each layer's config is a dictionary with:
139+
- **format**: Quantization format string that maps to the definitions found in `QUANT_ALGOS`
140+
- **group_size** (optional): Block size for block-wise quantization schemes (e.g., int8_blockwise)
135141

136142
Example:
137143
```json
138144
{
139145
"_quantization_metadata": {
140146
"format_version": "1.0",
141147
"layers": {
142-
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
143-
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
144-
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
148+
"model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"},
149+
"model.layers.0.mlp.down_proj": {"format": "int8_blockwise", "group_size": 128},
150+
"model.layers.1.mlp.up_proj": {"format": "int8_blockwise", "group_size": 256}
145151
}
146152
}
147153
}

comfy/float.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def stochastic_rounding(value, dtype, seed=0):
5454
return value.to(dtype=torch.float16)
5555
if dtype == torch.bfloat16:
5656
return value.to(dtype=torch.bfloat16)
57+
if dtype == torch.int8:
58+
return value.to(dtype=torch.int8)
5759
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
5860
generator = torch.Generator(device=value.device)
5961
generator.manual_seed(seed)

0 commit comments

Comments
 (0)