-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtailrec.py
More file actions
130 lines (98 loc) · 3.91 KB
/
tailrec.py
File metadata and controls
130 lines (98 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Tail recursion should be as simple as a decorator
This uses the AST of python to do the rewriting to reduce to existing Tailrec implementation
"""
import ast
from inspect import getsource, isclass
from textwrap import dedent
__all__ = ["ASTTailrec"]
class Tailrec:
def __init__(self, oper):
if isinstance(oper, Tailrec):
self.func = oper.func
else:
self.func = oper
def __call__(self, *args, **kwargs):
tailpos = self.func(*args, **kwargs)
while isinstance(tailpos, Tailcall):
tailpos = tailpos.evaluate()
return tailpos
def recur(self, *args, **kwargs):
return Tailcall(self.func, args, kwargs)
class Tailcall:
def __init__(self, oper, args, kwargs):
if isinstance(oper, Tailrec):
self.func = oper.func
else:
self.func = oper
self.args = args
self.kwargs = kwargs
def evaluate(self):
return self.func(*self.args, **self.kwargs)
class TailTransformer(ast.NodeTransformer):
"""
Replace all the return statements we see that are just recursive function calls
to the same name with `name.recur`
"""
def __init__(self, name, *args, **kwargs):
self.name = name
super(TailTransformer).__init__(*args, **kwargs)
def visit_Return(self, node):
if isinstance(node.value, ast.Call):
if node.value.func.id == self.name:
# Since the dot syntax means that this an attribute call and not a function call,
# we have to create a new node.
node.value.func = ast.Attribute(value=node.value.func, attr="recur", ctx=ast.Load())
# creating a new node doesn't always play nice with compile,
# so fix missing locations automatically
ast.fix_missing_locations(node)
return node
def replace_decorator(func_tree):
"""
The ASTTailrec decorator wants to get rid of itself and replace with Tailrec so we can
reduce this problem to already solved issue.
"""
decorators = func_tree.body[0].decorator_list
for dec in decorators:
if dec.id == 'ASTTailrec':
dec.id = 'Tailrec'
def ASTTailrec(func):
"""
This approach involves modifying the ast tree so we can just stick a decorator on such as
```
@ASTTailrec
def fac(n, k=1):
if n == 1: return k
return fac(n-1, k*n)
```
This function has been heavily inspired by Robin Hillard's pipeop library at
https://github.com/robinhilliard/pipes. It was used as reference when developing this decorator
"""
if isclass(func):
raise TypeError("Cannot apply tail recursion to a class")
in_context = func.__globals__
new_context = {"Tailrec": Tailrec, "Tailcall": Tailcall}
# these need to be included in the imports else we're gonna have some trouble
# if they've already been imported, let that import hold precedence.
new_context.update(in_context)
# now let's try and get the source
source = getsource(func)
# we get the tree
tree = ast.parse(dedent(source))
# update for debugger
first_line_number = func.__code__.co_firstlineno
ast.increment_lineno(tree, first_line_number - 1)
# let's grab the name of the function here. func.__name__ is not reliable in case
# of other decorators and no use of `functools.wraps`
func_name = tree.body[0].name
# we want to replace with the standard tailrec decorator here
replace_decorator(tree)
# now every time we find the function, let's replace with func_name.recur
# as in the standard case
tree = TailTransformer(func_name).visit(tree)
# now the tree has been modified satisfactorily, let's compile
code = compile(tree, filename=new_context['__file__'], mode='exec')
# exec the code in the scope of the new_context
exec(code, new_context)
# and return the function
return new_context[func_name]