Skip to content

Commit 667dfe7

Browse files
committed
Merge remote-tracking branch 'pytorch/main' into parq
2 parents 6130cc2 + cdced21 commit 667dfe7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+671
-1775
lines changed

.github/workflows/build_wheels_linux.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
os: linux
2929
with-cpu: enable
3030
with-cuda: enable
31-
with-rocm: enable
31+
with-rocm: disable
3232
with-xpu: enable
3333
# Note: if free-threaded python is required add py3.13t here
3434
python-versions: '["3.9"]'

.github/workflows/regression_test.yml

+15-15
Original file line numberDiff line numberDiff line change
@@ -59,35 +59,35 @@ jobs:
5959
fail-fast: false
6060
matrix:
6161
include:
62-
- name: CUDA 2.3
62+
- name: CUDA 2.5.1
6363
runs-on: linux.g5.12xlarge.nvidia.gpu
64-
torch-spec: 'torch==2.3.0'
64+
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
6565
gpu-arch-type: "cuda"
66-
gpu-arch-version: "12.1"
67-
- name: CUDA 2.4
66+
gpu-arch-version: "12.6"
67+
- name: CUDA 2.6
6868
runs-on: linux.g5.12xlarge.nvidia.gpu
69-
torch-spec: 'torch==2.4.0'
69+
torch-spec: 'torch==2.6.0'
7070
gpu-arch-type: "cuda"
71-
gpu-arch-version: "12.1"
72-
- name: CUDA 2.5.1
71+
gpu-arch-version: "12.6"
72+
- name: CUDA 2.7
7373
runs-on: linux.g5.12xlarge.nvidia.gpu
74-
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
74+
torch-spec: 'torch==2.7.0'
7575
gpu-arch-type: "cuda"
76-
gpu-arch-version: "12.1"
76+
gpu-arch-version: "12.6"
7777

78-
- name: CPU 2.3
78+
- name: CPU 2.5.1
7979
runs-on: linux.4xlarge
80-
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'
80+
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cpu'
8181
gpu-arch-type: "cpu"
8282
gpu-arch-version: ""
83-
- name: CPU 2.4
83+
- name: CPU 2.6
8484
runs-on: linux.4xlarge
85-
torch-spec: 'torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu'
85+
torch-spec: 'torch==2.6.0 --index-url https://download.pytorch.org/whl/cpu'
8686
gpu-arch-type: "cpu"
8787
gpu-arch-version: ""
88-
- name: CPU 2.5.1
88+
- name: CPU 2.7
8989
runs-on: linux.4xlarge
90-
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cpu'
90+
torch-spec: 'torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu'
9191
gpu-arch-type: "cpu"
9292
gpu-arch-version: ""
9393

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111

1212
- repo: https://github.com/astral-sh/ruff-pre-commit
1313
# Ruff version.
14-
rev: v0.6.8
14+
rev: v0.11.6
1515
hooks:
1616
# Run the linter.
1717
- id: ruff

benchmarks/float8/utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def profiler_output_to_filtered_time_by_kernel_name(
8383
continue
8484
elif e.key == "Activity Buffer Request":
8585
continue
86+
elif e.key == "Unrecognized":
87+
# TODO I think these are nvjet related
88+
continue
8689

8790
kernel_name_to_gpu_time_us[e.key] = e.self_device_time_total
8891
return kernel_name_to_gpu_time_us

test/dtypes/test_affine_quantized.py

+1
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def apply_uint6_weight_only_quant(linear):
222222

223223
deregister_aqt_quantized_linear_dispatch(dispatch_condition)
224224

225+
@skip_if_rocm("ROCm enablement in progress")
225226
@unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available")
226227
def test_print_quantized_module(self):
227228
for device in self.GPU_DEVICES:

test/dtypes/test_nf4.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
to_nf4,
4040
)
4141
from torchao.testing.utils import skip_if_rocm
42-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8
42+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
4343

4444
bnb_available = False
4545

@@ -119,7 +119,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype):
119119
@unittest.skipIf(not bnb_available, "Need bnb availble")
120120
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
121121
@unittest.skipIf(
122-
TORCH_VERSION_AT_LEAST_2_8, reason="Failing in CI"
122+
TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI"
123123
) # TODO: fix this
124124
@skip_if_rocm("ROCm enablement in progress")
125125
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@@ -146,7 +146,7 @@ def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype):
146146
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
147147
@skip_if_rocm("ROCm enablement in progress")
148148
@unittest.skipIf(
149-
TORCH_VERSION_AT_LEAST_2_8, reason="Failing in CI"
149+
TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI"
150150
) # TODO: fix this
151151
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
152152
def test_nf4_bnb_linear(self, dtype: torch.dtype):

test/prototype/mx_formats/test_mx_linear.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@
2828
from torchao.quantization import quantize_
2929
from torchao.quantization.utils import compute_error
3030
from torchao.utils import (
31+
TORCH_VERSION_AT_LEAST_2_7,
3132
TORCH_VERSION_AT_LEAST_2_8,
3233
is_sm_at_least_89,
3334
is_sm_at_least_100,
3435
)
3536

3637
torch.manual_seed(2)
3738

38-
if not TORCH_VERSION_AT_LEAST_2_8:
39+
if not TORCH_VERSION_AT_LEAST_2_7:
3940
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
4041

4142

@@ -222,6 +223,8 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
222223
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
223224

224225
if recipe_name in ["mxfp8_cublas", "mxfp8_cutlass", "mxfp4_cutlass"]:
226+
if not TORCH_VERSION_AT_LEAST_2_8:
227+
pytest.skip("torch.compile requires PyTorch 2.8+")
225228
if not is_sm_at_least_100():
226229
pytest.skip("CUDA capability >= 10.0 required for MX gemms")
227230

@@ -308,6 +311,9 @@ def test_inference_linear(elem_dtype, bias, input_shape):
308311

309312

310313
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
314+
@pytest.mark.skipif(
315+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
316+
)
311317
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
312318
def test_inference_compile_simple(elem_dtype):
313319
"""

test/prototype/mx_formats/test_mx_mm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
1111
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor
1212
from torchao.prototype.mx_formats.utils import to_blocked
13-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100
13+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_100
1414

15-
if not TORCH_VERSION_AT_LEAST_2_8:
15+
if not TORCH_VERSION_AT_LEAST_2_7:
1616
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
1717

1818

test/prototype/scaled_grouped_mm/test_kernels.py

+3
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
_to_2d_jagged_float8_tensor_colwise,
2929
_to_2d_jagged_float8_tensor_rowwise,
3030
)
31+
from torchao.testing.utils import skip_if_rocm
3132

3233

34+
@skip_if_rocm("ROCm enablement in progress")
3335
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
3436
def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool):
3537
# tests case where rowwise scales are computed for multiple distinct subtensors,
@@ -57,6 +59,7 @@ def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool):
5759
assert not _is_column_major(kernel_fp8_data), "fp8 data is not row major"
5860

5961

62+
@skip_if_rocm("ROCm enablement in progress")
6063
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
6164
def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: bool):
6265
# tests case where colwise scales are computed for multiple distinct subtensors,

test/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
3030
_scaled_grouped_mm,
3131
)
32+
from torchao.testing.utils import skip_if_rocm
3233

3334

35+
@skip_if_rocm("ROCm enablement in progress")
3436
def test_valid_scaled_grouped_mm_2d_3d():
3537
out_dtype = torch.bfloat16
3638
device = "cuda"

test/quantization/pt2e/test_duplicate_dq.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
Quantizer,
2727
SharedQuantizationSpec,
2828
)
29-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
29+
from torchao.testing.pt2e._xnnpack_quantizer import (
3030
get_symmetric_quantization_config,
3131
)
32-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer_utils import (
32+
from torchao.testing.pt2e._xnnpack_quantizer_utils import (
3333
OP_TO_ANNOTATOR,
3434
QuantizationConfig,
3535
)

test/quantization/pt2e/test_metadata_porting.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
1818
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation, Quantizer
19-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
19+
from torchao.testing.pt2e._xnnpack_quantizer import (
2020
get_symmetric_quantization_config,
2121
)
22-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
22+
from torchao.testing.pt2e._xnnpack_quantizer_utils import OP_TO_ANNOTATOR
2323
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
2424

2525

test/quantization/pt2e/test_numeric_debugger.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
2626
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
27-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
27+
from torchao.testing.pt2e._xnnpack_quantizer import (
2828
XNNPACKQuantizer,
2929
get_symmetric_quantization_config,
3030
)
@@ -255,7 +255,7 @@ def test_prepare_for_propagation_comparison(self):
255255
ref = m(*example_inputs)
256256
res = m_logger(*example_inputs)
257257

258-
from torchao.quantization.pt2e.pt2e._numeric_debugger import OutputLogger
258+
from torchao.quantization.pt2e._numeric_debugger import OutputLogger
259259

260260
loggers = [m for m in m_logger.modules() if isinstance(m, OutputLogger)]
261261
self.assertEqual(len(loggers), 3)

test/quantization/pt2e/test_quantize_pt2e.py

+43-10
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@
5757
from torchao.quantization.pt2e.quantizer.embedding_quantizer import ( # noqa: F811
5858
EmbeddingQuantizer,
5959
)
60-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
60+
from torchao.testing.pt2e._xnnpack_quantizer import (
6161
XNNPACKQuantizer,
6262
get_symmetric_quantization_config,
6363
)
64-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer_utils import (
64+
from torchao.testing.pt2e._xnnpack_quantizer_utils import (
6565
OP_TO_ANNOTATOR,
6666
QuantizationConfig,
6767
)
@@ -1328,6 +1328,40 @@ def validate(self, model: torch.fx.GraphModule) -> None:
13281328
with self.assertRaises(Exception):
13291329
m = prepare_pt2e(m, BackendAQuantizer())
13301330

1331+
def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False):
1332+
# resetting dynamo cache
1333+
torch._dynamo.reset()
1334+
1335+
m = export_for_training(
1336+
m,
1337+
example_inputs,
1338+
).module()
1339+
if is_qat:
1340+
m = prepare_qat_pt2e(m, quantizer)
1341+
else:
1342+
m = prepare_pt2e(m, quantizer)
1343+
m(*example_inputs)
1344+
m = convert_pt2e(m)
1345+
return m
1346+
1347+
def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule:
1348+
class M(torch.nn.Module):
1349+
def __init__(self) -> None:
1350+
super().__init__()
1351+
self.linear = torch.nn.Linear(2, 2)
1352+
1353+
def forward(self, x):
1354+
return self.linear(x)
1355+
1356+
quantizer = XNNPACKQuantizer()
1357+
operator_config = get_symmetric_quantization_config(
1358+
is_per_channel=is_per_channel
1359+
)
1360+
quantizer.set_global(operator_config)
1361+
example_inputs = (torch.randn(2, 2),)
1362+
m = M().eval()
1363+
return self._quantize(m, quantizer, example_inputs)
1364+
13311365
def test_fold_quantize(self):
13321366
"""Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)"""
13331367
m = self._get_pt2e_quantized_linear()
@@ -2493,10 +2527,10 @@ def check_nn_module(node):
24932527
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")
24942528
class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase):
24952529
def test_channel_group_quantization(self):
2496-
from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken
2497-
from torchao.quantization.pt2e.pt2e._affine_quantization import (
2530+
from torchao.quantization.pt2e._affine_quantization import (
24982531
AffineQuantizedMinMaxObserver,
24992532
)
2533+
from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken
25002534

25012535
class BackendAQuantizer(Quantizer):
25022536
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
@@ -2576,14 +2610,14 @@ def forward(self, x):
25762610
def test_dynamic_affine_act_per_channel_weights(self):
25772611
import operator
25782612

2613+
from torchao.quantization.pt2e._affine_quantization import (
2614+
AffineQuantizedMovingAverageMinMaxObserver,
2615+
)
25792616
from torchao.quantization.pt2e.observer import (
25802617
MappingType,
25812618
PerChannelMinMaxObserver,
25822619
PerToken,
25832620
)
2584-
from torchao.quantization.pt2e.pt2e._affine_quantization import (
2585-
AffineQuantizedMovingAverageMinMaxObserver,
2586-
)
25872621

25882622
class BackendAQuantizer(Quantizer):
25892623
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
@@ -2667,13 +2701,12 @@ def forward(self, x):
26672701
def test_dynamic_per_tok_act_per_group_weights(self):
26682702
import operator
26692703

2670-
from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken
2671-
26722704
# TODO: merge into torchao observer
2673-
from torchao.quantization.pt2e.pt2e._affine_quantization import (
2705+
from torchao.quantization.pt2e._affine_quantization import (
26742706
AffineQuantizedMinMaxObserver,
26752707
AffineQuantizedPlaceholderObserver,
26762708
)
2709+
from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken
26772710

26782711
class BackendAQuantizer(Quantizer):
26792712
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:

test/quantization/pt2e/test_quantize_pt2e_qat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
QuantizationSpec,
4848
Quantizer,
4949
)
50-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
50+
from torchao.testing.pt2e._xnnpack_quantizer import (
5151
XNNPACKQuantizer,
5252
get_symmetric_quantization_config,
5353
)

test/quantization/pt2e/test_representation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2525
from torchao.quantization.pt2e.quantizer import Quantizer
26-
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
26+
from torchao.testing.pt2e._xnnpack_quantizer import (
2727
XNNPACKQuantizer,
2828
get_symmetric_quantization_config,
2929
)

0 commit comments

Comments
 (0)