-
Notifications
You must be signed in to change notification settings - Fork 341
Closed
Description
>>> from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
>>> import torch
>>> torch.ops.aten._to_copy(NVFP4Tensor.to_nvfp4(torch.randn((32, 128))), dtype=torch.bfloat16)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/andrewor/local/pytorch/torch/_ops.py", line 1254, in __call__
return self._op(*args, **kwargs)
File "/home/andrewor/local/ao/torchao/prototype/mx_formats/nvfp4_tensor.py", line 137, in __torch_dispatch__
return NVFP4_OPS_TABLE[func](func, types, args, kwargs)
File "/home/andrewor/local/ao/torchao/prototype/mx_formats/nvfp4_tensor.py", line 316, in nvfp4_to_copy
tensor._data,
AttributeError: 'NVFP4Tensor' object has no attribute '_data'
Seems like this should be tensor.qdata
, and also it should be the first argument?
ao/torchao/prototype/mx_formats/nvfp4_tensor.py
Lines 311 to 322 in 083361b
if dtype is not None: | |
res = NVFP4Tensor( | |
tensor._scale_e4m3, | |
tensor._per_tensor_scale, | |
tensor._act_per_tensor_scale, | |
tensor._data, | |
tensor._block_size, | |
dtype, | |
tensor._is_swizzled_scales, | |
tensor.use_triton_kernel, | |
tensor.act_quant_kwargs, | |
) |
Metadata
Metadata
Assignees
Labels
No labels