From bb81efd4b99e04af604fc415c2ccdbaad61d4552 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 19 Aug 2025 16:50:04 -0700 Subject: [PATCH 1/2] Fix NVFP4 to_copy **Summary:** Fixes https://github.com/pytorch/ao/issues/2811 **Test Plan:** ``` pytest test/prototype/mx_formats/test_nvfp4_tensor.py -k to_copy ``` --- test/prototype/mx_formats/test_nvfp4_tensor.py | 12 ++++++++++++ torchao/prototype/mx_formats/nvfp4_tensor.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 3712d8929b..a77265a358 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -523,3 +523,15 @@ def test_nvfp4_matmul_with_amax( assert sqnr >= SQNR_THRESHOLD, ( f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, mm_config={mm_config}, compile={compile}, bias={bias}" ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+" +) +def test_nvfp4_to_copy(): + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + + torch.ops.aten._to_copy( + NVFP4Tensor.to_nvfp4(torch.randn((32, 128))), dtype=torch.bfloat16 + ) diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index fa4e7dc1c3..e364772f3a 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -310,10 +310,10 @@ def nvfp4_to_copy(func, types, args, kwargs): if dtype is not None: res = NVFP4Tensor( + tensor.qdata, tensor._scale_e4m3, tensor._per_tensor_scale, tensor._act_per_tensor_scale, - tensor._data, tensor._block_size, dtype, tensor._is_swizzled_scales, From fb5e41c4f492ea1e06841172f6231525d8c4c29f Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Thu, 21 Aug 2025 11:59:05 -0400 Subject: [PATCH 2/2] Update test_nvfp4_tensor.py --- test/prototype/mx_formats/test_nvfp4_tensor.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index a77265a358..219d0fe88e 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -532,6 +532,16 @@ def test_nvfp4_matmul_with_amax( def test_nvfp4_to_copy(): from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor - torch.ops.aten._to_copy( - NVFP4Tensor.to_nvfp4(torch.randn((32, 128))), dtype=torch.bfloat16 - ) + x = NVFP4Tensor.to_nvfp4(torch.randn((32, 128))).cuda() + y = torch.ops.aten._to_copy(x, dtype=torch.bfloat16) + assert torch.equal(x.qdata, y.qdata) + assert torch.equal(x._scale_e4m3, y._scale_e4m3) + assert x._per_tensor_scale is None + assert y._per_tensor_scale is None + assert x._act_per_tensor_scale is None + assert y._act_per_tensor_scale is None + assert x._block_size == y._block_size + assert x.use_triton_kernel == y.use_triton_kernel + assert x.act_quant_kwargs == y.act_quant_kwargs + assert x.dtype == torch.float32 + assert y.dtype == torch.bfloat16