diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py index 1271cc215..242f190f1 100644 --- a/src/adapters/methods/bottleneck.py +++ b/src/adapters/methods/bottleneck.py @@ -395,46 +395,135 @@ def forward(self, hidden_states, residual_input, layer_norm): return self.bottleneck_layer_forward(hidden_states, residual_input, layer_norm) -def hook_fn(adapter_layer, ln_get_fn, module, args, output): - # Retrieve residual from previous hook, if existing +def hook_fn(i, adapter_name, ln_name, module, args, output): + """ + Hook function to process the output of a module through an adapter layer. + + Args: + i (int): Layer index. + adapter_name (str): Name of the adapter. + ln_name (str): Name of the layer normalization module. + module (torch.nn.Module): The module being hooked. + args (tuple): Arguments passed to the module. + output (torch.Tensor or tuple): Output of the module. + + Returns: + torch.Tensor or tuple: Processed output through the adapter layer. + """ context = ForwardContext.get_context() + adapter_layer = getattr(context, f"{adapter_name}", None) + layer = getattr(context, "layer", None) + layer_norm = multigetattr(layer, ln_name, None) residual_input = getattr(context, f"{adapter_layer.location_key}_residual_input", None) - # Retrieve layer norm from getter fn - if ln_get_fn is not None: - layer_norm = ln_get_fn() - else: - layer_norm = None - # Call adapter layer if isinstance(output, torch.Tensor): return adapter_layer(output, residual_input, layer_norm) else: return (adapter_layer(output[0], residual_input, layer_norm),) + output[1:] +def _attention_adapters_hook_forward_pre_hook(module, args): + """ + Pre-forward hook to set the multi-head attention adapters in the context. + + Args: + module (torch.nn.Module): The module being hooked. + args (tuple): Arguments passed to the module. + """ + context = ForwardContext.get_context() + if context is not None: + setattr(context, "mh_adapter", module.attention_adapters) + + +def _output_adapter_hook_forward_pre_hook(module, args): + """ + Pre-forward hook to set the output adapters in the context. + + Args: + module (torch.nn.Module): The module being hooked. + args (tuple): Arguments passed to the module. + """ + context = ForwardContext.get_context() + if context is not None: + setattr(context, "output_adapter", module.output_adapters) + + +def _cross_attention_adapters_hook_forward_pre_hook(module, args): + """ + Pre-forward hook to set the cross-attention adapters in the context. + + Args: + module (torch.nn.Module): The module being hooked. + args (tuple): Arguments passed to the module. + """ + context = ForwardContext.get_context() + if context is not None: + setattr(context, "crossattn_adapter", module.cross_attention_adapters) + + +def _layer_hook_forward_pre_hook(module, args): + """ + Pre-forward hook to set the current layer in the context. + + Args: + module (torch.nn.Module): The module being hooked. + args (tuple): Arguments passed to the module. + """ + context = ForwardContext.get_context() + if context is not None: + setattr(context, "layer", module) + + def _residual_hook_fn(location_key, module, args): + """ + Hook function to set the residual input in the context. + + Args: + location_key (str): Location key of the adapter. + module (torch.nn.Module): The module being hooked. + args (tuple): Arguments passed to the module. + """ context = ForwardContext.get_context() if context is not None: setattr(context, f"{location_key}_residual_input", args[0]) def init_bottleneck(model): + """ + Initialize bottleneck adapters for the given model. + + Args: + model (torch.nn.Module): The model to initialize bottleneck adapters for. + """ model = model.base_model - for _, layer in model.iter_layers(): + + for i, layer in model.iter_layers(): + if not hasattr(layer, "has_layer_hook_forward_pre_hook"): + layer.register_forward_pre_hook(_layer_hook_forward_pre_hook) + layer.has_layer_hook_forward_pre_hook = True if self_attn := multigetattr(layer, model.adapter_interface.layer_self_attn, None): if o_proj := multigetattr(self_attn, model.adapter_interface.attn_o_proj, None): if not hasattr(layer, "attention_adapters"): layer.attention_adapters = BottleneckLayer("mh_adapter", is_layer_hooked=True) - ln_1_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_1, None) - o_proj.register_forward_hook(partial(hook_fn, layer.attention_adapters, ln_1_get_fn)) + if not hasattr(layer, "has_attention_adapters_hook_forward_pre_hook"): + layer.register_forward_pre_hook(_attention_adapters_hook_forward_pre_hook) + layer.has_attention_adapters_hook_forward_pre_hook = True + o_proj.register_forward_hook(partial(hook_fn, i, "mh_adapter", model.adapter_interface.layer_ln_1)) if layer_output_proj := multigetattr(layer, model.adapter_interface.layer_output_proj, None): if not hasattr(layer, "output_adapters"): layer.output_adapters = BottleneckLayer("output_adapter", is_layer_hooked=True) - ln_2_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_2, None) - layer_output_proj.register_forward_hook(partial(hook_fn, layer.output_adapters, ln_2_get_fn)) + if not hasattr(layer, "has_output_adapters_hook_forward_pre_hook"): + layer.register_forward_pre_hook(_output_adapter_hook_forward_pre_hook) + layer.has_output_adapters_hook_forward_pre_hook = True + layer_output_proj.register_forward_hook( + partial(hook_fn, i, "output_adapter", model.adapter_interface.layer_ln_2) + ) if cross_attn := multigetattr(layer, model.adapter_interface.layer_cross_attn, None): if not hasattr(cross_attn, "cross_attention_adapters"): - layer.attention_adapters = BottleneckLayer("cross_adapter", is_layer_hooked=True) - cross_attn.register_forward_hook(partial(hook_fn, layer.attention_adapters, None)) + layer.cross_attention_adapters = BottleneckLayer("cross_adapter", is_layer_hooked=True) + if not hasattr(layer, "has_cross_attention_adapters_hook_forward_pre_hook"): + layer.register_forward_pre_hook(_cross_attention_adapters_hook_forward_pre_hook) + layer.has_cross_attention_adapters_hook_forward_pre_hook = True + cross_attn.register_forward_hook(partial(hook_fn, i, "crossattn_adapter", None)) if model.adapter_interface.layer_pre_self_attn is not None: if pre_self_attn := multigetattr(layer, model.adapter_interface.layer_pre_self_attn, None):