Fix issue with recurrent part of chunk-wise backward computations #10
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
When
chunk_size_inter
is less thanchunk_size_intra
(e.g. forchunk_size
values greater than 128 --- the default chunk-size value in the heuristics), the recurrent state gradients (and therefore also gradients w.r.t. K and V) were computed incorrectly.It turns out that the backward kernel assumed that all recurrent states would be available. However, in the forward pass, only the states that are necessary for the parallel part (which are fewer if
chunk_size_intra
is greater thanchunk_size_inter
) are stored. This fix simply uses the available states correctly to compute the recurrent gradients.