Skip to content

Commit

Permalink
We need to always concatenate theeintervened/narrowed values in order…
Browse files Browse the repository at this point in the history
… to have the narrowed value included in the backward computation graph, therefore allowing retain_grad() to save the grad.
  • Loading branch information
JadenFiotto-Kaufman committed Dec 21, 2023
1 parent 31143b9 commit e33b2e6
Showing 1 changed file with 50 additions and 58 deletions.
108 changes: 50 additions & 58 deletions src/nnsight/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,6 @@ def save(self) -> InterventionProxy:

return self

def retain_grad(self):
self.node.graph.add(target=torch.Tensor.retain_grad, args=[self.node])

# We need to set the values of self to values of self to add this into the computation graph so grad flows through it
# This is because in intervene(), we call .narrow on activations which removes it from the grad path
self[:] = self

@property
def token(self) -> TokenIndexer:
"""Property used to do token based indexing on a proxy.
Expand Down Expand Up @@ -137,52 +130,42 @@ def value(self) -> Any:
return self.node.value


def check_swap(graph: Graph, activations: Any, batch_start: int, batch_size: int):
# If swap is populated due to a 'swp' intervention.
if graph.swap is not None:

def concat(values):
if isinstance(values[0], torch.Tensor):
return torch.concatenate(values)
elif isinstance(values[0], list) or isinstance(values[0], tuple):
return [
concat([value[value_idx] for value in values])
def concat(activations: Any, value: Any, batch_start: int, batch_size: int):
def _concat(values):
if isinstance(values[0], torch.Tensor):
return torch.concatenate(values)
elif isinstance(values[0], list):
return [
_concat([value[value_idx] for value in values])
for value_idx in range(len(values[0]))
]
elif isinstance(values[0], tuple):
return tuple(
[
_concat([value[value_idx] for value in values])
for value_idx in range(len(values[0]))
]
elif isinstance(values[0], dict):
return {
key: concat([value[key] for value in values])
for key in values[0].keys()
}

# As interventions are scoped only to their relevant batch, if we want to swap in values for this batch
# we need to concatenate the batches before and after the relevant batch with the new values.
# Getting batch data before.
pre = util.apply(
activations, lambda x: x.narrow(0, 0, batch_start), torch.Tensor
)
post_batch_start = batch_start + batch_size
# Getting batch data after.
post = util.apply(
activations,
lambda x: x.narrow(0, post_batch_start, x.shape[0] - post_batch_start),
torch.Tensor,
)

# Second argument of 'swp' interventions is the new value.
# Convert all Nodes in the value to their value.
value = util.apply(graph.swap.args[1], lambda x: x.value, Node)

# Concatenate
activations = concat([pre, value, post])

# Set value of 'swp' node so it destroys itself and listeners.
graph.swap.set_value(True)

# Un-set swap.
graph.swap = None

return activations
)
elif isinstance(values[0], dict):
return {
key: _concat([value[key] for value in values])
for key in values[0].keys()
}

# As interventions are scoped only to their relevant batch, if we want to swap in values for this batch
# we need to concatenate the batches before and after the relevant batch with the new values.
# Getting batch data before.
pre = util.apply(activations, lambda x: x.narrow(0, 0, batch_start), torch.Tensor)
post_batch_start = batch_start + batch_size
# Getting batch data after.
post = util.apply(
activations,
lambda x: x.narrow(0, post_batch_start, x.shape[0] - post_batch_start),
torch.Tensor,
)

# Concatenate
return _concat([pre, value, post])


def intervene(activations: Any, module_path: str, graph: Graph, key: str):
Expand Down Expand Up @@ -223,17 +206,26 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str):
_, batch_size, batch_start = node.args

# We set its result to the activations, indexed by only the relevant batch idxs.
node.set_value(
util.apply(
activations,
lambda x: x.narrow(0, batch_start, batch_size),
torch.Tensor,
)
value = util.apply(
activations,
lambda x: x.narrow(0, batch_start, batch_size),
torch.Tensor,
)

node.set_value(value)

# Check if through the previous value injection, there was a 'swp' intervention.
# This would mean we want to replace activations for this batch with some other ones.
activations = check_swap(graph, activations, batch_start, batch_size)
if graph.swap is not None:
value = util.apply(graph.swap.args[1], lambda x: x.value, Node)

# Set value of 'swp' node so it destroys itself and listeners.
graph.swap.set_value(True)

# Un-set swap.
graph.swap = None

activations = concat(activations, value, batch_start, batch_size)

return activations

Expand Down

0 comments on commit e33b2e6

Please sign in to comment.