Skip to content

Commit

Permalink
Big refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Oct 11, 2024
1 parent 84c757c commit f8aee6a
Show file tree
Hide file tree
Showing 18 changed files with 507 additions and 478 deletions.
16 changes: 5 additions & 11 deletions src/nnsight/contexts/GraphBasedContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,16 @@ def __enter__(self) -> Self:
def __exit__(self, exc_type, exc_val, exc_tb) -> None:

GlobalTracingContext.try_deregister(self)

if isinstance(exc_val, BaseException):
self.graph.alive = False
self.graph = None
raise exc_val

self.backend(self)

if not isinstance(self.graph, weakref.ProxyType):
self.graph = weakref.proxy(self.graph)


### BACKENDS ########

Expand All @@ -237,22 +240,13 @@ def local_backend_execute(self) -> None:
self.graph.execute()
except protocols.EarlyStopProtocol.EarlyStopException as e:
raise e
finally:
graph = self.graph
graph.alive = False

if not isinstance(graph, weakref.ProxyType):
self.graph = weakref.proxy(graph)

def bridge_backend_handle(self, bridge: Bridge) -> None:

bridge.pop_graph()

protocols.LocalBackendExecuteProtocol.add(self, bridge.peek_graph())

self.graph = weakref.proxy(self.graph)


from inspect import getmembers, isclass

from torch.utils import data
Expand Down
30 changes: 16 additions & 14 deletions src/nnsight/contexts/Invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ class Invoker(AbstractContextManager):
def __init__(
self,
tracer: "Tracer",
*inputs: Any,
*args,
scan: bool = False,
**kwargs,
) -> None:

self.tracer = tracer
self.inputs = inputs
self.input = (args, kwargs)

self.scan = scan
self.kwargs = kwargs

self.scanning = False

Expand All @@ -71,28 +71,28 @@ def __enter__(self) -> Invoker:
# Set self.inputs to be the proxy_value so we can prepare_inputs, get the batch size, and scan.
if self.tracer.model._session is not None:

self.inputs, has_proxies_in_inputs = check_for_dependencies(
self.inputs
self.input, has_proxies_in_inputs = check_for_dependencies(
self.input
)

with GlobalTracingContext.exit_global_tracing_context():

if not has_proxies_in_inputs:

self.inputs, batch_size = self.tracer.model._prepare_inputs(
*self.inputs, **self.kwargs
self.input, batch_size = self.tracer.model._prepare_input(
*self.input[0], **self.input[1]
)

if self.scan:

inputs = self.inputs
input = self.input

if has_proxies_in_inputs:

inputs = util.apply(inputs, lambda x: x.proxy_value, Node)
input = util.apply(input, lambda x: x.proxy_value, Node)

inputs, batch_size = self.tracer.model._prepare_inputs(
*inputs, **self.kwargs
input, batch_size = self.tracer.model._prepare_input(
*input[0], **input[1]
)

self.tracer.model._envoy._clear()
Expand All @@ -112,8 +112,10 @@ def __enter__(self) -> Invoker:
shape_env=ShapeEnv(assume_static_by_default=True),
) as fake_mode:
with FakeCopyMode(fake_mode):
self.tracer.model._execute(
*copy.deepcopy(inputs),
fn = self.tracer.model._execute if self.tracer.method is None else getattr(self.tracer.model, self.tracer.method)
fn(
*copy.deepcopy(input[0]),
**copy.deepcopy(input[1]),
**copy.deepcopy(self.tracer._kwargs),
)

Expand All @@ -122,7 +124,7 @@ def __enter__(self) -> Invoker:
else:
self.tracer.model._envoy._reset()

self.tracer._invoker_inputs.append(self.inputs)
self.tracer._invoker_inputs.append(self.input)

return self

Expand Down
77 changes: 44 additions & 33 deletions src/nnsight/contexts/Tracer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from __future__ import annotations

import weakref
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

from typing_extensions import Self

from ..tracing import protocols
from ..tracing.Bridge import Bridge
from ..tracing.Graph import Graph
from . import resolve_dependencies
from .backends import Backend, EditBackend, BridgeMixin, EditMixin, RemoteMixin
from .backends import Backend, BridgeMixin, EditBackend, EditMixin, RemoteMixin
from .GraphBasedContext import GraphBasedContext
from .Invoker import Invoker

from ..intervention import InterventionHandler
if TYPE_CHECKING:
from ..models.mixins import RemoteableMixin
from ..models.NNsightModel import NNsight
Expand All @@ -35,14 +35,16 @@ def __init__(
self,
backend: Backend,
model: "NNsight",
graph: Optional[Graph] = None,
bridge: Optional[Bridge] = None,
method: Optional[str] = None,
validate: bool = False,
graph: Graph = None,
bridge: Bridge = None,
return_context: bool = False,
**kwargs,
) -> None:

self.model = model
self.method = method

self.return_context = return_context

Expand Down Expand Up @@ -86,7 +88,7 @@ def __enter__(self) -> Union[Self, "NNsight", Tuple["NNsight", Self]]:
if isinstance(self.backend, EditBackend):
if self.return_context:
return self.model, self

return self.model

return tracer
Expand All @@ -99,11 +101,10 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:

self.model._envoy._reset()


super().__exit__(exc_type, exc_val, exc_tb)

def invoke(self, *inputs: Any, **kwargs) -> Invoker:
"""Create an Invoker context dor a given input.
"""Create an Invoker context for a given input.
Raises:
Exception: If an Invoker context is already open
Expand All @@ -118,14 +119,28 @@ def invoke(self, *inputs: Any, **kwargs) -> Invoker:

return Invoker(self, *inputs, **kwargs)

def next(self, increment: int = 1) -> None:
"""Increments call_iter of all module Envoys. Useful when doing iterative/generative runs.
def batch(
self, invoker_inputs: Tuple[Tuple[Tuple[Any], Dict[str, Any]]]
) -> Tuple[Tuple[Tuple[Any], Dict[str, Any]], List[Tuple[int, int]]]:

Args:
increment (int): How many call_iter to increment at once. Defaults to 1.
"""
batch_groups = []
batch_start = 0
batched_input = None

for args, kwargs in invoker_inputs:
(args, kwargs), batch_size = self.model._prepare_input(*args, **kwargs)

self.model._envoy.next(increment=increment, propagate=True)
batch_groups.append((batch_start, batch_size))

batched_input = self.model._batch(batched_input, *args, **kwargs)

batch_start += batch_size

if batched_input is None:

batched_input = tuple(tuple(), dict())

return batched_input, batch_groups

##### BACKENDS ###############################

Expand All @@ -141,24 +156,28 @@ def local_backend_execute(self) -> Graph:
if protocols.BridgeProtocol.has_bridge(self.graph):

invoker_inputs = resolve_dependencies(invoker_inputs)

(args, kwargs), batch_groups = self.batch(invoker_inputs)

self.graph.execute()

fn = (
self.model._execute
if self.method is None
else getattr(self.model, self.method)
)

intervention_handler = InterventionHandler(batch_groups=batch_groups)

self.model.interleave(
self.model._execute,
fn,
self.graph,
*invoker_inputs,
*args,
intervention_handler=intervention_handler,
**kwargs,
**self._kwargs,
)

graph = self.graph
graph.alive = False

if not isinstance(graph, weakref.ProxyType):
self.graph = weakref.proxy(graph)

return graph

def edit_backend_execute(self) -> Graph:

self.model._default_graph = self.graph
Expand All @@ -180,17 +199,9 @@ def remote_backend_handle_result_value(self, value: Dict[str, Any]) -> None:
# TODO : graph mismatch handle. hash json ?
for node_name, node_value in value.items():
self.graph.nodes[node_name]._value = node_value

def remote_backend_get_stream_node(self, name: str, graph_id: str) -> "Node":
return self.graph.nodes[name]

def remote_backend_cleanup(self):

graph = self.graph
graph.alive = False

if not isinstance(graph, weakref.ProxyType):
self.graph = weakref.proxy(graph)

def __repr__(self) -> str:
return f"<{self.__class__.__name__} at {hex(id(self))}>"
4 changes: 2 additions & 2 deletions src/nnsight/contexts/session/Session.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def __init__(

def __exit__(self, exc_type, exc_val, exc_tb) -> None:

self.model._session = None

super().__exit__(exc_type, exc_val, exc_tb)

self.model._session = None

def iter(self, iterable: Iterable, **kwargs) -> Iterator:
"""Creates an Iterator context to iteratively execute an intervention graph, with an update item at each iteration.
Expand Down
2 changes: 1 addition & 1 deletion src/nnsight/envoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def __call__(
InterventionProxy: Module call proxy.
"""

if not self._tracing():
if not self._tracing() or self._scanning():
return self._module(*args, **kwargs)

if isinstance(self._tracer.backend, EditBackend):
Expand Down
8 changes: 4 additions & 4 deletions src/nnsight/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,11 @@ def intervene(

# Args for intervention nodes are (module_path, batch_group_idx, call_iter).
_, batch_group_idx, call_iter = node.args

# If this node will be executed for multiple iterations, we need to reset the sub-graph to b executed once more
if call_iter == -1:
node.reset(propagate=True)

# Updates the count of intervention node calls.
# If count matches call_iter, time to inject value into node.
elif call_iter != intervention_handler.count(intervention_node_name):
Expand Down Expand Up @@ -660,13 +660,13 @@ class InterventionHandler:

def __init__(
self,
batch_groups: List[Tuple[int, int]],
batch_groups: List[Tuple[int, int]] = None,
call_counter: Dict[str, int] = None,
graph: Graph = None,
) -> None:

self.graph = graph
self.batch_groups = batch_groups
self.batch_groups = [] if batch_groups is None else batch_groups
self.call_counter: Dict[str, int] = (
defaultdict(lambda: 0) if call_counter is None else call_counter
)
Expand Down
4 changes: 2 additions & 2 deletions src/nnsight/models/DiffusionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _load(self, repo_id: str, device_map=None, **kwargs) -> Diffuser:

return model

def _prepare_inputs(
def _prepare_input(
self,
inputs: Union[str, List[str]],
) -> Any:
Expand All @@ -64,7 +64,7 @@ def _prepare_inputs(

return (inputs,), len(inputs)

def _batch_inputs(
def _batch(
self,
batched_inputs: Optional[Dict[str, Any]],
prepared_inputs: BatchEncoding,
Expand Down
Loading

0 comments on commit f8aee6a

Please sign in to comment.