Skip to content

Question about dtype check in marlin_qqq validation for w4a8 functionality #2115

Open
@xxw11

Description

@xxw11

Hi torchao developers,

Recently, while experimenting with the w4a8 functionality in torchao, I noticed that the marlin_qqq check function requires

input_tensor.dtype == torch.float16

This seems potentially problematic, as most modern models typically use bf16 or fp32 for activation values. Forcing a conversion to float16 might introduce precision loss or even NaN issues in some cases.

Could you clarify if this dtype check is strictly necessary? Are there specific constraints or optimizations that depend on float16 here?

Thank you for your insights!

Image

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions