Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor press implementation #21

Closed
maxjeblick opened this issue Dec 3, 2024 · 2 comments · Fixed by #24
Closed

Refactor press implementation #21

maxjeblick opened this issue Dec 3, 2024 · 2 comments · Fixed by #24
Assignees
Labels
feature request New feature or request

Comments

@maxjeblick
Copy link
Collaborator

Feature

Separate press class into two separate classes

  • A scorer class that implements the .score method
  • A pruning class that implements the .forward_hook method

The press class then works with dependency injection, e.g., ExpectedAttentionPress can be expressed as

press = BasePruner(
        compression_ratio=compression_ratio,
        scorer=ExpectedAttentionScorer(
            n_future_positions=n_future_positions, n_sink=n_sink, use_covariance=use_covariance, use_vnorm=use_vnorm
        ),
    )

Motivation

Current press code couples forward hook and score method, making it harder to implement custom workflows.
By decoupling pruning and scoring functionality, it is possible to add new pruning methods by subclassing BasePruner rather than using a wrapper function (e.g. PerLayerCompressionPruner).

@maxjeblick maxjeblick added the feature request New feature or request label Dec 3, 2024
@maxjeblick maxjeblick self-assigned this Dec 3, 2024
@SimJeg
Copy link
Collaborator

SimJeg commented Dec 3, 2024

Other features I can think about for a v0.1.0:

  • a utils.py module centralizing functions that are repeatedly use across different presses (e.g. compute_queries)
  • a distinction between SequenceScorer and ChannelScorer if more presses like ThinkPress appear (see Add ThinKPress #20)
  • a refactored of the README to move away from centralizing everything around the compression_ratio parameter. For some methods, the compression depends on the prompts, hence the compression ratio is not an input to the press but an output

@maxjeblick
Copy link
Collaborator Author

a distinction between SequenceScorer and ChannelScorer

Could be implemented as follows:

sequence_press = ... (normal press)
press = ChannelPruner(compression_ratio=..., sequence_press=sequence_press, channel_scorer=ThinkChannelScorer(....))

with ThinkChannelScorer having a score_channel method (instead of score).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants