Skip to content

MX-scale discrepancy during quantization and dequantization #1104

@mariosfourn

Description

@mariosfourn

In the case of very small numbers input numbers around the subnormal range of torch.float or torch.bfloat16, the scale exponent will take its smallest unbiased value: -127. However, you only allow division with a scale of 2**-126 in line 143 of mx_tensor.py. This is because of an incompatibility with triton.

However, during dequantization you use the the smaller scale of 2**-127 when calling

s_fp = get_fp_scale(scale_e8m0).reshape(-1, 1).to(target_dtype

in line 235. Why not clip the exponent to -126 in the get_fp_scale function

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions