Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssa self logic doesn't work with list member #61

Open
leonardt opened this issue Aug 21, 2020 · 6 comments
Open

ssa self logic doesn't work with list member #61

leonardt opened this issue Aug 21, 2020 · 6 comments

Comments

@leonardt
Copy link
Owner

leonardt commented Aug 21, 2020

This code:

import inspect

from ast_tools.passes import apply_ast_passes, ssa, loop_unroll
from ast_tools.macros import unroll


class Foo:
    def __init__(self):
        self.x = [i for i in range(4)]

    @apply_ast_passes([loop_unroll(), ssa()])
    def __call__(self, I: int, shift: bool):
        O = self.x[-1]
        if shift:
            for i in unroll(range(2, -1, -1)):
                self.x[i + 1] = self.regs[i]
            self.x[0] = I
        return O

print(inspect.getsource(Foo.__call__))

produces

def __call__(self, I: int, shift: bool):
    O0 = self.x[-1]
    self.x[2 + 1] = self.regs[2]
    self.x[1 + 1] = self.regs[1]
    self.x[0 + 1] = self.regs[0]
    self.x[0] = I
    __return_value0 = O0
    return __return_value0

notice the shift input is ignored in the rewritten call

If instead we use explicit attributes for each list member, it works as expected.

import inspect

from ast_tools.passes import apply_ast_passes, ssa, loop_unroll
from ast_tools.macros import unroll


class Foo:
    def __init__(self):
        self.x0 = 0
        self.x1 = 1
        self.x2 = 2
        self.x3 = 3

    @apply_ast_passes([loop_unroll(), ssa()])
    def __call__(self, I: int, shift: bool):
        O = self.x3
        if shift:
            self.x3 = self.x2
            self.x2 = self.x1
            self.x1 = self.x0
            self.x0 = I
        return O

print(inspect.getsource(Foo.__call__))

produces

def __call__(self, I: int, shift: bool):
    self_x30 = self.x3
    self_x20 = self.x2
    self_x10 = self.x1
    self_x00 = self.x0
    O0 = self_x30
    self_x31 = self_x20
    self_x21 = self_x10
    self_x11 = self_x00
    self_x01 = I
    self_x02 = self_x01 if shift else self_x00
    self_x12 = self_x11 if shift else self_x10
    self_x22 = self_x21 if shift else self_x20
    self_x32 = self_x31 if shift else self_x30
    __return_value0 = O0
    self.x3 = self_x32
    self.x2 = self_x22
    self.x1 = self_x12
    self.x0 = self_x02
    return __return_value0

@cdonovick I can look into implementing this, but have any guidance on whether/how we could support this?

@cdonovick
Copy link
Collaborator

I think its incredibly tricky. I think some sort variable name interpolation to make using explicit attribute ergonomic would probably be a better solution

@leonardt
Copy link
Owner Author

Hmm, yea thought about this some more. It seems that in sequential we don't actually support real python lists, it's more just syntax sugar on a fixed size array (since the call method can't dynamically change the list contents). And so, if we're using that restriction, then we can "unroll" the array into scalar values as a pass before entering SSA.

@leonardt
Copy link
Owner Author

Or we can just provide a symbol interpolation feature so the user can effectively do the same thing (again, this would avoid the issue about requiring the user to manage this conversion themselves)

@rdaly525
Copy link
Collaborator

Not sure if this would help, but I think Torch has a similar issue and resolves it by using a custom list. https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html

@leonardt
Copy link
Owner Author

Yea I think we could define a syntax layer that allows you to use a list abstraction inside init (e.g. for generator code that builds/appends to a list based on some parameter), then inside call we can assume the list size is fixed (i.e. the user can only index the list to refer to elements), so at that point we can "desugar" the list into elements (for the simplicity of the subsequent pass logic). So the special list object can be used differently in different stages (or perhaps another way to view it is inside call the user has a restricted view of the list, which could be defined by a custom list type). There might be some concern about code size explosion, but I think for the common case it should be fine.

@rsetaluri
Copy link
Collaborator

From @cdonovick

basically this https://github.com/leonardt/ast_tools/blob/master/ast_tools/passes/ssa.py#L546-L567

would have to be extended with similar logic for array indices

def foo(arr):
    arr[0] = 5
    if cond:
       arr[1] = 2

becomes:

def foo(arr):
    arr_0 = arr[0]
    arr_1 = arr[1]
    arr_0_0 = 5
    arr_1_0 = 2
    arr_1_1 = arr_1_0 if cond else arr_1
    arr[0] = arr_0
    arr[1] = arr_1_1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants