Skip to content

Commit

Permalink
Merge pull request #293 from ndif-team/vllm-tp-2
Browse files Browse the repository at this point in the history
VLLM w/ Tensor Parallelism
  • Loading branch information
JadenFiotto-Kaufman authored Dec 5, 2024
2 parents 95854a0 + 93495f5 commit 9290dd0
Show file tree
Hide file tree
Showing 11 changed files with 494 additions and 272 deletions.
2 changes: 1 addition & 1 deletion src/nnsight/intervention/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union)
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union)

import torch
from typing_extensions import Self
Expand Down
107 changes: 82 additions & 25 deletions src/nnsight/intervention/graph/graph.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import copy
import sys
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union

from typing_extensions import Self

from ...tracing.contexts import Context
from ...tracing.graph import SubGraph
from ...util import NNsightError
from ..protocols import ApplyModuleProtocol, GradProtocol, InterventionProtocol
from . import InterventionNode, InterventionNodeType, InterventionProxyType
from ...util import NNsightError

if TYPE_CHECKING:
from .. import NNsight
from ..tracing.graph.graph import GraphType, NodeType


class InterventionGraph(SubGraph[InterventionNode, InterventionProxyType]):
Expand All @@ -33,15 +37,32 @@ def __init__(
) -> None:

super().__init__(*args, **kwargs)

self.model = model

self.interventions: Dict[str, List[InterventionNode]] = defaultdict(list)
self.grad_subgraph: Set[int] = set()

self.compiled = False
self.call_counter: Dict[int, int] = defaultdict(int)
self.deferred:Dict[int, List[int]] = defaultdict(list)
self.deferred: Dict[int, List[int]] = defaultdict(list)

def __getstate__(self) -> Dict:

return {
"subset": self.subset,
"nodes": self.nodes,
"interventions": self.interventions,
"compiled": self.compiled,
"call_counter": self.call_counter,
"deferred": self.deferred,
"grad_subgraph": self.grad_subgraph,
"defer_stack": self.defer_stack,
}

def __setstate__(self, state: Dict) -> None:

self.__dict__.update(state)

def reset(self) -> None:
self.call_counter = defaultdict(int)
Expand Down Expand Up @@ -89,12 +110,11 @@ def compile(self) -> Optional[Dict[str, List[InterventionNode]]]:
return self.interventions

if len(self.nodes) == 1:
return
return

intervention_subgraphs: List[SubGraph] = []

start = self[0].index

# is the first node corresponding to an executable graph?
# occurs when a Conditional or Iterator context is explicitly entered by a user
if isinstance(self[0].target, type) and issubclass(
Expand All @@ -119,11 +139,11 @@ def compile(self) -> Optional[Dict[str, List[InterventionNode]]]:

# is this node part of an inner context's subgraph?
if context_node is None and node.graph is not self:

context_node = self.nodes[node.graph[-1].index + 1]

context_start = self.subset.index(context_node.index)

defer_start = node.index

self.context_dependency(context_node, intervention_subgraphs)
Expand Down Expand Up @@ -197,18 +217,24 @@ def compile(self) -> Optional[Dict[str, List[InterventionNode]]]:

self.compiled = True

def execute(self, start: int = 0, grad: bool = False, defer:bool=False, defer_start:int=0) -> None:

def execute(
self,
start: int = 0,
grad: bool = False,
defer: bool = False,
defer_start: int = 0,
) -> None:

err: Tuple[int, NNsightError] = None

if defer_start in self.deferred:

for index in self.deferred[defer_start]:

self.nodes[index].reset()

del self.deferred[defer_start]

if defer:

self.defer_stack.append(defer_start)
Expand All @@ -217,7 +243,9 @@ def execute(self, start: int = 0, grad: bool = False, defer:bool=False, defer_st

if node.executed:
continue
elif node.index != self[start].index and node.target is InterventionProtocol:
elif (
node.index != self[start].index and node.target is InterventionProtocol
):
break
elif node.fulfilled:
try:
Expand All @@ -231,7 +259,7 @@ def execute(self, start: int = 0, grad: bool = False, defer:bool=False, defer_st
continue
else:
break

if defer:
self.defer_stack.pop()

Expand All @@ -242,9 +270,7 @@ def execute(self, start: int = 0, grad: bool = False, defer:bool=False, defer_st
self.defer_stack = defer_stack
raise err[1]

def count(
self, index: int, iteration: Union[int, List[int], slice]
) -> bool:
def count(self, index: int, iteration: Union[int, List[int], slice]) -> bool:
"""Increments the count of times a given Intervention Node has tried to be executed and returns if the Node is ready and if it needs to be deferred.
Args:
Expand Down Expand Up @@ -286,7 +312,7 @@ def count(
self.call_counter[index] += 1

return ready, defer

def clean(self, start: Optional[int] = None):

if start is None:
Expand All @@ -296,15 +322,14 @@ def clean(self, start: Optional[int] = None):

# Loop over ALL nodes within the span of this graph.
for index in range(start, end):

node = self.nodes[index]

if node.executed:
break

node.update_dependencies()



def cleanup(self) -> None:
"""Because some modules may be executed more than once, and to accommodate memory management just like a loop,
intervention graph sections defer updating the remaining listeners of Nodes if this is not the last time this section will be executed.
Expand All @@ -330,6 +355,38 @@ def cleanup(self) -> None:
if dependency.redundant:
dependency.destroy()

def copy(
self,
new_graph: Self = None,
parent: Optional["GraphType"] = None,
memo: Optional[Dict[int, "NodeType"]] = None,
) -> Self:

if memo is None:
memo = {}

new_graph = super().copy(new_graph, parent=parent, memo=memo)

new_graph.compiled = self.compiled

for key, value in self.call_counter.items():
self.call_counter[memo[key]] = value

if new_graph.compiled:

for module_path, list_of_nodes in self.interventions.items():

new_graph.interventions[module_path] = [
new_graph.nodes[memo[node.index]] for node in list_of_nodes
]

for key, values in self.deferred.items():

new_graph[memo[key]] = [memo[index] for index in values]

new_graph.grad_subgraph = [memo[index] for index in self.grad_subgraph]

return new_graph

# @classmethod
# def shift(cls, mgraph: MultiGraph) -> MultiGraph:
Expand Down
12 changes: 6 additions & 6 deletions src/nnsight/intervention/interleaver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ def __init__(

if input_hook is None:
input_hook = (
lambda activations, module_path: InterventionProtocol.intervene(
activations, module_path, "input", self
lambda activations, module_path, module: InterventionProtocol.intervene(
activations, module_path, module, "input", self
)
)

if output_hook is None:
output_hook = (
lambda activations, module_path: InterventionProtocol.intervene(
activations, module_path, "output", self
lambda activations, module_path, module: InterventionProtocol.intervene(
activations, module_path, module, "output", self
)
)

Expand Down Expand Up @@ -96,7 +96,7 @@ def __enter__(self) -> Interleaver:
# Input hook activations are a tuple of (positional args, key-word arguments)
# Include the module_path not the module
def input_hook(module, input, kwargs, module_path=module_path):
return self.input_hook((input, kwargs), module_path)
return self.input_hook((input, kwargs), module_path, module)

self.handles.append(
module.register_forward_pre_hook(
Expand All @@ -107,7 +107,7 @@ def input_hook(module, input, kwargs, module_path=module_path):
elif hook_type == "output":

def output_hook(module, input, output, module_path=module_path):
return self.output_hook(output, module_path)
return self.output_hook(output, module_path, module)

self.handles.append(
module.register_forward_hook(output_hook, prepend=True)
Expand Down
1 change: 1 addition & 0 deletions src/nnsight/intervention/protocols/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def intervene(
cls,
activations: Any,
module_path: str,
module: torch.nn.Module,
key: str,
interleaver: "Interleaver",
):
Expand Down
Loading

0 comments on commit 9290dd0

Please sign in to comment.