Skip to content

Commit 1ddb363

Browse files
authored
Update test_nvfp4_tensor.py
1 parent bb81efd commit 1ddb363

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -525,13 +525,23 @@ def test_nvfp4_matmul_with_amax(
525525
)
526526

527527

528-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
529-
@pytest.mark.skipif(
530-
not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+"
531-
)
532-
def test_nvfp4_to_copy():
533-
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
534-
535-
torch.ops.aten._to_copy(
536-
NVFP4Tensor.to_nvfp4(torch.randn((32, 128))), dtype=torch.bfloat16
537-
)
528+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
529+
@pytest.mark.skipif(
530+
not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+"
531+
)
532+
def test_nvfp4_to_copy():
533+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
534+
535+
x = NVFP4Tensor.to_nvfp4(torch.randn((32, 128))).cuda()
536+
y = torch.ops.aten._to_copy(x, dtype=torch.bfloat16)
537+
assert torch.equal(x.qdata, y.qdata)
538+
assert torch.equal(x._scale_e4m3, y._scale_e4m3)
539+
assert x._per_tensor_scale is None
540+
assert y._per_tensor_scale is None
541+
assert x._act_per_tensor_scale is None
542+
assert y._act_per_tensor_scale is None
543+
assert x._block_size == y._block_size
544+
assert x.use_triton_kernel == y.use_triton_kernel
545+
assert x.act_quant_kwargs == y.act_quant_kwargs
546+
assert x.dtype == torch.float32
547+
assert y.dtype == torch.bfloat16

0 commit comments

Comments
 (0)