Apply power-of-two chunking consistently#207
Open
GMNGeoffrey wants to merge 1 commit intoaqlaboratory:mainfrom
Open
Apply power-of-two chunking consistently#207GMNGeoffrey wants to merge 1 commit intoaqlaboratory:mainfrom
GMNGeoffrey wants to merge 1 commit intoaqlaboratory:mainfrom
Conversation
- Avoid weird addition of 4 to power-of-two chunk sizes. This was added in aqlaboratory@a9a12890d without explanation. We can hypothesize that it was related to adding 4 to an input dimension in trace_utils.py (trying to get a test case to fit in one chunk?), but that file was long ago deleted. This just looks like a bug and makes us hit unhappy paths all over the place. Fixes aqlaboratory#203 - Enable chunking for AuxiliaryHeadsAllAtom pairformer embedding when using optimized kernels. Without chunking, this is the first call to cause OOMs because its `diffusion_samples*sequence_length` batches. Chunking gets turned off in prediction_heads.py due to batch size > 1 and use of optimized kernels because cross-sample chunking requires expanding out pair bias and they all require it to have size 1 in the second dimension with implicit broadcasting. So we turn on `apply_per_sample` when optimized kernels are in use. This splits the > 1 batch dimension, which avoids this problematic path and then we can do normal chunking for the rest if it's still too large. We could do something more elaborate (see suggestions in linked issue), but this is an improvement for now. Fixes aqlaboratory#206
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Currently chunking of the
AuxiliaryHeadsAllAtompairformer stack is disabled if using optimized attention kernels to avoid a bug with multi-sample chunking. This is the largest-memory attention call, so it doesn't make a whole lot of sense to disable it here. Instead, we split by sample when using optimized kernels. Additionally makes chunk sizes consistently a power of two rather than 4 being added to the largest chunk size (it was suggested to combine these changes. Happy to split them up if desired).Changes
AuxiliaryHeadsAllAtompairformer embedding when using optimized kernels. Without chunking, this is the first call to cause OOMs because itsdiffusion_samples*sequence_lengthbatches. Chunking gets turned off in prediction_heads.py due to batch size > 1 and use of optimized kernels because cross-sample chunking requires expanding out pair bias and they all require it to have size 1 in the second dimension with implicit broadcasting. So we turn onapply_per_samplewhen optimized kernels are in use. This splits the > 1 batch dimension, which avoids this problematic path and then we can do normal chunking for the rest if it's still too large. We could do something more elaborate (see suggestions in linked issue), but this is an improvement for now. Fixes Restoring chunking for batch_size > 1 with optimized kernels #206pair_bias.stride(1)not even passed into the function.Related Issues
Testing
Other Notes
I'm happy to make some of the more elaborate changes discussed in #206 either in this PR or a separate one, but wanted to share this relatively simple fix.