diff --git a/src/nnsight/tracing/Node.py b/src/nnsight/tracing/Node.py index 42c672e7..a81458f3 100755 --- a/src/nnsight/tracing/Node.py +++ b/src/nnsight/tracing/Node.py @@ -1,6 +1,7 @@ from __future__ import annotations import inspect +import sys import traceback import weakref from collections import defaultdict @@ -384,25 +385,26 @@ def execute(self) -> None: # Set value. self.set_value(output) - + except protocols.EarlyStopProtocol.EarlyStopException as e: + raise e except Exception as e: if self.graph and self.graph.debug: - if type(e) != protocols.EarlyStopProtocol.EarlyStopException: - if self.attached(): - print(f"\n{self.meta_data['traceback']}\n" + \ - f"NNsightError: {str(e)}.\n") - - # Kill all the graphs in the Session to signal that an error occured - # This way only the traceback of the Node responsible for the error gets printed out - self.graph.alive = False - if protocols.BridgeProtocol.has_bridge(self.graph): - bridge = protocols.BridgeProtocol.get_bridge(self.graph) - - def kill_graph(graph: "Graph") -> None: - """Sets the graph.alive attribute to False""" - graph.alive = False - - [kill_graph(g) for g in bridge.graph_stack] + sys.tracebacklimit = 0 + if self.attached(): + print(f"\n{self.meta_data['traceback']}") + + # Kill all the graphs in the Session to signal that an error occured + # This way only the traceback of the Node responsible for the error gets printed out + self.graph.alive = False + if protocols.BridgeProtocol.has_bridge(self.graph): + bridge = protocols.BridgeProtocol.get_bridge(self.graph) + + def kill_graph(graph: "Graph") -> None: + """Sets the graph.alive attribute to False""" + graph.alive = False + + [kill_graph(g) for g in bridge.graph_stack] + raise util.NNsightError(str(e)) from None else: raise type(e)( f"Above exception occured when executing Node: '{self.name}' in Graph: '{self.graph.id}'" diff --git a/src/nnsight/util.py b/src/nnsight/util.py index 42d52956..67f2bb7d 100755 --- a/src/nnsight/util.py +++ b/src/nnsight/util.py @@ -146,3 +146,8 @@ def forward(self, *args, **kwargs): args = args[0] return args + +class NNsightError(Exception): + """NNsight Execption class for raising error during execution.""" + + pass \ No newline at end of file