From e50782c6ade14de27c526047ccdaf4e9e3d576e4 Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Sun, 1 Dec 2024 19:49:09 -0500 Subject: [PATCH 1/2] renaming --- src/nnsight/intervention/base.py | 3 +- src/nnsight/intervention/envoy.py | 47 +++++++++++++++++++------ src/nnsight/modeling/mixins/loadable.py | 6 ++-- src/nnsight/modeling/mixins/meta.py | 6 ++-- 4 files changed, 46 insertions(+), 16 deletions(-) diff --git a/src/nnsight/intervention/base.py b/src/nnsight/intervention/base.py index 53a0bc4f..87c278d7 100755 --- a/src/nnsight/intervention/base.py +++ b/src/nnsight/intervention/base.py @@ -59,10 +59,11 @@ def _run(self, *inputs, **kwargs): def __init__( self, model: torch.nn.Module, + rename: Optional[Dict[str,str]] = None ) -> None: self._model: torch.nn.Module = model - self._envoy: Envoy[InterventionProxy, InterventionNode] = Envoy(self._model) + self._envoy: Envoy[InterventionProxy, InterventionNode] = Envoy(self._model, rename=rename) self._session: Optional[Session] = None self._default_graph: Optional[InterventionGraph] = None diff --git a/src/nnsight/intervention/envoy.py b/src/nnsight/intervention/envoy.py index c652ce9c..616f1e2f 100755 --- a/src/nnsight/intervention/envoy.py +++ b/src/nnsight/intervention/envoy.py @@ -1,18 +1,20 @@ from __future__ import annotations import inspect +import re import warnings from contextlib import AbstractContextManager -from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, Iterator, List, - Optional, Tuple, Union) +from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, Iterator, + List, Optional, Tuple, Union) import torch from typing_extensions import Self from . import protocols -from .contexts import InterventionTracer -from .graph import InterventionProxyType, InterventionNodeType from .backends import EditingBackend +from .contexts import InterventionTracer +from .graph import InterventionNodeType, InterventionProxyType + class Envoy(Generic[InterventionProxyType, InterventionNodeType]): """Envoy objects act as proxies for torch modules themselves within a model's module tree in order to add nnsight functionality. @@ -28,15 +30,20 @@ class Envoy(Generic[InterventionProxyType, InterventionNodeType]): _children (List[Envoy]): Immediate Envoy children of this Envoy. _fake_outputs (List[torch.Tensor]): List of 'meta' tensors built from the outputs most recent _scan. Is list as there can be multiple shapes for a module called more than once. _fake_inputs (List[torch.Tensor]): List of 'meta' tensors built from the inputs most recent _scan. Is list as there can be multiple shapes for a module called more than once. - + _rename (Optional[Dict[str,str]]): Optional mapping of (regex string -> new name). + For example to rename all gpt 'attn' modules to 'attention' you would: rename={r"\.transformer\.h\.\d+.attn": "attention"} + Not this does not actually change the underlying module names, just how you access its envoy. Renaming will replace Envoy.path but Envoy._path represents the pre-renamed true attribute path. _tracer (nnsight.context.Tracer.Tracer): Object which adds this Envoy's module's output and input proxies to an intervention graph. Must be set on Envoys objects manually by the Tracer. """ - def __init__(self, module: torch.nn.Module, module_path: str = ""): + def __init__(self, module: torch.nn.Module, module_path: str = "", alias_path:Optional[str] = None, rename: Optional[Dict[str,str]] = None): - self.path = module_path + self.path = alias_path or module_path + self._path = module_path self._module = module + + self._rename = rename self._iteration_stack = [0] @@ -111,7 +118,7 @@ def output(self) -> InterventionProxyType: else: fake_output = self._fake_outputs[iteration] - module_path = f"{self.path}.output" + module_path = f"{self._path}.output" output = protocols.InterventionProtocol.add( self._tracer.graph, @@ -169,7 +176,7 @@ def inputs(self) -> InterventionProxyType: else: fake_input = self._fake_inputs[iteration] - module_path = f"{self.path}.input" + module_path = f"{self._path}.input" input = protocols.InterventionProtocol.add( self._tracer.graph, @@ -330,8 +337,26 @@ def _add_envoy(self, module: torch.nn.Module, name: str) -> None: module (torch.nn.Module): Module to create Envoy for. name (str): name of envoy/attribute. """ - - envoy = Envoy(module, module_path=f"{self.path}.{name}") + + alias_path = None + + module_path = f"{self.path}.{name}" + + if self._rename is not None: + + for key, value in self._rename.items(): + + match = re.match(key, module_path) + + if match is not None: + + name = value + + alias_path = f"{self.path}.{name}" + + break + + envoy = Envoy(module, module_path=module_path, alias_path=alias_path, rename=self._rename) self._children.append(envoy) diff --git a/src/nnsight/modeling/mixins/loadable.py b/src/nnsight/modeling/mixins/loadable.py index 7c160942..914af587 100755 --- a/src/nnsight/modeling/mixins/loadable.py +++ b/src/nnsight/modeling/mixins/loadable.py @@ -1,3 +1,5 @@ +from typing import Dict, Optional + import torch from ...intervention import NNsight @@ -5,7 +7,7 @@ class LoadableMixin(NNsight): - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args, rename: Optional[Dict[str,str]] = None, **kwargs) -> None: if not isinstance(args[0], torch.nn.Module): @@ -15,7 +17,7 @@ def __init__(self, *args, **kwargs) -> None: model = args[0] - super().__init__(model) + super().__init__(model, rename=rename) def _load(self, *args, **kwargs) -> torch.nn.Module: diff --git a/src/nnsight/modeling/mixins/meta.py b/src/nnsight/modeling/mixins/meta.py index 839dc8ef..7ae550d3 100755 --- a/src/nnsight/modeling/mixins/meta.py +++ b/src/nnsight/modeling/mixins/meta.py @@ -1,3 +1,5 @@ +from typing import Dict, Optional + import torch from accelerate import init_empty_weights @@ -8,7 +10,7 @@ class MetaMixin(LoadableMixin): def __init__( - self, *args, dispatch: bool = False, meta_buffers: bool = True, **kwargs + self, *args, dispatch: bool = False, meta_buffers: bool = True, rename: Optional[Dict[str,str]] = None, **kwargs ) -> None: self.dispatched = dispatch @@ -23,7 +25,7 @@ def __init__( model = self._load_meta(*args, **kwargs) - NNsight.__init__(self, model) + NNsight.__init__(self, model, rename=rename) self.args = args self.kwargs = kwargs From 42500f5bb1b4408a7bebaa2b46902db12ed0fa28 Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Fri, 6 Dec 2024 21:09:15 -0500 Subject: [PATCH 2/2] Updates @Clement --- src/nnsight/intervention/envoy.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/nnsight/intervention/envoy.py b/src/nnsight/intervention/envoy.py index 616f1e2f..9f7d8d53 100755 --- a/src/nnsight/intervention/envoy.py +++ b/src/nnsight/intervention/envoy.py @@ -30,8 +30,8 @@ class Envoy(Generic[InterventionProxyType, InterventionNodeType]): _children (List[Envoy]): Immediate Envoy children of this Envoy. _fake_outputs (List[torch.Tensor]): List of 'meta' tensors built from the outputs most recent _scan. Is list as there can be multiple shapes for a module called more than once. _fake_inputs (List[torch.Tensor]): List of 'meta' tensors built from the inputs most recent _scan. Is list as there can be multiple shapes for a module called more than once. - _rename (Optional[Dict[str,str]]): Optional mapping of (regex string -> new name). - For example to rename all gpt 'attn' modules to 'attention' you would: rename={r"\.transformer\.h\.\d+.attn": "attention"} + _rename (Optional[Dict[str,str]]): Optional mapping of (old name -> new name). + For example to rename all gpt 'attn' modules to 'attention' you would: rename={r"attn": "attention"} Not this does not actually change the underlying module names, just how you access its envoy. Renaming will replace Envoy.path but Envoy._path represents the pre-renamed true attribute path. _tracer (nnsight.context.Tracer.Tracer): Object which adds this Envoy's module's output and input proxies to an intervention graph. Must be set on Envoys objects manually by the Tracer. """ @@ -346,10 +346,8 @@ def _add_envoy(self, module: torch.nn.Module, name: str) -> None: for key, value in self._rename.items(): - match = re.match(key, module_path) - - if match is not None: - + if name == key: + name = value alias_path = f"{self.path}.{name}"