Skip to content

Commit fb5e41c

Browse files
committed
Update test_nvfp4_tensor.py
1 parent bb81efd commit fb5e41c

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,16 @@ def test_nvfp4_matmul_with_amax(
532532
def test_nvfp4_to_copy():
533533
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
534534

535-
torch.ops.aten._to_copy(
536-
NVFP4Tensor.to_nvfp4(torch.randn((32, 128))), dtype=torch.bfloat16
537-
)
535+
x = NVFP4Tensor.to_nvfp4(torch.randn((32, 128))).cuda()
536+
y = torch.ops.aten._to_copy(x, dtype=torch.bfloat16)
537+
assert torch.equal(x.qdata, y.qdata)
538+
assert torch.equal(x._scale_e4m3, y._scale_e4m3)
539+
assert x._per_tensor_scale is None
540+
assert y._per_tensor_scale is None
541+
assert x._act_per_tensor_scale is None
542+
assert y._act_per_tensor_scale is None
543+
assert x._block_size == y._block_size
544+
assert x.use_triton_kernel == y.use_triton_kernel
545+
assert x.act_quant_kwargs == y.act_quant_kwargs
546+
assert x.dtype == torch.float32
547+
assert y.dtype == torch.bfloat16

0 commit comments

Comments
 (0)