Skip to content

Commit

Permalink
Merge pull request #300 from ndif-team/hacks2.0
Browse files Browse the repository at this point in the history
Hacks2.0
  • Loading branch information
JadenFiotto-Kaufman authored Dec 5, 2024
2 parents 8672ddb + d0b36a4 commit 95854a0
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 88 deletions.
9 changes: 7 additions & 2 deletions src/nnsight/tracing/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ..graph import Graph, Proxy
from ..protocols import StopProtocol


class Backend:

def __call__(self, graph: Graph) -> None:
Expand All @@ -25,7 +24,13 @@ def __call__(self, graph: Graph) -> None:
graph.nodes[-1].execute()

if self.injection:
frame = inspect.currentframe().f_back.f_back.f_back.f_back

from ..contexts import Context

frame = inspect.currentframe().f_back
while frame.f_back is not None and 'self' in frame.f_locals and isinstance(frame.f_locals['self'], Context):
frame = frame.f_back

for key, value in frame.f_locals.items():
if isinstance(value, Proxy) and value.node.done:
frame.f_locals[key] = value.value
Expand Down
10 changes: 8 additions & 2 deletions src/nnsight/tracing/contexts/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,20 @@ def __init__(
super().__init__(*args, **kwargs)

self.args = [condition, branch]
self.index = None

def else_(self, condition: Optional[Any] = None):

return Condition(
condition,
branch=self.graph.nodes[self.graph[-1].index + 1],
branch=self.graph.nodes[self.index],
parent=self.graph.stack[-1],
)

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
super().__exit__(exc_type, exc_val, exc_tb)

self.index = self.graph.nodes[-1].index

@classmethod
def execute(cls, node: NodeType):
Expand Down
3 changes: 3 additions & 0 deletions src/nnsight/tracing/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def clean(self, start: Optional[int] = None):
Args:
start (Optional[int], optional): `Node` index to start cleaning up from. Defaults to None.
"""

if len(self) == 0:
return

if start is None:
start = self[0].index
Expand Down
8 changes: 4 additions & 4 deletions src/nnsight/tracing/graph/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,17 +286,17 @@ def __iter__(self) -> Iterator[Self]:
raise Exception('Iteration control flow encountered but "CONFIG.APP.CONTROL_FLOW_HACKS" is set to False')

from ..hacks import iterator

return iterator.handle_iterator(inspect.currentframe().f_back, self)
return iterator.handle_proxy(inspect.currentframe().f_back, self)

def __bool__(self) -> Self:

if not CONFIG.APP.CONTROL_FLOW_HACKS:
raise Exception('Conditional control flow encountered but "CONFIG.APP.CONTROL_FLOW_HACKS" is set to False')

from ..hacks import conditional

return conditional.handle_conditional(inspect.currentframe().f_back, self)
return conditional.handle_proxy(inspect.currentframe().f_back, self)

def __instancecheck__(self, __instance: Any) -> bool:
return self.node.fake_value.__instancecheck__(__instance)
Expand Down
23 changes: 23 additions & 0 deletions src/nnsight/tracing/hacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import ast
from types import FrameType

from ..graph import Graph
from .conditional import handle as handle_conditional
from .iterator import handle as handle_iterator


def handle_inner(node:ast.stmt, frame: FrameType, graph: Graph):

if isinstance(node, ast.If):

handle_conditional(node, frame, graph)

return True

elif isinstance(node, ast.For):

handle_iterator(node, frame, graph)

return True

return False
67 changes: 67 additions & 0 deletions src/nnsight/tracing/hacks/comprehension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import ast
import ctypes
import inspect
import sys
from types import FrameType
from typing import TYPE_CHECKING

from ..contexts import Iterator
from ..graph import Graph
from .util import execute, execute_body, execute_until, visit

if TYPE_CHECKING:
from ..graph import Proxy

COMPS = [ast.SetComp, ast.DictComp, ast.ListComp, ast.GeneratorExp]

def handle(node: ast.For, frame: FrameType, graph: Graph):

iter_expr = ast.Expression(
body=node.iter, lineno=node.lineno, col_offset=node.col_offset
)

iter = execute(iter_expr, frame)

context = Iterator(iter, parent=graph)

target = node.target

with context as item:
if isinstance(target, ast.Name):
frame.f_locals[target.id] = item
elif isinstance(target, ast.Tuple):
for t, v in zip(target.elts, item):
if isinstance(t, ast.Name):
frame.f_locals[t.id] = v

ctypes.pythonapi.PyFrame_LocalsToFast(ctypes.py_object(frame), 0)

execute_body(node.body, frame, context.graph)


def handle_proxy(node:ast.stmt, frame: FrameType, collection:"Proxy"):

graph = collection.node.graph


iterator = Iterator(collection, parent=graph)

item = iterator.__enter__()

def callback(new_frame:FrameType, list_proxy, iterator:Iterator):


key, result = next(iter(new_frame.f_locals.items()))
print(node, node.elt.elt.ctx.__dict__)

# list_proxy.append(result[0])

# new_frame.f_locals[key] = list_proxy
# ctypes.pythonapi.PyFrame_LocalsToFast(ctypes.py_object(new_frame), 0)

iterator.__exit__(None, None, None)


execute_until(frame.f_lineno -1, frame.f_lineno - 1, frame, callback= lambda new_frame: callback(new_frame, [], iterator))

return iter([item])
90 changes: 42 additions & 48 deletions src/nnsight/tracing/hacks/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,43 @@
from types import FrameType
from typing import TYPE_CHECKING

from ..contexts import Condition, Context
from .util import execute, execute_body, execute_until

from ..contexts import Condition
from .util import execute, execute_body, execute_until, visit
from ..graph import Graph
if TYPE_CHECKING:
from ..graph import Proxy

def get_else(node: ast.If):

return (
node.orelse[0]
if isinstance(node.orelse[0], ast.If)
else ast.If(
test=ast.Constant(value=None),
body=node.orelse,
orelse=[],
lineno=node.lineno,
col_offset=node.col_offset,
)
)

def handle(node: ast.If, frame:FrameType, graph:Graph, branch:Condition = None):

condition_expr = ast.Expression(
body=node.test, lineno=node.lineno, col_offset=node.col_offset
)

condition = execute(condition_expr, frame)

context = Condition(condition, parent = graph) if branch is None else branch.else_(condition)

def handle_conditional(frame: FrameType, condition: "Proxy"):
with context as branch:
execute_body(node.body, frame, branch.graph)

line_no = frame.f_lineno
source_lines, _ = inspect.getsourcelines(frame)
source = "".join(source_lines)
tree = ast.parse(source)
if node.orelse:
return handle(get_else(node), frame, graph, branch)

def handle_proxy(frame: FrameType, condition: "Proxy"):

class Visitor(ast.NodeVisitor):
def __init__(self, line_no):
Expand All @@ -28,53 +52,23 @@ def visit_If(self, node):
self.target = node
self.generic_visit(node)

visitor = Visitor(line_no)
visitor.visit(tree)

if_node = visitor.target
visitor = visit(frame, Visitor)

if_node:ast.If = visitor.target

graph = condition.node.graph

branch = Condition(condition, parent=graph)

def get_else(node: ast.If):

return (
node.orelse[0]
if isinstance(node.orelse[0], ast.If)
else ast.If(
test=ast.Constant(value=None),
body=node.orelse,
orelse=[],
lineno=node.lineno,
col_offset=node.col_offset,
)
)

def evaluate_and_execute(node: ast.stmt):

nonlocal branch

if isinstance(node, ast.If):

condition_expr = ast.Expression(
body=node.test, lineno=node.lineno, col_offset=node.col_offset
)

condition = execute(condition_expr, frame)

with branch.else_(condition) as branch:
execute_body(node.body, frame)

if node.orelse:
return evaluate_and_execute(get_else(node))

def callback(node: ast.If, context: Context, frame: FrameType):
def callback(node: ast.If, frame: FrameType, graph:Graph, branch:Condition):

branch.__exit__(None, None, None)

if node.orelse:
evaluate_and_execute(get_else(if_node))
handle(get_else(if_node), frame, graph, branch)

branch.__enter__()
execute_until(branch, if_node, frame, callback=callback)
end = frame.f_lineno + (if_node.end_lineno - if_node.lineno)
execute_until(frame.f_lineno, end, frame, callback=lambda _: callback(if_node, frame, graph, branch))

return True
return True
Loading

0 comments on commit 95854a0

Please sign in to comment.