Skip to content

Commit

Permalink
merge 0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamBelfki3 committed Dec 5, 2024
2 parents d27f3d0 + 95854a0 commit 93495f5
Show file tree
Hide file tree
Showing 26 changed files with 738 additions and 176 deletions.
21 changes: 16 additions & 5 deletions src/nnsight/intervention/contexts/tracer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import inspect
import weakref
from functools import wraps
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
TypeVar, Union)
from typing import Any, Callable, Dict, Optional, TypeVar, Union

from ...tracing.contexts import Tracer
from ...tracing.graph import Proxy
from ..graph import (InterventionNodeType, InterventionProxy,
InterventionProxyType)
from . import LocalContext
from ... import CONFIG


class InterventionTracer(Tracer[InterventionNodeType, InterventionProxyType]):
"""Extension of base Tracer to add additional intervention functionality and type hinting for intervention proxies.
"""
Expand Down Expand Up @@ -48,3 +44,18 @@ def inner(*args, **kwargs):

# TODO: error
pass

@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.
Returns:
- Dict: dictionary style.
"""

default_style = super().style()

default_style["node"] = {"color": "purple", "shape": "polygon", "sides": 6}
default_style["arg_kname"][1] = "method"

return default_style
41 changes: 26 additions & 15 deletions src/nnsight/intervention/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def context_dependency(
self,
context_node: InterventionNode,
intervention_subgraphs: List[SubGraph],
):
) -> None:

context_graph: SubGraph = context_node.args[0]

Expand All @@ -85,25 +85,26 @@ def context_dependency(

for intervention_subgraph in intervention_subgraphs:

if intervention_subgraph.subset[-1] < start:
continue

if intervention_subgraph.subset[0] > end:
# continue if the subgraph does not overlap with the context's graph
if intervention_subgraph.subset[-1] < start or end < intervention_subgraph.subset[0]:
continue

for intervention_index in intervention_subgraph.subset:

if intervention_index >= start and intervention_index <= end:
# if there's an overlapping node, make the context depend on the intervention node in the subgraph
if start <= intervention_index and intervention_index <= end:

# the first node in the subgraph is an InterventionProtocol node
intervention_node = intervention_subgraph[0]

context_node._dependencies.add(intervention_node.index)
intervention_node._listeners.add(context_node.index)
# TODO: maybe we don't need this
intervention_subgraph.subset.append(context_node.index)

break

def compile(self) -> None:
def compile(self) -> Optional[Dict[str, List[InterventionNode]]]:

if self.compiled:
return self.interventions
Expand All @@ -114,10 +115,14 @@ def compile(self) -> None:
intervention_subgraphs: List[SubGraph] = []

start = self[0].index

if isinstance(self[0].target, type) and issubclass(self[0].target, Context):
# is the first node corresponding to an executable graph?
# occurs when a Conditional or Iterator context is explicitly entered by a user
if isinstance(self[0].target, type) and issubclass(
self[0].target, Context
):
graph = self[0].args[0]

# handle emtpy if statments or for loops
if len(graph) > 0:
start = graph[0].index

Expand All @@ -127,10 +132,12 @@ def compile(self) -> None:
defer_start: int = None
context_node: InterventionNode = None

# looping over all the nodes created within this graph's context
for index in range(start, end):

node: InterventionNodeType = self.nodes[index]

# is this node part of an inner context's subgraph?
if context_node is None and node.graph is not self:

context_node = self.nodes[node.graph[-1].index + 1]
Expand All @@ -142,7 +149,8 @@ def compile(self) -> None:
self.context_dependency(context_node, intervention_subgraphs)

if node.target is InterventionProtocol:


# build intervention subgraph
subgraph = SubGraph(self, subset=sorted(list(node.subgraph())))

module_path, *_ = node.args
Expand All @@ -151,10 +159,13 @@ def compile(self) -> None:

intervention_subgraphs.append(subgraph)

# if the InterventionProtocol is defined within a sub-context
if context_node is not None:


# make the current context node dependent on this intervention node
context_node._dependencies.add(node.index)
node._listeners.add(context_node.index)
# TODO: maybe we don't need this
self.subset.append(node.index)

graph: SubGraph = node.graph
Expand All @@ -164,13 +175,13 @@ def compile(self) -> None:
node.kwargs["start"] = context_start
node.kwargs["defer_start"] = defer_start

node.graph = self

else:

node.kwargs["start"] = self.subset.index(subgraph.subset[0])
node.kwargs["defer_start"] = node.kwargs["start"]

node.graph = self

elif node.target is GradProtocol:

subgraph = SubGraph(self, subset=sorted(list(node.subgraph())))
Expand All @@ -191,12 +202,12 @@ def compile(self) -> None:

node.kwargs["start"] = context_start

node.graph = self

else:

node.kwargs["start"] = self.subset.index(subgraph.subset[1])

node.graph = self

elif node.target is ApplyModuleProtocol:

node.graph = self
Expand Down
29 changes: 21 additions & 8 deletions src/nnsight/intervention/protocols/grad.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict

import torch

from ...tracing.protocols import Protocol
from ...tracing.graph import GraphType

if TYPE_CHECKING:
from ..graph import InterventionNode, InterventionNodeType

Expand Down Expand Up @@ -36,11 +37,7 @@ def execute(cls, node: "InterventionNode") -> None:
hook = None

def grad(value):


# print(backwards_iteration, node)



# Set the value of the Node.
node.set_value(value)

Expand All @@ -57,4 +54,20 @@ def grad(value):
return value

# Register hook.
hook = tensor.register_hook(grad)
hook = tensor.register_hook(grad)

@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.
Returns:
- Dict: dictionary style.
"""

default_style = super().style()

default_style["node"] = {"color": "green4", "shape": "box"}

return default_style


22 changes: 19 additions & 3 deletions src/nnsight/intervention/protocols/intervention.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import TYPE_CHECKING, Any, List

from typing import TYPE_CHECKING, Any, Dict

import torch
from ... import util
from .entrypoint import EntryPoint

if TYPE_CHECKING:
from ..graph import InterventionNodeType
from ..interleaver import Interleaver
from ..graph import InterventionNodeType, InterventionProxyType, InterventionGraph, InterventionProxy, InterventionNode

class InterventionProtocol(EntryPoint):

Expand Down Expand Up @@ -197,3 +196,20 @@ def narrow(acts: torch.Tensor):
def execute(cls, node: "InterventionNodeType"):
# To prevent the node from looking like its executed when calling Graph.execute
node.executed = False

@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.
Returns:
- Dict: dictionary style.
"""

default_style = super().style()

default_style["node"] = {"color": "green4", "shape": "box"}
default_style["arg_kname"][0] = "module_path"
default_style["arg_kname"][1] = "batch_group"
default_style["arg_kname"][2] = "call_counter"

return default_style
24 changes: 18 additions & 6 deletions src/nnsight/intervention/protocols/module.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import TYPE_CHECKING, Any, Dict

import inspect
from typing import TYPE_CHECKING
import torch
from ...tracing.protocols import Protocol
from typing_extensions import Self
from ...tracing.graph import SubGraph

from ... import util
from ...tracing.protocols import Protocol

if TYPE_CHECKING:

from ..graph import InterventionProxyType, InterventionNode, InterventionGraph
from ..graph import InterventionGraph, InterventionNode

class ApplyModuleProtocol(Protocol):
"""Protocol that references some root model, and calls its .forward() method given some input.
Expand Down Expand Up @@ -92,3 +91,16 @@ def execute(cls, node: "InterventionNode") -> None:

node.set_value(output)

@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.
Returns:
- Dict: dictionary style.
"""

default_style = super().style()

default_style["node"] = {"color": "green4", "shape": "polygon", "sides": 6}

return default_style
21 changes: 16 additions & 5 deletions src/nnsight/intervention/protocols/swap.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Dict

import torch
from ...tracing.protocols import Protocol
from ... import util

if TYPE_CHECKING:
from ..graph import InterventionNodeType, InterventionGraph, InterventionProxyType
from ..graph import InterventionNodeType


class SwapProtocol(Protocol):
Expand All @@ -22,4 +21,16 @@ def execute(cls, node: "InterventionNodeType") -> None:

intervention_node.kwargs['swap'] = value


@classmethod
def style(cls) -> Dict[str, Any]:
"""Visualization style for this protocol node.
Returns:
- Dict: dictionary style.
"""

default_style = super().style()

default_style["node"] = {"color": "green4", "shape": "ellipse"}

return default_style
9 changes: 7 additions & 2 deletions src/nnsight/tracing/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ..graph import Graph, Proxy
from ..protocols import StopProtocol


class Backend:

def __call__(self, graph: Graph) -> None:
Expand All @@ -25,7 +24,13 @@ def __call__(self, graph: Graph) -> None:
graph.nodes[-1].execute()

if self.injection:
frame = inspect.currentframe().f_back.f_back.f_back.f_back

from ..contexts import Context

frame = inspect.currentframe().f_back
while frame.f_back is not None and 'self' in frame.f_locals and isinstance(frame.f_locals['self'], Context):
frame = frame.f_back

for key, value in frame.f_locals.items():
if isinstance(value, Proxy) and value.node.done:
frame.f_locals[key] = value.value
Expand Down
9 changes: 6 additions & 3 deletions src/nnsight/tracing/contexts/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations

from contextlib import AbstractContextManager
from typing import Any, Callable, Generic, Optional, Type, Union
from typing import Generic, Optional, Type

from typing_extensions import Self

from ... import CONFIG
from ...tracing.graph import Node, NodeType, Proxy, ProxyType
from ..backends import Backend, ExecutionBackend
from ..graph import Graph, GraphType, SubGraph
from ..graph import Graph, GraphType, SubGraph, viz_graph
from ..protocols import Protocol
from ... import CONFIG

class Context(Protocol, AbstractContextManager, Generic[GraphType]):
"""A `Context` represents a scope (or slice) of a computation graph with specific logic for adding and executing nodes defined within it.
Expand Down Expand Up @@ -81,6 +81,9 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:

self.backend(graph)

def vis(self, *args, **kwargs):
viz_graph(self.graph, *args, **kwargs)

@classmethod
def execute(cls, node: NodeType):

Expand Down
Loading

0 comments on commit 93495f5

Please sign in to comment.