Description
For testing #1684 related to #1665 I have been comparing side to side my QLoRA implementation with bitsandbytes Params4bit
vs. torchao NF4Tensor
on single GPU or DDP setups.
I noticed that the torchao script nearly has twice the runtime vs. the bitsandbytes implementation. Here is the main difference:
weight = nn.Parameter(
to_nf4(weight, block_size=64, scaler_block_size=256)
)
linear_nf4(input=x, weight=weight)
vs.
weight = Params4bit(
data=weight,
quant_type="nf4",
blocksize=64
)._quantize(device="cuda:0")
bnb.matmul_4bit(
x,
weight.t(),
quant_state=weight.quant_state,
)
I ran a few simple timings in a quick script, and indeed the forward appears to be nearly twice as slow on simple input and weight tensors in bfloat16
. Digging a bit deeper it seems to be mainly based on the dequantization step.
In torchao weight.to(torch.bfloat16)
seems to be quite slower vs. bitsandbytes.functional.dequantize_4bit(weight, weight.quant_state, quant_type="nf4")
.
Im wondering if Im either missing some difference for a fair comparison, or some functionality, or whether this is expected. Happy for any guidance. Thanks!