Skip to content

Make chunking for Triton kernels closer to cueq#209

Open
GMNGeoffrey wants to merge 1 commit intoaqlaboratory:mainfrom
GMNGeoffrey:no-chunk-subdivision
Open

Make chunking for Triton kernels closer to cueq#209
GMNGeoffrey wants to merge 1 commit intoaqlaboratory:mainfrom
GMNGeoffrey:no-chunk-subdivision

Conversation

@GMNGeoffrey
Copy link
Copy Markdown
Contributor

Summary
Make the chunking values used for Triton kernels the same as those used for cuequivariance kernels

Rationale: the Triton kernels are very similar to the cuequivariance kernels and both use a memory efficient flash-attention-style algorithm. I think it makes sense to default to having them use the same code paths for chunk size. Some of the chunking subdivisions had explicit comments about being to work around errors specific to native PyTorch ops (I think maybe when I was looking through history: not now). We might want to further tune the chunking based on more detailed benchmarking, but assuming the Triton kernels behave similar to the cuequivariance kernels seems like a better starting point than assuming they behave like native PyTorch. This also reduces the number of unique call signatures the attention kernel sees which makes it easier to tune.

Changes

  • The max chunk size when using the cuequivariance kernels is double the default (1024 instead of 512). Make the Triton kernels use the same.
  • There are a few places where the already tuned chunk size is divided by 4, but that subdivision is skipped when using the cuequivariance kernels. Make the Triton kernels the same.

Related Issues
Sort of related to #203 and #206. It was noticing the weird calling pattern of chunks of 129 ((512+4)/4) that caused me to dive into the chunking logic in the first place

Testing

  • Did a quick spot check that this doesn't blow up memory usage (test was based on top of Apply power-of-two chunking consistently #207 to split per sample). Using the largest inference example (the protein ligand complex, which still only has 737 tokens), peak memory usage went from 7.2GB to 8.2GB.

Other Notes
We could have a single "optimized attention kernels" chunk size variable rather than having separate ones.

The max chunk size when using the cuequivariance kernels is double the
default (1024 instead of 512). Make the Triton kernels use the same.

There are a few places where the already tuned chunk size is divided by
4, but that subdivision is skipped when using the cuequivariance
kernels. Make the Triton kernels the same.

Rationale: the Triton kernels are very similar to the cuequivariance
kernels and both use a memory efficient flash-attention-style algorithm.
I think it makes sense to default to having them use the same code paths
for chunk size. Some of the chunking subdivisions had explicit comments
about being to work around errors specific to native PyTorch ops. We
might want to further tune the chunking based on more detailed
benchmarking, but assuming the Triton kernels behave similar to the
cuequivariance kernels seems like a better starting point than assuming
they behave like native PyTorch.
@GMNGeoffrey
Copy link
Copy Markdown
Contributor Author

@christinaflo and @jnwei since you were responding to my other issues about chunking :-) Thanks for taking a look

@christinaflo
Copy link
Copy Markdown
Collaborator

I was going to mention this in the other PR, chunking is too aggressive for all these kernels including deepspeed so we should have a single variable for this as you mentioned, or just make it more configurable in general instead of a hardcoded param. I did notice that when modifying these settings for large sequences cueq ran out of mem faster than the other kernels which is why I lowered the max in this commit, but I don't know if this is optimal. For the purposes of this PR we can just make these two match, but I'd like to tune this better in a future PR.

@jandom jandom added the safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing. label May 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants