Skip to content

Commit

Permalink
ndif v2 pydantic updates
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Dec 27, 2023
1 parent 6cf0f64 commit 20ef1fe
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 33 deletions.
27 changes: 15 additions & 12 deletions src/nnsight/contexts/Runner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

import pickle

import socketio

from .. import CONFIG, pydantics
from ..logger import logger
from .Invoker import Invoker
from .Tracer import Tracer
from ..logger import logger


class Runner(Tracer):
"""The Runner object manages the intervention tracing for a given model's _generation method or _run_local method.
Expand Down Expand Up @@ -79,7 +78,7 @@ def run_local(self):

def run_server(self):
# Create the pydantic class for the request.

request = pydantics.RequestModel(
args=self.args,
kwargs=self.kwargs,
Expand All @@ -98,30 +97,34 @@ def blocking_request(self, request: pydantics.RequestModel):
# Create a socketio connection to the server.
sio = socketio.Client(logger=logger)

sio.connect(f"wss://{CONFIG.API.HOST}", socketio_path="/ws/socket.io", transports=["websocket"])
sio.connect(
f"wss://{CONFIG.API.HOST}",
socketio_path="/ws/socket.io",
transports=["websocket"],
)

# Called when receiving a response from the server.
@sio.on("blocking_response")
def blocking_response(data):
# Load the data into the ResponseModel pydantic class.
data: pydantics.ResponseModel = pickle.loads(data)
response = pydantics.ResponseModel(**data)

# Print response for user ( should be logger.info and have an info handler print to stdout)
print(str(data))
print(str(response))

# If the status of the response is completed, update the local nodes that the user specified to save.
# Then disconnect and continue.
if data.status.value == pydantics.ResponseModel.JobStatus.COMPLETED.value:
for name, value in data.saves.items():
if response.status == pydantics.ResponseModel.JobStatus.COMPLETED:
for name, value in response.saves.items():
self.graph.nodes[name].value = value

self.output = data.output
self.output = response.output

sio.disconnect()
# Or if there was some error.
elif data.status.value == pydantics.ResponseModel.JobStatus.ERROR.value:
elif response.status == pydantics.ResponseModel.JobStatus.ERROR:
sio.disconnect()

sio.emit(
"blocking_request",
request.model_dump(exclude_defaults=True, exclude_none=True),
Expand Down
29 changes: 8 additions & 21 deletions src/nnsight/pydantics/Response.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import pickle
from datetime import datetime
from enum import Enum
from typing import Any, Dict
from typing import Any, Union

import requests
from pydantic import BaseModel
from pydantic import BaseModel, field_validator


class ResponseModel(BaseModel):
Expand All @@ -22,9 +21,9 @@ class JobStatus(Enum):
status: JobStatus
description: str

output: Any = None
received: datetime = None
saves: Dict[str, Any] = None
saves: Union[bytes, Any] = None
output: Union[bytes, Any] = None
session_id: str = None
blocking: bool = False

Expand All @@ -39,19 +38,7 @@ def log(self, logger: logging.Logger) -> ResponseModel:

return self

def update_backend(self, client) -> ResponseModel:
responses_collection = client["ndif_database"]["responses"]

from bson.objectid import ObjectId

responses_collection.replace_one(
{"_id": ObjectId(self.id)}, {"bytes": pickle.dumps(self)}, upsert=True
)

return self

def blocking_response(self, api_url: str) -> ResponseModel:
if self.blocking:
requests.get(f"{api_url}/blocking_response/{self.id}")

return self
@field_validator("output", "saves")
@classmethod
def unpickle(cls, value):
return pickle.loads(value)

0 comments on commit 20ef1fe

Please sign in to comment.