@@ -525,13 +525,23 @@ def test_nvfp4_matmul_with_amax(
525
525
)
526
526
527
527
528
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
529
- @pytest .mark .skipif (
530
- not TORCH_VERSION_AT_LEAST_2_8 , reason = "NVFP4 requires PyTorch 2.8+"
531
- )
532
- def test_nvfp4_to_copy ():
533
- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
534
-
535
- torch .ops .aten ._to_copy (
536
- NVFP4Tensor .to_nvfp4 (torch .randn ((32 , 128 ))), dtype = torch .bfloat16
537
- )
528
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
529
+ @pytest .mark .skipif (
530
+ not TORCH_VERSION_AT_LEAST_2_8 , reason = "NVFP4 requires PyTorch 2.8+"
531
+ )
532
+ def test_nvfp4_to_copy ():
533
+ from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
534
+
535
+ x = NVFP4Tensor .to_nvfp4 (torch .randn ((32 , 128 ))).cuda ()
536
+ y = torch .ops .aten ._to_copy (x , dtype = torch .bfloat16 )
537
+ assert torch .equal (x .qdata , y .qdata )
538
+ assert torch .equal (x ._scale_e4m3 , y ._scale_e4m3 )
539
+ assert x ._per_tensor_scale is None
540
+ assert y ._per_tensor_scale is None
541
+ assert x ._act_per_tensor_scale is None
542
+ assert y ._act_per_tensor_scale is None
543
+ assert x ._block_size == y ._block_size
544
+ assert x .use_triton_kernel == y .use_triton_kernel
545
+ assert x .act_quant_kwargs == y .act_quant_kwargs
546
+ assert x .dtype == torch .float32
547
+ assert y .dtype == torch .bfloat16
0 commit comments