Skip to content

Commit f06db0d

Browse files
Update transformer_engine/pytorch/tensor/quantized_tensor.py
Co-authored-by: Tim Moon <[email protected]> Signed-off-by: ZhiyiDanielSu <[email protected]>
1 parent 2d315e6 commit f06db0d

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

transformer_engine/pytorch/tensor/quantized_tensor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
427427
dst.quantize_(src)
428428
else:
429429
if isinstance(src, QuantizedTensor):
430-
src = src.dequantize(dtype=dst.dtype)
430+
dtype = dst.dtype
431+
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
432+
dtype = torch.float32
433+
src = src.dequantize(dtype=dtype)
431434
dst.copy_(src)
432435
return None
433436

0 commit comments

Comments
 (0)