File tree Expand file tree Collapse file tree 1 file changed +13
-3
lines changed
test/prototype/mx_formats Expand file tree Collapse file tree 1 file changed +13
-3
lines changed Original file line number Diff line number Diff line change @@ -532,6 +532,16 @@ def test_nvfp4_matmul_with_amax(
532
532
def test_nvfp4_to_copy ():
533
533
from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
534
534
535
- torch .ops .aten ._to_copy (
536
- NVFP4Tensor .to_nvfp4 (torch .randn ((32 , 128 ))), dtype = torch .bfloat16
537
- )
535
+ x = NVFP4Tensor .to_nvfp4 (torch .randn ((32 , 128 ))).cuda ()
536
+ y = torch .ops .aten ._to_copy (x , dtype = torch .bfloat16 )
537
+ assert torch .equal (x .qdata , y .qdata )
538
+ assert torch .equal (x ._scale_e4m3 , y ._scale_e4m3 )
539
+ assert x ._per_tensor_scale is None
540
+ assert y ._per_tensor_scale is None
541
+ assert x ._act_per_tensor_scale is None
542
+ assert y ._act_per_tensor_scale is None
543
+ assert x ._block_size == y ._block_size
544
+ assert x .use_triton_kernel == y .use_triton_kernel
545
+ assert x .act_quant_kwargs == y .act_quant_kwargs
546
+ assert x .dtype == torch .float32
547
+ assert y .dtype == torch .bfloat16
You can’t perform that action at this time.
0 commit comments