Make chunking for Triton kernels closer to cueq#209
Open
GMNGeoffrey wants to merge 1 commit intoaqlaboratory:mainfrom
Open
Make chunking for Triton kernels closer to cueq#209GMNGeoffrey wants to merge 1 commit intoaqlaboratory:mainfrom
GMNGeoffrey wants to merge 1 commit intoaqlaboratory:mainfrom
Conversation
The max chunk size when using the cuequivariance kernels is double the default (1024 instead of 512). Make the Triton kernels use the same. There are a few places where the already tuned chunk size is divided by 4, but that subdivision is skipped when using the cuequivariance kernels. Make the Triton kernels the same. Rationale: the Triton kernels are very similar to the cuequivariance kernels and both use a memory efficient flash-attention-style algorithm. I think it makes sense to default to having them use the same code paths for chunk size. Some of the chunking subdivisions had explicit comments about being to work around errors specific to native PyTorch ops. We might want to further tune the chunking based on more detailed benchmarking, but assuming the Triton kernels behave similar to the cuequivariance kernels seems like a better starting point than assuming they behave like native PyTorch.
Contributor
Author
|
@christinaflo and @jnwei since you were responding to my other issues about chunking :-) Thanks for taking a look |
Collaborator
|
I was going to mention this in the other PR, chunking is too aggressive for all these kernels including deepspeed so we should have a single variable for this as you mentioned, or just make it more configurable in general instead of a hardcoded param. I did notice that when modifying these settings for large sequences cueq ran out of mem faster than the other kernels which is why I lowered the max in this commit, but I don't know if this is optimal. For the purposes of this PR we can just make these two match, but I'd like to tune this better in a future PR. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Make the chunking values used for Triton kernels the same as those used for cuequivariance kernels
Rationale: the Triton kernels are very similar to the cuequivariance kernels and both use a memory efficient flash-attention-style algorithm. I think it makes sense to default to having them use the same code paths for chunk size. Some of the chunking subdivisions had explicit comments about being to work around errors specific to native PyTorch ops (I think maybe when I was looking through history: not now). We might want to further tune the chunking based on more detailed benchmarking, but assuming the Triton kernels behave similar to the cuequivariance kernels seems like a better starting point than assuming they behave like native PyTorch. This also reduces the number of unique call signatures the attention kernel sees which makes it easier to tune.
Changes
Related Issues
Sort of related to #203 and #206. It was noticing the weird calling pattern of chunks of 129 (
(512+4)/4) that caused me to dive into the chunking logic in the first placeTesting
Other Notes
We could have a single "optimized attention kernels" chunk size variable rather than having separate ones.