From 03d653967161b82e2d1a324287aa0d015d5d7991 Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Sat, 12 Oct 2024 23:36:43 -0400 Subject: [PATCH 1/2] Envoy.iter! --- src/nnsight/contexts/Tracer.py | 7 +- src/nnsight/envoy.py | 640 ++++++++++++++++------------ src/nnsight/intervention.py | 171 ++++---- src/nnsight/models/LanguageModel.py | 14 +- src/nnsight/models/NNsightModel.py | 33 +- src/nnsight/models/vllm/__init__.py | 14 +- src/nnsight/tracing/Node.py | 53 +-- src/nnsight/util.py | 11 - 8 files changed, 538 insertions(+), 405 deletions(-) diff --git a/src/nnsight/contexts/Tracer.py b/src/nnsight/contexts/Tracer.py index 7115fd45..39cfefed 100755 --- a/src/nnsight/contexts/Tracer.py +++ b/src/nnsight/contexts/Tracer.py @@ -138,9 +138,14 @@ def batch( if batched_input is None: - batched_input = tuple(tuple(), dict()) + batched_input = (((0, -1),), dict()) return batched_input, batch_groups + + @property + def _invoker_group(self): + + return len(self._invoker_inputs) - 1 ##### BACKENDS ############################### diff --git a/src/nnsight/envoy.py b/src/nnsight/envoy.py index 1d3b12e1..27c92da5 100755 --- a/src/nnsight/envoy.py +++ b/src/nnsight/envoy.py @@ -2,56 +2,324 @@ import inspect import warnings -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Tuple, Union +import weakref import torch +from typing_extensions import Self +from contextlib import AbstractContextManager from .contexts.backends import EditBackend from .contexts.Tracer import Tracer from .intervention import InterventionProtocol, InterventionProxy from .tracing import protocols - class Envoy: - """Envoy object act as proxies for torch modules within a model's module tree in order to add nnsight functionality. + """Envoy objects act as proxies for torch modules themselves within a model's module tree in order to add nnsight functionality. Proxies of the underlying module's output and input are accessed by `.output` and `.input` respectively. Attributes: - path (str): String representing the attribute path of this Envoy's module relative the the root model. Separated by '.' e.x ('transformer.h.0.mlp'). Set by NNsight on initialization of meta model. + path (str): String representing the attribute path of this Envoy's module relative the the root model. Separated by '.' e.x ('.transformer.h.0.mlp'). + output (nnsight.intervention.InterventionProxy): Proxy object representing the output of this Envoy's module. Reset on forward pass. + inputs (nnsight.intervention.InterventionProxy): Proxy object representing the inputs of this Envoy's module. Proxy is in the form of (Tuple[Tuple[], Dict[str, ]])Reset on forward pass. + input (nnsight.intervention.InterventionProxy): Alias for the first positional Proxy input i.e Envoy.inputs[0][0] + iter (nnsight.envoy.EnvoyIterator): Iterator object allowing selection of specific .input and .output iterations of this Envoy. + _module (torch.nn.Module): Underlying torch module. + _children (List[Envoy]): Immediate Envoy children of this Envoy. _fake_outputs (List[torch.Tensor]): List of 'meta' tensors built from the outputs most recent _scan. Is list as there can be multiple shapes for a module called more than once. _fake_inputs (List[torch.Tensor]): List of 'meta' tensors built from the inputs most recent _scan. Is list as there can be multiple shapes for a module called more than once. - output (nnsight.intervention.InterventionProxy): Proxy object representing the output of this Envoy's module. Reset on forward pass. - input (nnsight.intervention.InterventionProxy): Proxy object representing the input of this Envoy's module. Reset on forward pass. - _call_iter (int): Integer representing the current iteration of this Envoy's module's inputs/outputs. + _tracer (nnsight.context.Tracer.Tracer): Object which adds this Envoy's module's output and input proxies to an intervention graph. Must be set on Envoys objects manually by the Tracer. """ def __init__(self, module: torch.nn.Module, module_path: str = ""): self.path = module_path + + self._module = module + + self._iteration_stack = [0] self._fake_outputs: List[torch.Tensor] = [] self._fake_inputs: List[torch.Tensor] = [] - self._output: Optional[InterventionProxy] = None - self._input: Optional[InterventionProxy] = None - - self._call_iter = 0 + self._output_stack: List[Optional[InterventionProxy]] = [None] + self._input_stack: List[Optional[InterventionProxy]] = [None] self._tracer: Tracer = None - self._module = module - self._sub_envoys: List[Envoy] = [] + self._children: List[Envoy] = [] # Register hook on underlying module to update the _fake_outputs and _fake_inputs on forward pass. self._hook_handle = self._module.register_forward_hook( self._hook, with_kwargs=True ) + # Recurse into PyTorch module tree. for name, module in self._module.named_children(): setattr(self, name, module) + # Public API ################ + + def __call__( + self, *args: List[Any], hook=False, **kwargs: Dict[str, Any] + ) -> InterventionProxy: + """Creates a proxy to call the underlying module's forward method with some inputs. + + Returns: + InterventionProxy: Module call proxy. + """ + + if not self._tracing() or self._scanning(): + return self._module(*args, **kwargs) + + if isinstance(self._tracer.backend, EditBackend): + hook = True + + return protocols.ApplyModuleProtocol.add( + self._tracer.graph, self.path, *args, hook=hook, **kwargs + ) + + @property + def output(self) -> InterventionProxy: + """ + Calling denotes the user wishes to get the output of the underlying module and therefore we create a Proxy of that request. + Only generates a proxy the first time it is references otherwise return the already set one. + + Returns: + InterventionProxy: Output proxy. + """ + output = self._output_stack.pop() + + if output is None: + + if isinstance(self._module, torch.nn.ModuleList): + + output = [envoy.output for envoy in self._children] + + return output + else: + + iteration = self._iteration_stack[-1] + + if len(self._fake_outputs) == 0: + fake_output = inspect._empty + elif iteration >= len(self._fake_outputs): + # TODO warning? + fake_output = self._fake_outputs[-1] + else: + fake_output = self._fake_outputs[iteration] + + module_path = f"{self.path}.output" + + output = InterventionProtocol.add( + self._tracer.graph, + fake_output, + args=[ + module_path, + self._tracer._invoker_group, + iteration, + ], + ) + + self._output_stack.append(output) + + return output + + @output.setter + def output(self, value: Union[InterventionProxy, Any]) -> None: + """ + Calling denotes the user wishes to set the output of the underlying module and therefore we create a Proxy of that request. + + Args: + value (Union[InterventionProxy, Any]): Value to set output to. + """ + + protocols.SwapProtocol.add(self.output.node, value) + + self._output_stack[-1] = None + + @property + def inputs(self) -> InterventionProxy: + """ + Calling denotes the user wishes to get the input of the underlying module and therefore we create a Proxy of that request. + Only generates a proxy the first time it is references otherwise return the already set one. + + Returns: + InterventionProxy: Input proxy. + """ + + input = self._input_stack.pop() + + if input is None: + + if isinstance(self._module, torch.nn.ModuleList): + + input = [envoy.input for envoy in self._children] + + return input + else: + + iteration = self._iteration_stack[-1] + + if len(self._fake_inputs) == 0: + fake_input = inspect._empty + elif iteration >= len(self._fake_inputs): + # TODO warning? + fake_input = self._fake_inputs[-1] + else: + fake_input = self._fake_inputs[iteration] + + module_path = f"{self.path}.input" + + input = InterventionProtocol.add( + self._tracer.graph, + fake_input, + args=[ + module_path, + self._tracer._invoker_group, + iteration, + ], + ) + + self._input_stack.append(input) + + return input + + @inputs.setter + def inputs(self, value: Union[InterventionProxy, Any]) -> None: + """ + Calling denotes the user wishes to set the input of the underlying module and therefore we create a Proxy of that request. + + Args: + value (Union[InterventionProxy, Any]): Value to set input to. + """ + + protocols.SwapProtocol.add(self.inputs.node, value) + + self._input_stack[-1] = None + @property + def input(self) -> InterventionProxy: + """Getting the first positional argument input of the model's module. + + Returns: + InterventionProxy: Input proxy. + """ + + return self.inputs[0][0] + + @input.setter + def input(self, value: Union[InterventionProxy, Any]) -> None: + """Setting the value of the input's first positional argument in the model's module. + + Args; + value (Union[InterventionProxy, Any]): Value to set the input to. + """ + + self.inputs = ((value,) + self.inputs[0][1:],) + (self.inputs[1:]) + + @property + def iter(self) -> IterationEnvoy: + + return IterationEnvoy(self) + + @iter.setter + def iter(self, iteration: Union[int, List[int], slice]) -> None: + self._iteration_stack.append(iteration) + + def next(self, increment: int = 1) -> Envoy: + """By default, this modules inputs and outputs only refer to the first time its called. Use `.next()`to select which iteration .input an .output refer to. + + Args: + increment (int, optional): How many iterations to jump. Defaults to 1. + + Returns: + Envoy: Self. + """ + + return self.iter[self._iteration_stack[-1] + increment].__enter__() + + def all(self, propagate: bool = True) -> Envoy: + """By default, this modules inputs and outputs only refer to the first time its called. Use `.all()`to have .input and .output refer to all iterations. + + Returns: + Envoy: Self. + """ + + return self.iter[:].__enter__() + + def to(self, *args, **kwargs) -> Envoy: + """Override torch.nn.Module.to so this returns the Envoy, not the underlying module when doing: model = model.to(...) + + Returns: + Envoy: Envoy. + """ + + self._module = self._module.to(*args, **kwargs) + + return self + + def modules( + self, + include_fn: Callable[[Envoy], bool] = None, + names: bool = False, + envoys: List = None, + ) -> List[Envoy]: + """Returns all Envoys in the Envoy tree. + + Args: + include_fn (Callable, optional): Optional function to be ran against all Envoys to check if they should be included in the final collection of Envoys. Defaults to None. + names (bool, optional): If to include the name/module_path of returned Envoys along with the Envoy itself. Defaults to False. + + Returns: + List[Envoy]: Included Envoys + """ + + if envoys is None: + envoys = list() + + included = True + + if include_fn is not None: + included = include_fn(self) + + if included: + if names: + envoys.append((self.path, self)) + else: + envoys.append(self) + + for sub_envoy in self._children: + sub_envoy.modules(include_fn=include_fn, names=names, envoys=envoys) + + return envoys + + def named_modules(self, *args, **kwargs) -> List[Tuple[str, Envoy]]: + """Returns all Envoys in the Envoy tree along with their name/module_path. + + Args: + include_fn (Callable, optional): Optional function to be ran against all Envoys to check if they should be included in the final collection of Envoys. Defaults to None. + + Returns: + List[Tuple[str, Envoy]]: Included Envoys and their names/module_paths. + """ + + return self.modules(*args, **kwargs, names=True) + + def to(self, *args, **kwargs) -> Self: + """Override torch.nn.Module.to so this returns the Envoy, not the underlying module when doing: model = model.to(...) + + Returns: + Envoy: Envoy. + """ + + self._module = self._module.to(*args, **kwargs) + + return self + + # Private API ############################### + def _update(self, module: torch.nn.Module) -> None: """Updates the ._model attribute using a new model of the same architecture. Used when loading the real weights (dispatching) and need to replace the underlying modules. @@ -67,7 +335,7 @@ def _update(self, module: torch.nn.Module) -> None: for i, module in enumerate(self._module.children()): - self._sub_envoys[i]._update(module) + self._children[i]._update(module) def _add_envoy(self, module: torch.nn.Module, name: str) -> None: """Adds a new Envoy for a given torch module under this Envoy. @@ -79,7 +347,7 @@ def _add_envoy(self, module: torch.nn.Module, name: str) -> None: envoy = Envoy(module, module_path=f"{self.path}.{name}") - self._sub_envoys.append(envoy) + self._children.append(envoy) # If the module already has a sub-module named 'input' or 'output', # mount the proxy access to 'nns_input' or 'nns_output instead. @@ -140,10 +408,9 @@ def _set_tracer(self, tracer: Tracer, propagate=True): self._tracer = tracer if propagate: - for envoy in self._sub_envoys: + for envoy in self._children: envoy._set_tracer(tracer, propagate=True) - - + def _tracing(self) -> bool: """Whether or not tracing. @@ -174,6 +441,21 @@ def _scanning(self) -> bool: return False + def _set_iteration(self, iteration:Optional[int] = None, propagate:bool=True) -> None: + + if iteration is not None: + self._iteration_stack.append(iteration) + self._output_stack.append(None) + self._input_stack.append(None) + else: + self._iteration_stack.pop() + self._output_stack.pop() + self._input_stack.pop() + + if propagate: + for envoy in self._children: + envoy._set_iteration(iteration, propagate=True) + def _reset_proxies(self, propagate: bool = True) -> None: """Sets proxies to None. @@ -181,11 +463,11 @@ def _reset_proxies(self, propagate: bool = True) -> None: propagate (bool, optional): If to propagate to all sub-modules. Defaults to True. """ - self._output: InterventionProxy = None - self._input: InterventionProxy = None + self._output_stack = [] + self._input_stack = [] if propagate: - for envoy in self._sub_envoys: + for envoy in self._children: envoy._reset_proxies(propagate=True) def _reset(self, propagate: bool = True) -> None: @@ -197,10 +479,10 @@ def _reset(self, propagate: bool = True) -> None: self._reset_proxies(propagate=False) - self._call_iter = 0 + self._set_iteration(0, propagate=False) if propagate: - for envoy in self._sub_envoys: + for envoy in self._children: envoy._reset(propagate=True) def _clear(self, propagate: bool = True) -> None: @@ -216,7 +498,7 @@ def _clear(self, propagate: bool = True) -> None: self._fake_inputs = [] if propagate: - for envoy in self._sub_envoys: + for envoy in self._children: envoy._clear(propagate=True) def _hook( @@ -229,113 +511,14 @@ def _hook( if self._scanning(): - self._reset_proxies(propagate=False) - input = (input, input_kwargs) self._fake_outputs.append(output) self._fake_inputs.append(input) - def next(self, increment: int = 1, propagate: bool = True) -> Envoy: - """By default, this modules inputs and outputs only refer to the first time its called. Use `.next()`to select which iteration .input an .output refer to. - - Args: - increment (int, optional): How many iterations to jump. Defaults to 1. - propagate (bool, optional): If to also call `.next()` on all sub envoys/modules.. Defaults to True. - - Returns: - Envoy: Self. - """ - - self._call_iter += increment - - self._reset_proxies(propagate=False) - - if propagate: - for envoy in self._sub_envoys: - envoy.next(increment=increment, propagate=True) - - return self - - def all(self, propagate: bool = True) -> Envoy: - """By default, this modules inputs and outputs only refer to the first time its called. Use `.all()`to have .input and .output refer to all iterations. - - Args: - propagate (bool, optional): If to also call `.all()` on all sub envoys/modules.. Defaults to True. - - Returns: - Envoy: Self. - """ - - self._call_iter = -1 - - if propagate: - for envoy in self._sub_envoys: - envoy.all(propagate=True) - - return self - - def to(self, *args, **kwargs) -> Envoy: - """Override torch.nn.Module.to so this returns the Envoy, not the underlying module when doing: model = model.to(...) - - Returns: - Envoy: Envoy. - """ - - self._module = self._module.to(*args, **kwargs) - - return self - - def modules( - self, - include_fn: Callable[[Envoy], bool] = None, - names: bool = False, - envoys: List = None, - ) -> List[Envoy]: - """Returns all Envoys in the Envoy tree. - - Args: - include_fn (Callable, optional): Optional function to be ran against all Envoys to check if they should be included in the final collection of Envoys. Defaults to None. - names (bool, optional): If to include the name/module_path of returned Envoys along with the Envoy itself. Defaults to False. - - Returns: - List[Envoy]: Included Envoys - """ - - if envoys is None: - envoys = list() - - included = True - - if include_fn is not None: - included = include_fn(self) - - if included: - if names: - envoys.append((self.path, self)) - else: - envoys.append(self) - - for sub_envoy in self._sub_envoys: - sub_envoy.modules(include_fn=include_fn, names=names, envoys=envoys) - - return envoys - - def named_modules(self, *args, **kwargs) -> List[Tuple[str, Envoy]]: - """Returns all Envoys in the Envoy tree along with their name/module_path. - - Args: - include_fn (Callable, optional): Optional function to be ran against all Envoys to check if they should be included in the final collection of Envoys. Defaults to None. - - Returns: - List[Tuple[str, Envoy]]: Included Envoys and their names/module_paths. - """ - - return self.modules(*args, **kwargs, names=True) - def _repr_module_list(self): - list_of_reprs = [repr(item) for item in self._sub_envoys] + list_of_reprs = [repr(item) for item in self._children] if len(list_of_reprs) == 0: return self._module._get_name() + "()" @@ -414,7 +597,7 @@ def __iter__(self) -> Iterator[Envoy]: Iterator[Envoy]: Iterator. """ - return iter(self._sub_envoys) + return iter(self._children) def __getitem__(self, key: int) -> Envoy: """Wrapper method for underlying ModuleList getitem. @@ -426,7 +609,7 @@ def __getitem__(self, key: int) -> Envoy: Envoy: Envoy. """ - return self._sub_envoys[key] + return self._children[key] def __len__(self) -> int: """Wrapper method for underlying ModuleList len. @@ -437,7 +620,10 @@ def __len__(self) -> int: return len(self._module) - def __getattr__(self, key: str) -> Union[Envoy, Any]: + def __getattribute__(self, name: str) -> Union[Tracer]: + return super().__getattribute__(name) + + def __getattr__(self, key: str) -> Union[Tracer]: """Wrapper method for underlying module's attributes. Args: @@ -461,146 +647,70 @@ def __setattr__(self, key: Any, value: Any) -> None: else: super().__setattr__(key, value) - - def __call__( - self, *args: List[Any], hook=False, **kwargs: Dict[str, Any] - ) -> InterventionProxy: - """Creates a proxy to call the underlying module's forward method with some inputs. - - Returns: - InterventionProxy: Module call proxy. - """ + +class IterationEnvoy(Envoy, AbstractContextManager): + + def __init__(self, envoy:Envoy) -> None: + + self.__dict__.update(envoy.__dict__) + + self._iteration = self._iteration_stack[-1] + + self._open_context = False - if not self._tracing() or self._scanning(): - return self._module(*args, **kwargs) - - if isinstance(self._tracer.backend, EditBackend): - hook = True - - return protocols.ApplyModuleProtocol.add( - self._tracer.graph, self.path, *args, hook=hook, **kwargs - ) - @property def output(self) -> InterventionProxy: - """ - Calling denotes the user wishes to get the output of the underlying module and therefore we create a Proxy of that request. - Only generates a proxy the first time it is references otherwise return the already set one. - - Returns: - InterventionProxy: Output proxy. - """ - if self._output is None: - - if isinstance(self._module, torch.nn.ModuleList): - - self._output = [envoy.output for envoy in self._sub_envoys] - - return self._output - - if len(self._fake_outputs) == 0: - fake_output = inspect._empty - elif self._call_iter >= len(self._fake_outputs): - # TODO warning? - fake_output = self._fake_outputs[-1] - else: - fake_output = self._fake_outputs[self._call_iter] - - module_path = f"{self.path}.output" - - self._output = InterventionProtocol.add( - self._tracer.graph, - fake_output, - args=[ - module_path, - len(self._tracer._invoker_inputs) - 1, - self._call_iter, - ], - ) - - return self._output - - @output.setter - def output(self, value: Union[InterventionProxy, Any]) -> None: - """ - Calling denotes the user wishes to set the output of the underlying module and therefore we create a Proxy of that request. - - Args: - value (Union[InterventionProxy, Any]): Value to set output to. - """ - - protocols.SwapProtocol.add(self.output.node, value) - - self._output = None - - @property - def inputs(self) -> InterventionProxy: - """ - Calling denotes the user wishes to get the input of the underlying module and therefore we create a Proxy of that request. - Only generates a proxy the first time it is references otherwise return the already set one. - - Returns: - InterventionProxy: Input proxy. - """ - if self._input is None: - - if isinstance(self._module, torch.nn.ModuleList): - - self._input = [envoy.input for envoy in self._sub_envoys] - - return self._input - - if len(self._fake_inputs) == 0: - fake_input = inspect._empty - elif self._call_iter >= len(self._fake_inputs): - # TODO warning? - fake_input = self._fake_inputs[-1] - else: - fake_input = self._fake_inputs[self._call_iter] - - module_path = f"{self.path}.input" - - self._input = InterventionProtocol.add( - self._tracer.graph, - fake_input, - args=[ - module_path, - len(self._tracer._invoker_inputs) - 1, - self._call_iter, - ], - ) - - return self._input - - @inputs.setter - def inputs(self, value: Union[InterventionProxy, Any]) -> None: - """ - Calling denotes the user wishes to set the input of the underlying module and therefore we create a Proxy of that request. - - Args: - value (Union[InterventionProxy, Any]): Value to set input to. - """ - - protocols.SwapProtocol.add(self.inputs.node, value) - - self._input = None - + + self._output_stack.append(None) + self._iteration_stack.append(self._iteration) + + output = super().output + + self._output_stack.pop() + self._iteration_stack.pop() + + return output + @property def input(self) -> InterventionProxy: - """Getting the first positional argument input of the model's module. - - Returns: - InterventionProxy: Input proxy. - """ - - return self.inputs[0][0] - - @input.setter - def input(self, value: Union[InterventionProxy, Any]) -> None: - """Setting the value of the input's first positionl argument in the model's module. - - Args; - value (Union[InterventionProxy, Any]): Value to set the input to. - """ - - self.inputs = ((value,) + self.inputs[0][1:],) + (self.inputs[1:]) + + self._input_stack.append(None) + self._iteration_stack.append(self._iteration) + + input = super().input + + self._input_stack.pop() + self._iteration_stack.pop() + + return input + + def __getitem__(self, key: Union[int, List[int], slice]) -> Self: + + # TODO: Error if not valid key type + + if isinstance(key, tuple): + + key = list(key) + + self._iteration = key + + return self + + def __enter__(self) -> IterationEnvoy: + + if not self._open_context: + + self._set_iteration(self._iteration) + + self._open_context = True + + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + + self._set_iteration() + + self._open_context = False + + if isinstance(exc_val, BaseException): + raise exc_val diff --git a/src/nnsight/intervention.py b/src/nnsight/intervention.py index 0e63cb81..6ab0144b 100755 --- a/src/nnsight/intervention.py +++ b/src/nnsight/intervention.py @@ -12,7 +12,17 @@ import inspect from collections import defaultdict from contextlib import AbstractContextManager -from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Collection, + Dict, + List, + Optional, + Tuple, + Union, +) import torch from torch.utils.hooks import RemovableHandle @@ -249,7 +259,6 @@ class InterventionProtocol(Protocol): attachment_name = "nnsight_module_nodes" attachment_flag_name = "nnsight_compiled" - condition: bool = False @classmethod def add( @@ -478,19 +487,30 @@ def intervene( node = intervention_handler.graph.nodes[intervention_node_name] - # Args for intervention nodes are (module_path, batch_group_idx, call_iter). - _, batch_group_idx, call_iter = node.args - - # If this node will be executed for multiple iterations, we need to reset the sub-graph to b executed once more - if call_iter == -1: - node.reset(propagate=True) - + # Args for intervention nodes are (module_path, batch_group, iteration). + _, batch_group, iteration = node.args + # Updates the count of intervention node calls. - # If count matches call_iter, time to inject value into node. - elif call_iter != intervention_handler.count(intervention_node_name): + # If count matches the Node's iteration, its ready to be executed. + ready, defer = intervention_handler.count(node.name, iteration) + # Dont execute if the node isnt ready (call count / iteration) or its noy fulfilled (conditional) + if not ready or (not node.fulfilled() and not node.executed()): continue + # If this execution is possibly not the last time it will be executed, + # we need to defer destruction of dependencies outside the sub-graph. + if defer: + cls.defer(node) + + # If this node will be executed for multiple iterations, we need to reset the sub-graph to be executed once more. + if node.executed() or defer: + + node.reset(propagate=True) + + # Make the node "executed" + node.remaining_dependencies -= 1 + value = activations narrowed = False @@ -498,7 +518,7 @@ def intervene( if len(intervention_handler.batch_groups) > 1: batch_start, batch_size = intervention_handler.batch_groups[ - batch_group_idx + batch_group ] def narrow(acts: torch.Tensor): @@ -519,10 +539,8 @@ def narrow(acts: torch.Tensor): torch.Tensor, ) - # If this Node may be executed more than once, we need to defer any Node destruction until after interleaving. - with intervention_handler.defer_destruction(call_iter == -1): - # Value injection. - node.set_value(value) + # Value injection. + node.set_value(value) # Check if through the previous value injection, there was a 'swap' intervention. # This would mean we want to replace activations for this batch with some other ones. @@ -547,6 +565,19 @@ def narrow(acts: torch.Tensor): return activations + @classmethod + def execute(cls, node: Node): + # To prevent the node from looking like its executed when calling Graph.execute + node.remaining_dependencies += 1 + + @classmethod + def defer(cls, node: Node) -> None: + + for listener in node.listeners: + for dependency in listener.arg_dependencies: + dependency.remaining_listeners += 1 + cls.defer(listener) + @classmethod def style(cls) -> Dict[str, Any]: """Visualization style for this protocol node. @@ -671,23 +702,62 @@ def __init__( defaultdict(lambda: 0) if call_counter is None else call_counter ) self.batch_size = sum(self.batch_groups[-1]) - self.destruction_deferrer = InterventionHandler.DestructionDeferrer() + self.deferred = set() - def count(self, name: str) -> int: - """Increments the count of times a given Intervention Node has tried to be executed and returns the count. + def count(self, name: str, iteration: Union[int, List[int], slice]) -> bool: + """Increments the count of times a given Intervention Node has tried to be executed and returns if the Node is ready and if it needs to be deferred. Args: name (str): Name of intervention node to return count for. + iteration (Union[int, List[int], slice]): What iteration(s) this Node should be executed for. Returns: - int: Count. + bool: If this Node should be executed on this iteration. + bool: If this Node and recursive listeners should have updating their remaining listeners (and therefore their destruction) deferred. """ + ready = False + defer = False + count = self.call_counter[name] + if isinstance(iteration, int): + ready = count == iteration + elif isinstance(iteration, list): + iteration.sort() + + ready = count in iteration + defer = count != iteration[-1] + + elif isinstance(iteration, slice): + + start = iteration.start or 0 + stop = iteration.stop + + ready = count >= start and (stop is None or count < stop) + + defer = stop is None or count < stop - 1 + + if defer: + self.deferred.add(name) + else: + self.deferred.discard(name) + self.call_counter[name] += 1 - return count + return ready, defer + + def cleanup(self) -> None: + + def inner(node: Node): + + for listener in node.listeners: + listener.update_dependencies() + inner(listener) + + for name in self.deferred: + + inner(self.graph.nodes[name]) class MultiCounter(dict): @@ -712,19 +782,19 @@ def __getstate__(self): state["graph_id_to_call_counter"] = self.graph_id_to_call_counter return state - + def __setstate__(self, state: Dict) -> None: self.__dict__.update(state) - + @classmethod def persistent_call_counter(cls, graph: Graph) -> dict: def inner(graph: Graph): if cls.attachment_name not in graph.attachments: - + graph.attachments[cls.attachment_name] = defaultdict(int) - + return graph.attachments[cls.attachment_name] if isinstance(graph, MultiGraph): @@ -740,53 +810,8 @@ def inner(graph: Graph): return inner(graph) - def defer_destruction(self, defer: bool): - - self.destruction_deferrer.defer = defer - - return self.destruction_deferrer - - def destroy(self): - - self.destruction_deferrer.destroy() - - class DestructionDeferrer: - - def __init__(self, defer: bool = False) -> None: - - self.defer = defer - self.deferred_nodes: Dict[str, Node] = {} - self.patcher = Patcher( - [ - Patch( - Node, - lambda node: self.deferred_nodes.update({node.name: node}), - "destroy", - ) - ] - ) - - def __enter__(self): - - if self.defer: - - self.patcher.__enter__() - - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - - if self.defer: - - self.patcher.__exit__(None, None, None) - - if isinstance(exc_val, BaseException): - raise exc_val - - def destroy(self): - - for node in self.deferred_nodes.values(): - node.destroy() +if TYPE_CHECKING: - self.deferred_nodes = {} + class InterventionProxy(InterventionProxy, torch.Tensor): + pass diff --git a/src/nnsight/models/LanguageModel.py b/src/nnsight/models/LanguageModel.py index 39cdcc59..1227e15b 100755 --- a/src/nnsight/models/LanguageModel.py +++ b/src/nnsight/models/LanguageModel.py @@ -3,6 +3,7 @@ import json import warnings from typing import ( + TYPE_CHECKING, Any, Dict, Generic, @@ -91,11 +92,10 @@ def t(self) -> LanguageModel.TokenIndexer: return self.token -from ..util import TypeHint, hint -@hint -class LanguageModel(RemoteableMixin, TypeHint[Union[PreTrainedModel]]): + +class LanguageModel(RemoteableMixin): """LanguageModels are NNsight wrappers around transformers language models. Inputs can be in the form of: @@ -421,3 +421,11 @@ def __setitem__( key = self.convert_idx(key) self.proxy[:, key] = value + +if TYPE_CHECKING: + + class LanguageModel(LanguageModel, PreTrainedModel): + + def generate(self, *args, **kwargs) -> Tracer: + pass + \ No newline at end of file diff --git a/src/nnsight/models/NNsightModel.py b/src/nnsight/models/NNsightModel.py index 10c9e802..95a770a6 100755 --- a/src/nnsight/models/NNsightModel.py +++ b/src/nnsight/models/NNsightModel.py @@ -1,7 +1,7 @@ from __future__ import annotations import weakref -from typing import (Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union) import torch @@ -31,6 +31,7 @@ class NNsight: _model (torch.nn.Module): Underlying torch module. _envoy (Envoy): Envoy for underlying model. _session (Session): Session object if in a Session. + _default_graph (Graph): """ __methods__: Dict[str, str] = dict() @@ -50,9 +51,6 @@ def __init__( self._session: Session = None self._default_graph: Graph = None - def __new__(cls, *args, **kwargs) -> Self | Envoy: - return super().__new__(cls) - #### Public API ############## def trace( @@ -405,7 +403,7 @@ def interleave( if not node.executed(): node.clean() finally: - intervention_handler.destroy() + intervention_handler.cleanup() def to(self, *args, **kwargs) -> Self: """Override torch.nn.Module.to so this returns the NNSight model, not the underlying module when doing: model = model.to(...) @@ -414,9 +412,18 @@ def to(self, *args, **kwargs) -> Self: Envoy: Envoy. """ - self._model = self._model.to(*args, **kwargs) + self._envoy.to(*args, **kwargs) return self + + @property + def device(self) -> Optional[torch.device]: + + try: + return next(self._model.parameters()).device + except: + return None + def clear_edits(self) -> None: """Resets the default graph of this model.""" @@ -450,13 +457,6 @@ def _shallow_copy(self) -> Self: return copy - @property - def device(self) -> Optional[torch.device]: - - try: - return next(self._model.parameters()).device - except: - return None def to_device(self, data: Any) -> Any: @@ -487,7 +487,7 @@ def __setattr__(self, key: Any, value: Any) -> None: object.__setattr__(self, key, value) - def __getattr__(self, key: Any) -> Union[Any, InterventionProxy, Envoy, Tracer]: + def __getattr__(self, key: Any): """Wrapper of ._envoy's attributes to access module's inputs and outputs. Returns: @@ -557,3 +557,8 @@ def _batch( ) return args, kwargs + + +if TYPE_CHECKING: + class NNsight(NNsight, Envoy): + pass diff --git a/src/nnsight/models/vllm/__init__.py b/src/nnsight/models/vllm/__init__.py index 86fe9f37..c928463c 100755 --- a/src/nnsight/models/vllm/__init__.py +++ b/src/nnsight/models/vllm/__init__.py @@ -1,13 +1,12 @@ import weakref -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from ...envoy import Envoy from ...tracing import protocols from ...tracing.Graph import Graph -from ...util import TypeHint, WrapperModule, hint +from ...util import WrapperModule from ..mixins import RemoteableMixin from .executors.GPUExecutor import NNsightGPUExecutor from .executors.RayGPUExecutor import NNsightRayGPUExecutor @@ -31,8 +30,7 @@ ) from e -@hint -class VLLM(RemoteableMixin, TypeHint[Union[LLM, Envoy]]): +class VLLM(RemoteableMixin): """NNsight wrapper to conduct interventions on a vLLM inference engine. .. code-block:: python @@ -216,3 +214,9 @@ def _execute( ) -> Any: self.vllm_entrypoint.generate(prompts, sampling_params=params) + +if TYPE_CHECKING: + + class VLLM(VLLM,LLM): + pass + \ No newline at end of file diff --git a/src/nnsight/tracing/Node.py b/src/nnsight/tracing/Node.py index 5ce112fd..78f1b65d 100755 --- a/src/nnsight/tracing/Node.py +++ b/src/nnsight/tracing/Node.py @@ -54,15 +54,17 @@ def __getstate__(self) -> Dict: state["graph"] = util.weakref_to_obj(self.graph) state["listeners"] = [ util.weakref_to_obj(listener) for listener in self.listeners - ] + ] return state def __setstate__(self, state: Dict) -> None: state["graph"] = weakref.proxy(state["graph"]) - state["listeners"] = [weakref.proxy(listener) for listener in state["listeners"]] - + state["listeners"] = [ + weakref.proxy(listener) for listener in state["listeners"] + ] + self.__dict__.update(state) def __init__( @@ -74,7 +76,6 @@ def __init__( kwargs: Dict[str, Any] = None, name: str = None, ) -> None: - super().__init__() if args is None: args = list() @@ -103,9 +104,7 @@ def __init__( self.preprocess() # Node.graph is a weak reference to avoid reference loops. - self.graph = ( - weakref.proxy(self.graph) if self.graph is not None else None - ) + self.graph = weakref.proxy(self.graph) if self.graph is not None else None self.name: str = name @@ -175,9 +174,7 @@ def preprocess_node(node: Union[Node, Proxy]): if conditional_node: if all( [ - not protocols.ConditionalProtocol.is_node_conditioned( - arg - ) + not protocols.ConditionalProtocol.is_node_conditioned(arg) for arg in self.arg_dependencies ] ): @@ -282,16 +279,19 @@ def create( def reset(self, propagate: bool = False) -> None: """Resets this Nodes remaining_listeners and remaining_dependencies.""" - + self.remaining_listeners = len(self.listeners) self.remaining_dependencies = sum( [not node.executed() for node in self.arg_dependencies] ) + int(not (self.cond_dependency is None)) - + if propagate: for node in self.listeners: - if node.executed(): - node.reset(propagate=True) + node.reset(propagate=True) + + + + def done(self) -> bool: """Returns true if the value of this node has been set. @@ -482,9 +482,7 @@ def visualize( styles = { "node": {"color": "black", "shape": "ellipse"}, "label": ( - self.target - if isinstance(self.target, str) - else self.target.__name__ + self.target if isinstance(self.target, str) else self.target.__name__ ), "arg": defaultdict(lambda: {"color": "gray", "shape": "box"}), "arg_kname": defaultdict(lambda: None), @@ -498,13 +496,8 @@ def visualize( ): styles = self.target.style() - viz_graph.add_node( - node_name, label=styles["label"], **styles["node"] - ) - if ( - recursive - and self.target == protocols.LocalBackendExecuteProtocol - ): + viz_graph.add_node(node_name, label=styles["label"], **styles["node"]) + if recursive and self.target == protocols.LocalBackendExecuteProtocol: # recursively draw all sub-graphs for sub_node in self.args[0].graph.nodes.values(): @@ -528,9 +521,7 @@ def visualize( viz_graph, recursive, node_name + "_" ) else: - viz_graph.add_node( - node_name, label=styles["label"], **styles["node"] - ) + viz_graph.add_node(node_name, label=styles["label"], **styles["node"]) def visualize_args(arg_collection): """Recursively visualizes the arguments of this node. @@ -573,9 +564,7 @@ def visualize_args(arg_collection): viz_graph.add_node(name, label=label, **styles["arg"][key]) for dep_name in iter_val_dependencies: - viz_graph.add_edge( - dep_name, name, style="dashed", color="gray" - ) + viz_graph.add_edge(dep_name, name, style="dashed", color="gray") viz_graph.add_edge(name, node_name, style=styles["edge"][key]) @@ -584,9 +573,7 @@ def visualize_args(arg_collection): visualize_args(self.kwargs.items()) if isinstance(self.cond_dependency, Node): - name = self.cond_dependency.visualize( - viz_graph, recursive, backend_name - ) + name = self.cond_dependency.visualize(viz_graph, recursive, backend_name) viz_graph.add_edge( name, node_name, style=styles["edge"][None], color="#FF8C00" ) diff --git a/src/nnsight/util.py b/src/nnsight/util.py index b701ea7d..28c76020 100755 --- a/src/nnsight/util.py +++ b/src/nnsight/util.py @@ -157,14 +157,3 @@ def forward(self, *args, **kwargs): return args - -H = TypeVar("H") -P = TypeVar("P") - - -class TypeHint(Generic[H]): - pass - - -def hint(cls: Type[P | TypeHint[H]]) -> Union[Type[H], Type[P]]: - return cls From cad701511e444309c82d914f08146856227cda39 Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Sun, 13 Oct 2024 00:01:01 -0400 Subject: [PATCH 2/2] Type hinting --- src/nnsight/models/NNsightModel.py | 44 ++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/src/nnsight/models/NNsightModel.py b/src/nnsight/models/NNsightModel.py index 95a770a6..e12d60cc 100755 --- a/src/nnsight/models/NNsightModel.py +++ b/src/nnsight/models/NNsightModel.py @@ -1,20 +1,40 @@ from __future__ import annotations import weakref -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, - Union) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import torch from typing_extensions import Self from .. import util -from ..contexts.backends import (Backend, BridgeBackend, EditBackend, - LocalBackend, NoopBackend, RemoteBackend) +from ..contexts.backends import ( + Backend, + BridgeBackend, + EditBackend, + LocalBackend, + NoopBackend, + RemoteBackend, +) from ..contexts.session.Session import Session from ..contexts.Tracer import Tracer from ..envoy import Envoy -from ..intervention import (HookHandler, InterventionHandler, - InterventionProtocol, InterventionProxy) +from ..intervention import ( + HookHandler, + InterventionHandler, + InterventionProtocol, + InterventionProxy, +) from ..tracing import protocols from ..tracing.Graph import Graph @@ -31,7 +51,7 @@ class NNsight: _model (torch.nn.Module): Underlying torch module. _envoy (Envoy): Envoy for underlying model. _session (Session): Session object if in a Session. - _default_graph (Graph): + _default_graph (Graph): """ __methods__: Dict[str, str] = dict() @@ -415,7 +435,7 @@ def to(self, *args, **kwargs) -> Self: self._envoy.to(*args, **kwargs) return self - + @property def device(self) -> Optional[torch.device]: @@ -424,7 +444,6 @@ def device(self) -> Optional[torch.device]: except: return None - def clear_edits(self) -> None: """Resets the default graph of this model.""" self._default_graph = None @@ -457,7 +476,6 @@ def _shallow_copy(self) -> Self: return copy - def to_device(self, data: Any) -> Any: device = self.device @@ -560,5 +578,9 @@ def _batch( if TYPE_CHECKING: + class NNsight(NNsight, Envoy): - pass + def __getattribute__( + self, name: str + ) -> Union[Tracer, Session, Envoy, InterventionProxy, Any]: + pass