Skip to content

Simple speedups to decomposition #310

@danbraunai-goodfire

Description

@danbraunai-goodfire

bfloat16 everywhere is ~3x faster, but the runs seem to fail:

  • run. bf16. ss_llama_simple_mlp-1L (same config as here)
  • run. bf16. ss_llama_simple_mlp-1.25M (same config as here)

Mixed precision seems to be about 50% faster. Runs pending to check that it works well:

  • run. 1L mixed_precision
  • run. 1.25M mixed precision

We should at least use mixed precision if it works well.

We should also:

  • torch.compile() the target model (compiling other things didn't seem to help/weren't possible)
  • Use regular torch ops rather than einops inside the slow parts (maybe just LinearComponents.forward()).
  • Turning off use_delta_components. It's 13% faster without them

Metadata

Metadata

Assignees

No one assigned

    Labels

    Priority: IUImportant & UrgentfeatureNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions