Skip to content

Apply power-of-two chunking consistently#207

Open
GMNGeoffrey wants to merge 1 commit intoaqlaboratory:mainfrom
GMNGeoffrey:chunking
Open

Apply power-of-two chunking consistently#207
GMNGeoffrey wants to merge 1 commit intoaqlaboratory:mainfrom
GMNGeoffrey:chunking

Conversation

@GMNGeoffrey
Copy link
Copy Markdown
Contributor

Summary
Currently chunking of the AuxiliaryHeadsAllAtom pairformer stack is disabled if using optimized attention kernels to avoid a bug with multi-sample chunking. This is the largest-memory attention call, so it doesn't make a whole lot of sense to disable it here. Instead, we split by sample when using optimized kernels. Additionally makes chunk sizes consistently a power of two rather than 4 being added to the largest chunk size (it was suggested to combine these changes. Happy to split them up if desired).

Changes

  • Avoid weird addition of 4 to power-of-two chunk sizes. This was added in a9a12890d without explanation. We can hypothesize that it was related to adding 4 to an input dimension in trace_utils.py (trying to get a test case to fit in one chunk?), but that file was long ago deleted. This just looks like a bug and makes us hit unhappy paths all over the place. Fixes [BUG/QUESTION] Chunk size incremented by 4 off largest power of 2 #203
  • Enable chunking for AuxiliaryHeadsAllAtom pairformer embedding when using optimized kernels. Without chunking, this is the first call to cause OOMs because its diffusion_samples*sequence_length batches. Chunking gets turned off in prediction_heads.py due to batch size > 1 and use of optimized kernels because cross-sample chunking requires expanding out pair bias and they all require it to have size 1 in the second dimension with implicit broadcasting. So we turn on apply_per_sample when optimized kernels are in use. This splits the > 1 batch dimension, which avoids this problematic path and then we can do normal chunking for the rest if it's still too large. We could do something more elaborate (see suggestions in linked issue), but this is an improvement for now. Fixes Restoring chunking for batch_size > 1 with optimized kernels #206
  • Assert in Triton triangle attention kernel that the second dim of pair bias has size 1. This was implicitly assumed with pair_bias.stride(1) not even passed into the function.
  • Update tests for all optimized kernels in test_kernels.py to avoid batch_size > 1 + chunking
  • Document specific issue with multi-sample chunking and optimized kernels

Related Issues

Testing

  • Confirmed that when limiting torch GPU memory usage to 12GB, the multimer example OOMed prior to this change and runs to completion after.
  • Confirmed that attention kernel calling pattern moves from batches of 516 to 512.
  • I didn't find a good way to unit test this that was actually a useful test (e.g. I can assert that chunk_size is now 512 instead of 516, but I don't think that test carries its weight).

Other Notes
I'm happy to make some of the more elaborate changes discussed in #206 either in this PR or a separate one, but wanted to share this relatively simple fix.

- Avoid weird addition of 4 to power-of-two chunk sizes. This was added
  in aqlaboratory@a9a12890d without
  explanation. We can hypothesize that it was related to adding 4 to an
  input dimension in trace_utils.py (trying to get a test case to fit in
  one chunk?), but that file was long ago deleted. This just looks like
  a bug and makes us hit unhappy paths all over the place. Fixes
  aqlaboratory#203
- Enable chunking for AuxiliaryHeadsAllAtom pairformer embedding when
  using optimized kernels. Without chunking, this is the first call to
  cause OOMs because its `diffusion_samples*sequence_length` batches.
  Chunking gets turned off in prediction_heads.py due to batch size > 1
  and use of optimized kernels because cross-sample chunking requires
  expanding out pair bias and they all require it to have size 1 in the
  second dimension with implicit broadcasting. So we turn on
  `apply_per_sample` when optimized kernels are in use. This splits the
  > 1 batch dimension, which avoids this problematic path and then we
  can do normal chunking for the rest if it's still too large. We could
  do something more elaborate (see suggestions in linked issue), but
  this is an improvement for now. Fixes
  aqlaboratory#206
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Restoring chunking for batch_size > 1 with optimized kernels [BUG/QUESTION] Chunk size incremented by 4 off largest power of 2

1 participant