Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap ForwardContext around full model forward #789

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

calpt
Copy link
Member

@calpt calpt commented Feb 2, 2025

This PR adapts the ForwardContext to be applied to the full model (including head) forward pass. The original base model forward wrapper is now moved to wrap_base to make sure no second ForwardContext is created for a single forward pass.

This enables passing custom args that are defined in the ForwardContext definition to the top-level model call, as discussed in #783, e.g.:

model = AutoModelForCausalLM.from_pretrained(model_name)
adapters.init(model)
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer(["This is a test text"], return_tensors="pt")

# Registers new forward args globally
ForwardContext.context_args += ["task_ids"]

# New the new arg name can be used w/o modifying the model's forward method
output = model(**inputs, task_ids=["id_0", "id_1"])

In the example above, the forward context will automatically add the passed context args as attributes, ie. they can be accessed within the foward pass like this:

task_ids = ForwardContext.get_context().task_ids

@FrLdy
Copy link

FrLdy commented Feb 2, 2025

Hello @calpt,
Thanks for the PR !
Just tested it out and it works great !
This definitely makes passing custom args much easier, no need to redefine a new forward for each specific hf model.
I'll rebase my fork once the changes are merged and propose my additions for multi-task support.

@calpt calpt marked this pull request as ready for review February 4, 2025 14:11
@calpt calpt linked an issue Feb 4, 2025 that may be closed by this pull request
@FrLdy
Copy link

FrLdy commented Feb 8, 2025

Hi,
I tried to include a new ForwardContext in my test for multi-task composition:
https://github.com/FrLdy/adapters/blob/6f63c4df35d8deed973f4c791f6b779f8ba4f668/tests/test_misc/test_adapter_composition.py#L177

However, if I don't check whether the new ForwardContext argument (task_ids) is present in context_args, the test runs fine when each TestCase class is executed independently. But when they are run together, only the first one passes, while the subsequent ones fail.

Maybe use a set to store ForwardContext.context_args ?

@FrLdy FrLdy mentioned this pull request Feb 8, 2025
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for Passing task_ids to forward_context for Multi-Task Learning
2 participants