Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality to get and set state_dict of a particular adapter #756

Open
itsskofficial opened this issue Oct 29, 2024 · 1 comment
Open
Labels
enhancement New feature or request

Comments

@itsskofficial
Copy link

I am trying to do federated training of adapters using Flower framework. But I am unable to find a way to get and set the adapter state_dict similar to set_peft_model_state_dict. Here is standard Flower code for getting and setting parameters

def set_parameters(model, parameters: NDArrays) -> None:
    """Change the parameters of the model using the given ones."""
    peft_state_dict_keys = get_peft_model_state_dict(model).keys()
    params_dict = zip(peft_state_dict_keys, parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    set_peft_model_state_dict(model, state_dict)


def get_parameters(model) -> NDArrays:
    """Return the parameters of the current net."""
    state_dict = get_peft_model_state_dict(model)
    return [val.cpu().numpy() for _, val in state_dict.items()]

How do I go around doing this with adapters instead of peft?

Any help is appreciated

@itsskofficial itsskofficial added the enhancement New feature or request label Oct 29, 2024
@calpt
Copy link
Member

calpt commented Dec 22, 2024

Hey,

while there currently is no option to set the state dict exactly as shown, you can manipulate adapter weights via the get_adapter() method, ie.:

model.add_adapter("test", config="lora")
adapter = model.get_adapter("test")

This returns a dictionary containing all adapter modules in the following format: {<layer id>: {<module location>: <nn.Module>}} which can be manipulated directly and/ or copied. E.g. to set weights in the Lora module of the last transformer layer:

adapter[11]["selfattn_lora"][0].lora_A.data = torch.zeros_like(adapter[11]["selfattn_lora"][0].lora_A.data)
adapter[11]["selfattn_lora"][1].lora_A.data = torch.zeros_like(adapter[11]["selfattn_lora"][1].lora_A.data)

Hope this helps!

@calpt calpt self-assigned this Dec 22, 2024
@calpt calpt removed their assignment Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants