Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

renaming #298

Merged
merged 3 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/nnsight/intervention/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def _run(self, *inputs, **kwargs):
def __init__(
self,
model: torch.nn.Module,
rename: Optional[Dict[str,str]] = None
) -> None:

self._model: torch.nn.Module = model
self._envoy: Envoy[InterventionProxy, InterventionNode] = Envoy(self._model)
self._envoy: Envoy[InterventionProxy, InterventionNode] = Envoy(self._model, rename=rename)

self._session: Optional[Session] = None
self._default_graph: Optional[InterventionGraph] = None
Expand Down
45 changes: 34 additions & 11 deletions src/nnsight/intervention/envoy.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from __future__ import annotations

import inspect
import re
import warnings
from contextlib import AbstractContextManager
from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, Iterator, List,
Optional, Tuple, Union)
from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, Iterator,
List, Optional, Tuple, Union)

import torch
from typing_extensions import Self

from . import protocols
from .contexts import InterventionTracer
from .graph import InterventionProxyType, InterventionNodeType
from .backends import EditingBackend
from .contexts import InterventionTracer
from .graph import InterventionNodeType, InterventionProxyType


class Envoy(Generic[InterventionProxyType, InterventionNodeType]):
"""Envoy objects act as proxies for torch modules themselves within a model's module tree in order to add nnsight functionality.
Expand All @@ -28,15 +30,20 @@ class Envoy(Generic[InterventionProxyType, InterventionNodeType]):
_children (List[Envoy]): Immediate Envoy children of this Envoy.
_fake_outputs (List[torch.Tensor]): List of 'meta' tensors built from the outputs most recent _scan. Is list as there can be multiple shapes for a module called more than once.
_fake_inputs (List[torch.Tensor]): List of 'meta' tensors built from the inputs most recent _scan. Is list as there can be multiple shapes for a module called more than once.

_rename (Optional[Dict[str,str]]): Optional mapping of (old name -> new name).
For example to rename all gpt 'attn' modules to 'attention' you would: rename={r"attn": "attention"}
Not this does not actually change the underlying module names, just how you access its envoy. Renaming will replace Envoy.path but Envoy._path represents the pre-renamed true attribute path.
_tracer (nnsight.context.Tracer.Tracer): Object which adds this Envoy's module's output and input proxies to an intervention graph. Must be set on Envoys objects manually by the Tracer.
"""

def __init__(self, module: torch.nn.Module, module_path: str = ""):
def __init__(self, module: torch.nn.Module, module_path: str = "", alias_path:Optional[str] = None, rename: Optional[Dict[str,str]] = None):

self.path = module_path
self.path = alias_path or module_path
self._path = module_path

self._module = module

self._rename = rename

self._iteration_stack = [0]

Expand Down Expand Up @@ -111,7 +118,7 @@ def output(self) -> InterventionProxyType:
else:
fake_output = self._fake_outputs[iteration]

module_path = f"{self.path}.output"
module_path = f"{self._path}.output"

output = protocols.InterventionProtocol.add(
self._tracer.graph,
Expand Down Expand Up @@ -169,7 +176,7 @@ def inputs(self) -> InterventionProxyType:
else:
fake_input = self._fake_inputs[iteration]

module_path = f"{self.path}.input"
module_path = f"{self._path}.input"

input = protocols.InterventionProtocol.add(
self._tracer.graph,
Expand Down Expand Up @@ -330,8 +337,24 @@ def _add_envoy(self, module: torch.nn.Module, name: str) -> None:
module (torch.nn.Module): Module to create Envoy for.
name (str): name of envoy/attribute.
"""

envoy = Envoy(module, module_path=f"{self.path}.{name}")

alias_path = None

module_path = f"{self.path}.{name}"

if self._rename is not None:

for key, value in self._rename.items():

if name == key:

name = value

alias_path = f"{self.path}.{name}"

break

envoy = Envoy(module, module_path=module_path, alias_path=alias_path, rename=self._rename)

self._children.append(envoy)

Expand Down
6 changes: 4 additions & 2 deletions src/nnsight/modeling/mixins/loadable.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Dict, Optional

import torch

from ...intervention import NNsight


class LoadableMixin(NNsight):

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args, rename: Optional[Dict[str,str]] = None, **kwargs) -> None:

if not isinstance(args[0], torch.nn.Module):

Expand All @@ -15,7 +17,7 @@ def __init__(self, *args, **kwargs) -> None:

model = args[0]

super().__init__(model)
super().__init__(model, rename=rename)

def _load(self, *args, **kwargs) -> torch.nn.Module:

Expand Down
6 changes: 4 additions & 2 deletions src/nnsight/modeling/mixins/meta.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict, Optional

import torch
from accelerate import init_empty_weights

Expand All @@ -8,7 +10,7 @@
class MetaMixin(LoadableMixin):

def __init__(
self, *args, dispatch: bool = False, meta_buffers: bool = True, **kwargs
self, *args, dispatch: bool = False, meta_buffers: bool = True, rename: Optional[Dict[str,str]] = None, **kwargs
) -> None:

self.dispatched = dispatch
Expand All @@ -23,7 +25,7 @@ def __init__(

model = self._load_meta(*args, **kwargs)

NNsight.__init__(self, model)
NNsight.__init__(self, model, rename=rename)

self.args = args
self.kwargs = kwargs
Expand Down
Loading