Skip to content

Conversation

hoedt
Copy link

@hoedt hoedt commented Jun 10, 2025

When chunk_size_inter is less than chunk_size_intra (e.g. for chunk_size values greater than 128 --- the default chunk-size value in the heuristics), the recurrent state gradients (and therefore also gradients w.r.t. K and V) were computed incorrectly.

It turns out that the backward kernel assumed that all recurrent states would be available. However, in the forward pass, only the states that are necessary for the parallel part (which are fewer if chunk_size_intra is greater than chunk_size_inter) are stored. This fix simply uses the available states correctly to compute the recurrent gradients.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant