Skip to content

Commit

Permalink
Merge pull request #298 from ndif-team/aliasing
Browse files Browse the repository at this point in the history
renaming
  • Loading branch information
JadenFiotto-Kaufman authored Dec 7, 2024
2 parents f03101a + 42500f5 commit 8e4f8e9
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 16 deletions.
3 changes: 2 additions & 1 deletion src/nnsight/intervention/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def _run(self, *inputs, **kwargs):
def __init__(
self,
model: torch.nn.Module,
rename: Optional[Dict[str,str]] = None
) -> None:

self._model: torch.nn.Module = model
self._envoy: Envoy[InterventionProxy, InterventionNode] = Envoy(self._model)
self._envoy: Envoy[InterventionProxy, InterventionNode] = Envoy(self._model, rename=rename)

self._session: Optional[Session] = None
self._default_graph: Optional[InterventionGraph] = None
Expand Down
45 changes: 34 additions & 11 deletions src/nnsight/intervention/envoy.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from __future__ import annotations

import inspect
import re
import warnings
from contextlib import AbstractContextManager
from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, Iterator, List,
Optional, Tuple, Union)
from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, Iterator,
List, Optional, Tuple, Union)

import torch
from typing_extensions import Self

from . import protocols
from .contexts import InterventionTracer
from .graph import InterventionProxyType, InterventionNodeType
from .backends import EditingBackend
from .contexts import InterventionTracer
from .graph import InterventionNodeType, InterventionProxyType


class Envoy(Generic[InterventionProxyType, InterventionNodeType]):
"""Envoy objects act as proxies for torch modules themselves within a model's module tree in order to add nnsight functionality.
Expand All @@ -28,15 +30,20 @@ class Envoy(Generic[InterventionProxyType, InterventionNodeType]):
_children (List[Envoy]): Immediate Envoy children of this Envoy.
_fake_outputs (List[torch.Tensor]): List of 'meta' tensors built from the outputs most recent _scan. Is list as there can be multiple shapes for a module called more than once.
_fake_inputs (List[torch.Tensor]): List of 'meta' tensors built from the inputs most recent _scan. Is list as there can be multiple shapes for a module called more than once.
_rename (Optional[Dict[str,str]]): Optional mapping of (old name -> new name).
For example to rename all gpt 'attn' modules to 'attention' you would: rename={r"attn": "attention"}
Not this does not actually change the underlying module names, just how you access its envoy. Renaming will replace Envoy.path but Envoy._path represents the pre-renamed true attribute path.
_tracer (nnsight.context.Tracer.Tracer): Object which adds this Envoy's module's output and input proxies to an intervention graph. Must be set on Envoys objects manually by the Tracer.
"""

def __init__(self, module: torch.nn.Module, module_path: str = ""):
def __init__(self, module: torch.nn.Module, module_path: str = "", alias_path:Optional[str] = None, rename: Optional[Dict[str,str]] = None):

self.path = module_path
self.path = alias_path or module_path
self._path = module_path

self._module = module

self._rename = rename

self._iteration_stack = [0]

Expand Down Expand Up @@ -111,7 +118,7 @@ def output(self) -> InterventionProxyType:
else:
fake_output = self._fake_outputs[iteration]

module_path = f"{self.path}.output"
module_path = f"{self._path}.output"

output = protocols.InterventionProtocol.add(
self._tracer.graph,
Expand Down Expand Up @@ -169,7 +176,7 @@ def inputs(self) -> InterventionProxyType:
else:
fake_input = self._fake_inputs[iteration]

module_path = f"{self.path}.input"
module_path = f"{self._path}.input"

input = protocols.InterventionProtocol.add(
self._tracer.graph,
Expand Down Expand Up @@ -330,8 +337,24 @@ def _add_envoy(self, module: torch.nn.Module, name: str) -> None:
module (torch.nn.Module): Module to create Envoy for.
name (str): name of envoy/attribute.
"""

envoy = Envoy(module, module_path=f"{self.path}.{name}")

alias_path = None

module_path = f"{self.path}.{name}"

if self._rename is not None:

for key, value in self._rename.items():

if name == key:

name = value

alias_path = f"{self.path}.{name}"

break

envoy = Envoy(module, module_path=module_path, alias_path=alias_path, rename=self._rename)

self._children.append(envoy)

Expand Down
6 changes: 4 additions & 2 deletions src/nnsight/modeling/mixins/loadable.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Dict, Optional

import torch

from ...intervention import NNsight


class LoadableMixin(NNsight):

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args, rename: Optional[Dict[str,str]] = None, **kwargs) -> None:

if not isinstance(args[0], torch.nn.Module):

Expand All @@ -15,7 +17,7 @@ def __init__(self, *args, **kwargs) -> None:

model = args[0]

super().__init__(model)
super().__init__(model, rename=rename)

def _load(self, *args, **kwargs) -> torch.nn.Module:

Expand Down
6 changes: 4 additions & 2 deletions src/nnsight/modeling/mixins/meta.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict, Optional

import torch
from accelerate import init_empty_weights

Expand All @@ -8,7 +10,7 @@
class MetaMixin(LoadableMixin):

def __init__(
self, *args, dispatch: bool = False, meta_buffers: bool = True, **kwargs
self, *args, dispatch: bool = False, meta_buffers: bool = True, rename: Optional[Dict[str,str]] = None, **kwargs
) -> None:

self.dispatched = dispatch
Expand All @@ -23,7 +25,7 @@ def __init__(

model = self._load_meta(*args, **kwargs)

NNsight.__init__(self, model)
NNsight.__init__(self, model, rename=rename)

self.args = args
self.kwargs = kwargs
Expand Down

0 comments on commit 8e4f8e9

Please sign in to comment.