-
Notifications
You must be signed in to change notification settings - Fork 35
Open
Labels
Description
bfloat16 everywhere is ~3x faster, but the runs seem to fail:
- run. bf16. ss_llama_simple_mlp-1L (same config as here)
- run. bf16. ss_llama_simple_mlp-1.25M (same config as here)
Mixed precision seems to be about 50% faster. Runs pending to check that it works well:
We should at least use mixed precision if it works well.
We should also:
- torch.compile() the target model (compiling other things didn't seem to help/weren't possible)
- Use regular torch ops rather than einops inside the slow parts (maybe just LinearComponents.forward()).
- Turning off use_delta_components. It's 13% faster without them
Reactions are currently unavailable