From 4c74345d163215266824a0856cd5375d1041392c Mon Sep 17 00:00:00 2001 From: JadenFiottoKaufman Date: Wed, 29 Nov 2023 23:31:25 -0500 Subject: [PATCH] Catch exceptions in some context __exit__ methods and raise them so __exit__ logit is not run --- src/nnsight/contexts/Invoker.py | 3 ++- src/nnsight/contexts/Runner.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/nnsight/contexts/Invoker.py b/src/nnsight/contexts/Invoker.py index 84899dd9..2e16ff8b 100644 --- a/src/nnsight/contexts/Invoker.py +++ b/src/nnsight/contexts/Invoker.py @@ -72,7 +72,8 @@ def __enter__(self) -> Invoker: return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: - pass + if isinstance(exc_val, BaseException): + raise exc_val def next(self, increment: int = 1) -> None: """Designates subsequent interventions should be applied to the next generation for multi-iteration generation runs. diff --git a/src/nnsight/contexts/Runner.py b/src/nnsight/contexts/Runner.py index 449c38dc..65717aa9 100644 --- a/src/nnsight/contexts/Runner.py +++ b/src/nnsight/contexts/Runner.py @@ -43,7 +43,7 @@ class Runner(Tracer): def __init__( self, *args, - generation:bool = False, + generation: bool = False, blocking: bool = True, remote: bool = False, **kwargs, @@ -59,6 +59,8 @@ def __enter__(self) -> Runner: def __exit__(self, exc_type, exc_val, exc_tb) -> None: """On exit, run and generate using the model whether locally or on the server.""" + if isinstance(exc_val, BaseException): + raise exc_val if self.remote: self.run_server() else: @@ -83,7 +85,7 @@ def run_server(self): model_name=self.model.repoid_path_clsname, batched_input=self.batched_input, intervention_graph=self.graph, - generation=self.generation + generation=self.generation, ) if self.blocking: