Skip to content

Commit

Permalink
Merge pull request #18 from JadenFiotto-Kaufman/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JadenFiotto-Kaufman authored Dec 16, 2023
2 parents 4bf35b2 + e2d0ce0 commit 9f9f86c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
21 changes: 13 additions & 8 deletions src/nnsight/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def value(self) -> Any:


def check_swap(graph: Graph, activations: Any, batch_start: int, batch_size: int):
# If swap is populated due to a 'swp' intervention.
if graph.swap is not None:

def concat(values):
Expand All @@ -154,27 +155,31 @@ def concat(values):
for key in values[0].keys()
}

# As interventions are scoped only to their relevant batch, if we want to swap in values for this batch
# we need to concatenate the batches before and after the relevant batch with the new values.
# Getting batch data before.
pre = util.apply(
activations, lambda x: x.narrow(0, 0, batch_start), torch.Tensor
)
post_batch_start = batch_start + batch_size
# Getting batch data after.
post = util.apply(
activations,
lambda x: x.narrow(0, post_batch_start, x.shape[0] - post_batch_start),
torch.Tensor,
)

def get_value(node: Node):
value = node.value

node.set_value(True)

return value

value = util.apply(graph.swap, get_value, Node)
# Second argument of 'swp' interventions is the new value.
# Convert all Nodes in the value to their value.
value = util.apply(graph.swap.args[1], lambda x: x.value, Node)

# Concatenate
activations = concat([pre, value, post])

# Set value of 'swp' node so it destroys itself and listeners.
graph.swap.set_value(True)

# Un-set swap.
graph.swap = None

return activations
Expand Down
4 changes: 2 additions & 2 deletions src/nnsight/tracing/Graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Graph:
module_proxy (Proxy): Proxy for given root meta module.
argument_node_names (Dict[str, List[str]]): Map of name of argument to name of nodes that depend on it.
generation_idx (int): Current generation index.
swap (Any): Attribute to store swap values from 'swp' nodes.
swap (Node): Attribute to store swap values from 'swp' nodes.
"""

@staticmethod
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(

self.generation_idx = 0

self.swap: Any = None
self.swap: Node = None

def increment(self) -> None:
"""Increments the generation_idx by one. Should be called by a forward hook on the model being used for generation."""
Expand Down
4 changes: 2 additions & 2 deletions src/nnsight/tracing/Node.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def execute(self) -> None:
if self.target == "null":
return
elif self.target == "swp":
self.graph.swap = self.args[1]
self.graph.swap = self

return

Expand Down Expand Up @@ -230,7 +230,7 @@ def set_value(self, value: Any):
if dependency.redundant():
dependency.destroy()

if self.value is not None and self.redundant():
if self.value is not inspect._empty and self.redundant():
self.destroy()

def destroy(self) -> None:
Expand Down

0 comments on commit 9f9f86c

Please sign in to comment.