Hi, I've been studying this implementation of speculative decoding and I have a question about the fallback branch in the rejection sampling logic.
When a draft token is rejected, the code computes the residual distribution max(0, target - draft) and normalizes it. There's an else branch that handles the case where the sum s == 0
My understanding is that s == 0 would require target_probs[i] <= draft_probs[i] for every single token in the vocabulary simultaneously which is mathematically impossible if the two distributions are different (since they both sum to 1). And if they are identical, the acceptance ratio would always be 1, meaning the token would never be rejected in the first place, so this branch would never be reached.
This makes me think the else is purely a defensive guard against floating point numerical errors (e.g. in bfloat16 on GPU) rather than a case that can occur in the mathematical logic of the algorithm. Is that correct? Or is there a scenario I'm missing where this branch would actually be triggered?
Thanks for putting this together, it's been a really valuable resource for understanding the mechanics of speculative decoding.
Hi, I've been studying this implementation of speculative decoding and I have a question about the fallback branch in the rejection sampling logic.
When a draft token is rejected, the code computes the residual distribution
max(0, target - draft)and normalizes it. There's anelsebranch that handles the case where the sums == 0My understanding is that
s == 0would requiretarget_probs[i] <= draft_probs[i]for every single token in the vocabulary simultaneously which is mathematically impossible if the two distributions are different (since they both sum to 1). And if they are identical, the acceptance ratio would always be 1, meaning the token would never be rejected in the first place, so this branch would never be reached.This makes me think the
elseis purely a defensive guard against floating point numerical errors (e.g. in bfloat16 on GPU) rather than a case that can occur in the mathematical logic of the algorithm. Is that correct? Or is there a scenario I'm missing where this branch would actually be triggered?Thanks for putting this together, it's been a really valuable resource for understanding the mechanics of speculative decoding.