Skip to content

Commit

Permalink
Catch exceptions in some context __exit__ methods and raise them so _…
Browse files Browse the repository at this point in the history
…_exit__ logit is not run
  • Loading branch information
JadenFiotto-Kaufman committed Nov 30, 2023
1 parent dca0226 commit 4c74345
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/nnsight/contexts/Invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def __enter__(self) -> Invoker:
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
pass
if isinstance(exc_val, BaseException):
raise exc_val

def next(self, increment: int = 1) -> None:
"""Designates subsequent interventions should be applied to the next generation for multi-iteration generation runs.
Expand Down
6 changes: 4 additions & 2 deletions src/nnsight/contexts/Runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Runner(Tracer):
def __init__(
self,
*args,
generation:bool = False,
generation: bool = False,
blocking: bool = True,
remote: bool = False,
**kwargs,
Expand All @@ -59,6 +59,8 @@ def __enter__(self) -> Runner:

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""On exit, run and generate using the model whether locally or on the server."""
if isinstance(exc_val, BaseException):
raise exc_val
if self.remote:
self.run_server()
else:
Expand All @@ -83,7 +85,7 @@ def run_server(self):
model_name=self.model.repoid_path_clsname,
batched_input=self.batched_input,
intervention_graph=self.graph,
generation=self.generation
generation=self.generation,
)

if self.blocking:
Expand Down

0 comments on commit 4c74345

Please sign in to comment.