Open
Description
Hi torchao developers,
Recently, while experimenting with the w4a8 functionality in torchao, I noticed that the marlin_qqq check function requires
input_tensor.dtype == torch.float16
This seems potentially problematic, as most modern models typically use bf16 or fp32 for activation values. Forcing a conversion to float16 might introduce precision loss or even NaN issues in some cases.
Could you clarify if this dtype check is strictly necessary? Are there specific constraints or optimizations that depend on float16 here?
Thank you for your insights!