Skip to content

Commit

Permalink
Updates for interaction with NDIF server
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Dec 22, 2023
1 parent f6c078a commit dc05d13
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 21 deletions.
10 changes: 6 additions & 4 deletions src/nnsight/contexts/Runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .. import CONFIG, pydantics
from .Invoker import Invoker
from .Tracer import Tracer

from ..logger import logger

class Runner(Tracer):
"""The Runner object manages the intervention tracing for a given model's _generation method or _run_local method.
Expand Down Expand Up @@ -79,6 +79,7 @@ def run_local(self):

def run_server(self):
# Create the pydantic class for the request.

request = pydantics.RequestModel(
args=self.args,
kwargs=self.kwargs,
Expand All @@ -95,8 +96,9 @@ def run_server(self):

def blocking_request(self, request: pydantics.RequestModel):
# Create a socketio connection to the server.
sio = socketio.Client()
sio.connect(f"wss://{CONFIG.API.HOST}", transports=["websocket"])
sio = socketio.Client(logger=logger, reconnection_attempts=5)

sio.connect(f"wss://{CONFIG.API.HOST}", transports=["websocket"], wait_timeout=240)

# Called when receiving a response from the server.
@sio.on("blocking_response")
Expand All @@ -119,7 +121,7 @@ def blocking_response(data):
# Or if there was some error.
elif data.status == pydantics.JobStatus.ERROR:
sio.disconnect()

sio.emit(
"blocking_request",
request.model_dump(exclude_defaults=True, exclude_none=True),
Expand Down
21 changes: 5 additions & 16 deletions src/nnsight/pydantics/Request.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class RequestModel(BaseModel):
kwargs: Dict
model_name: str
batched_input: Union[Any, bytes]
intervention_graph: Union[Graph, bytes, Dict[str, NodeModel]]
intervention_graph: Union[Graph, bytes]
generation: bool
# Edits
# altered
Expand All @@ -27,26 +27,15 @@ class RequestModel(BaseModel):
blocking: bool = False

@field_serializer("intervention_graph")
def intervention_graph_serialize(self, value: Union[str, Graph], _info) -> bytes:
if isinstance(value, Graph):
nodes = dict()
def intervention_graph_serialize(self, value: Graph, _info) -> bytes:
value.compile(None)

for node in value.nodes.values():
node = NodeModel.from_node(node)
nodes[node.name] = node
for node in value.nodes.values():

value = nodes
node.proxy_value = None

return pickle.dumps(value)

@field_serializer("batched_input")
def serialize(self, value, _info) -> bytes:
return pickle.dumps(value)

def graph(self):
graph = Graph(None)

for node in self.intervention_graph.values():
NodeModel.to_node(graph, self.intervention_graph, node)

return graph
1 change: 0 additions & 1 deletion src/nnsight/pydantics/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def _dereference(reference: NodeModel.Reference):
return graph.nodes[node_model.name]

graph.add(
value=None,
target=node_model.target,
args=args,
kwargs=kwargs,
Expand Down
1 change: 1 addition & 0 deletions src/nnsight/tracing/Node.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def compile(self) -> None:
self.remaining_listeners = len(self.listeners)
self.remaining_dependencies = len(self.dependencies)
self.value = inspect._empty
self.meta = dict()

def fulfilled(self) -> bool:
"""Returns true if remaining_dependencies is 0.
Expand Down
8 changes: 8 additions & 0 deletions src/nnsight/tracing/Proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def proxy_call(callable: Callable, *args, **kwargs) -> None:
def __init__(self, node: "Node") -> None:
self.node = node

def __getstate__(self):
return self.__dict__

def __setstate__(self, d: dict):
self.__dict__ = d

def __call__(self, *args, **kwargs) -> Proxy:
"""
Calling a Proxy object normally just creates a Proxy.proxy_call operation. However if this call is a method on the root module proxy, it's assumed that one wishes to trace into the method and therefore trace all operations inside it.
Expand All @@ -39,6 +45,7 @@ def __call__(self, *args, **kwargs) -> Proxy:
if self.node.args[0] is self.node.graph.module_proxy.node and not isinstance(
self.node.proxy_value, torch.nn.Module
):

value = self.node.proxy_value.__func__(
self.node.graph.module_proxy, *args, **kwargs
)
Expand All @@ -65,6 +72,7 @@ def __setitem__(self, key: Union[Proxy, Any], value: Union[Proxy, Any]) -> None:
)

def __getattr__(self, key: Union[Proxy, Any]) -> Proxy:
breakpoint()
return self.node.graph.add(
target=util.fetch_attr,
args=[self.node, key],
Expand Down

0 comments on commit dc05d13

Please sign in to comment.