Skip to content

Commit

Permalink
Merge pull request #30 from JadenFiotto-Kaufman/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JadenFiotto-Kaufman authored Dec 22, 2023
2 parents ffd0d9e + fed3c36 commit f177ccb
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 84 deletions.
38 changes: 33 additions & 5 deletions src/nnsight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,48 @@ def repeat_interleave(
)


def cpu_wrapper(fn):
def noop_wrapper(fn):
@wraps(fn)
def cpu(input: torch.Tensor, *args, **kwargs):
def noop(input: torch.Tensor, *args, **kwargs):
if input.device.type == "meta":
return input

else:
return fn(input, *args, **kwargs)

return cpu
return noop


DEFAULT_PATCHER.add(Patch(torch.Tensor, cpu_wrapper(torch.Tensor.cpu), "cpu"))
DEFAULT_PATCHER.add(Patch(torch.Tensor, noop_wrapper(torch.Tensor.cpu), "cpu"))

def onehot_wrapper(fn):
@wraps(fn)
def onehot(input: torch.Tensor, num_classes=-1):
if input.device.type == "meta":
return torch.zeros((*input.shape, num_classes), device='meta')

else:
return fn(input, num_classes=num_classes)

return onehot


DEFAULT_PATCHER.add(Patch(torch.nn.functional, onehot_wrapper(torch.nn.functional.one_hot), "one_hot"))

def where_wrapper(fn):
@wraps(fn)
def where(input: torch.Tensor, *args, **kwargs):
if input.device.type == "meta":
return input.to(torch.int)

else:
return fn(input, *args, **kwargs)

return where

DEFAULT_PATCHER.add(Patch(torch, where_wrapper(torch.where), "where"))

DEFAULT_PATCHER.add(Patch(torch.Tensor, noop_wrapper(torch.Tensor.tolist), "tolist"))

DEFAULT_PATCHER.__enter__()

Expand All @@ -95,5 +124,4 @@ def activate_recent_meta():
def local_scalar_dense_meta(A):
return 0


activate_recent_meta()
10 changes: 6 additions & 4 deletions src/nnsight/contexts/Runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .. import CONFIG, pydantics
from .Invoker import Invoker
from .Tracer import Tracer

from ..logger import logger

class Runner(Tracer):
"""The Runner object manages the intervention tracing for a given model's _generation method or _run_local method.
Expand Down Expand Up @@ -79,6 +79,7 @@ def run_local(self):

def run_server(self):
# Create the pydantic class for the request.

request = pydantics.RequestModel(
args=self.args,
kwargs=self.kwargs,
Expand All @@ -95,8 +96,9 @@ def run_server(self):

def blocking_request(self, request: pydantics.RequestModel):
# Create a socketio connection to the server.
sio = socketio.Client()
sio.connect(f"wss://{CONFIG.API.HOST}", transports=["websocket"])
sio = socketio.Client(logger=logger, reconnection_attempts=5)

sio.connect(f"wss://{CONFIG.API.HOST}", transports=["websocket"], wait_timeout=240)

# Called when receiving a response from the server.
@sio.on("blocking_response")
Expand All @@ -119,7 +121,7 @@ def blocking_response(data):
# Or if there was some error.
elif data.status == pydantics.JobStatus.ERROR:
sio.disconnect()

sio.emit(
"blocking_request",
request.model_dump(exclude_defaults=True, exclude_none=True),
Expand Down
108 changes: 50 additions & 58 deletions src/nnsight/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,6 @@ def save(self) -> InterventionProxy:

return self

def retain_grad(self):
self.node.graph.add(target=torch.Tensor.retain_grad, args=[self.node])

# We need to set the values of self to values of self to add this into the computation graph so grad flows through it
# This is because in intervene(), we call .narrow on activations which removes it from the grad path
self[:] = self

@property
def token(self) -> TokenIndexer:
"""Property used to do token based indexing on a proxy.
Expand Down Expand Up @@ -137,52 +130,42 @@ def value(self) -> Any:
return self.node.value


def check_swap(graph: Graph, activations: Any, batch_start: int, batch_size: int):
# If swap is populated due to a 'swp' intervention.
if graph.swap is not None:

def concat(values):
if isinstance(values[0], torch.Tensor):
return torch.concatenate(values)
elif isinstance(values[0], list) or isinstance(values[0], tuple):
return [
concat([value[value_idx] for value in values])
def concat(activations: Any, value: Any, batch_start: int, batch_size: int):
def _concat(values):
if isinstance(values[0], torch.Tensor):
return torch.concatenate(values)
elif isinstance(values[0], list):
return [
_concat([value[value_idx] for value in values])
for value_idx in range(len(values[0]))
]
elif isinstance(values[0], tuple):
return tuple(
[
_concat([value[value_idx] for value in values])
for value_idx in range(len(values[0]))
]
elif isinstance(values[0], dict):
return {
key: concat([value[key] for value in values])
for key in values[0].keys()
}

# As interventions are scoped only to their relevant batch, if we want to swap in values for this batch
# we need to concatenate the batches before and after the relevant batch with the new values.
# Getting batch data before.
pre = util.apply(
activations, lambda x: x.narrow(0, 0, batch_start), torch.Tensor
)
post_batch_start = batch_start + batch_size
# Getting batch data after.
post = util.apply(
activations,
lambda x: x.narrow(0, post_batch_start, x.shape[0] - post_batch_start),
torch.Tensor,
)

# Second argument of 'swp' interventions is the new value.
# Convert all Nodes in the value to their value.
value = util.apply(graph.swap.args[1], lambda x: x.value, Node)

# Concatenate
activations = concat([pre, value, post])

# Set value of 'swp' node so it destroys itself and listeners.
graph.swap.set_value(True)

# Un-set swap.
graph.swap = None

return activations
)
elif isinstance(values[0], dict):
return {
key: _concat([value[key] for value in values])
for key in values[0].keys()
}

# As interventions are scoped only to their relevant batch, if we want to swap in values for this batch
# we need to concatenate the batches before and after the relevant batch with the new values.
# Getting batch data before.
pre = util.apply(activations, lambda x: x.narrow(0, 0, batch_start), torch.Tensor)
post_batch_start = batch_start + batch_size
# Getting batch data after.
post = util.apply(
activations,
lambda x: x.narrow(0, post_batch_start, x.shape[0] - post_batch_start),
torch.Tensor,
)

# Concatenate
return _concat([pre, value, post])


def intervene(activations: Any, module_path: str, graph: Graph, key: str):
Expand Down Expand Up @@ -223,17 +206,26 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str):
_, batch_size, batch_start = node.args

# We set its result to the activations, indexed by only the relevant batch idxs.
node.set_value(
util.apply(
activations,
lambda x: x.narrow(0, batch_start, batch_size),
torch.Tensor,
)
value = util.apply(
activations,
lambda x: x.narrow(0, batch_start, batch_size),
torch.Tensor,
)

node.set_value(value)

# Check if through the previous value injection, there was a 'swp' intervention.
# This would mean we want to replace activations for this batch with some other ones.
activations = check_swap(graph, activations, batch_start, batch_size)
if graph.swap is not None:
value = util.apply(graph.swap.args[1], lambda x: x.value, Node)

# Set value of 'swp' node so it destroys itself and listeners.
graph.swap.set_value(True)

# Un-set swap.
graph.swap = None

activations = concat(activations, value, batch_start, batch_size)

return activations

Expand Down
21 changes: 5 additions & 16 deletions src/nnsight/pydantics/Request.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class RequestModel(BaseModel):
kwargs: Dict
model_name: str
batched_input: Union[Any, bytes]
intervention_graph: Union[Graph, bytes, Dict[str, NodeModel]]
intervention_graph: Union[Graph, bytes]
generation: bool
# Edits
# altered
Expand All @@ -27,26 +27,15 @@ class RequestModel(BaseModel):
blocking: bool = False

@field_serializer("intervention_graph")
def intervention_graph_serialize(self, value: Union[str, Graph], _info) -> bytes:
if isinstance(value, Graph):
nodes = dict()
def intervention_graph_serialize(self, value: Graph, _info) -> bytes:
value.compile(None)

for node in value.nodes.values():
node = NodeModel.from_node(node)
nodes[node.name] = node
for node in value.nodes.values():

value = nodes
node.proxy_value = None

return pickle.dumps(value)

@field_serializer("batched_input")
def serialize(self, value, _info) -> bytes:
return pickle.dumps(value)

def graph(self):
graph = Graph(None)

for node in self.intervention_graph.values():
NodeModel.to_node(graph, self.intervention_graph, node)

return graph
1 change: 0 additions & 1 deletion src/nnsight/pydantics/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def _dereference(reference: NodeModel.Reference):
return graph.nodes[node_model.name]

graph.add(
value=None,
target=node_model.target,
args=args,
kwargs=kwargs,
Expand Down
1 change: 1 addition & 0 deletions src/nnsight/tracing/Node.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def compile(self) -> None:
self.remaining_listeners = len(self.listeners)
self.remaining_dependencies = len(self.dependencies)
self.value = inspect._empty
self.meta = dict()

def fulfilled(self) -> bool:
"""Returns true if remaining_dependencies is 0.
Expand Down
7 changes: 7 additions & 0 deletions src/nnsight/tracing/Proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def proxy_call(callable: Callable, *args, **kwargs) -> None:
def __init__(self, node: "Node") -> None:
self.node = node

def __getstate__(self):
return self.__dict__

def __setstate__(self, d: dict):
self.__dict__ = d

def __call__(self, *args, **kwargs) -> Proxy:
"""
Calling a Proxy object normally just creates a Proxy.proxy_call operation. However if this call is a method on the root module proxy, it's assumed that one wishes to trace into the method and therefore trace all operations inside it.
Expand All @@ -39,6 +45,7 @@ def __call__(self, *args, **kwargs) -> Proxy:
if self.node.args[0] is self.node.graph.module_proxy.node and not isinstance(
self.node.proxy_value, torch.nn.Module
):

value = self.node.proxy_value.__func__(
self.node.graph.module_proxy, *args, **kwargs
)
Expand Down

0 comments on commit f177ccb

Please sign in to comment.