Open
Description
TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Lines 1030 to 1037 in a662411
Is there a specific reason why torch.bfloat16
is not included in the allowed_casts
set within the to_copy_dtype_validator
function?
Plus, this causes graph partitioning when performing a aten.ops._to_copy
operation to torch.bfloat16
. I'm wondering if this could potentially impact performance.