-
Notifications
You must be signed in to change notification settings - Fork 606
Description
ExportedProgram for
class AvgPool2dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(
kernel_size=6,
)
def forward(self, x):
return self.ap2d(x)
produces the call to AvgPool2d as torch.ops.aten.avg_pool2d.default(x, [6, 6], [6, 6]). This matches with documented behavior for kernel parameter (https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) that states that single integer value will be used for both height, width dimension. As per documentation, the only other possible value for kernel is a tuple of 2 integers. However, tuple of single element works as well:
class AvgPool2dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(
kernel_size=(6,)
)
def forward(self, x):
return self.ap2d(x)
and the ExportedProgram has the call to AvgPool2d as torch.ops.aten.avg_pool2d.default(x, [6], [6]). Note that the kernel value is not being repeated though that's what happens when executing the code in python.
This ExportedProgram causes an assertion when lowering the resulting Torch IR to Tosa/Linalg/Stablehlo as the lowerings assume that kernel is 2-elements.
So I think this can be fixed by either of the following approaches:
- Match the behavior of ExportedProgram for the second scenario to match with the first one. I am not familiar with PyTorch codebase, so not sure where to make the change. If anyone knows where to start looking, I'll appreciate it.
- Fix the individual lowerings but that means repeating the same logic in 3 different places.
- In Torch IR before any of the lowerings (possibly when
DecomposeComplexOpsis called) extend thekernelparam of thetorch.aten.avg_pool2dop to be of correct size, so the individual lowerings don't need to be fixed.
I'm leaning towards 3 (since I don't know how to make 1 work) -- is that the correct approach? If so, which pass will be the correct place to add the logic -- AFAICT none of the existing passes seem to be doing a similar transform where the op is replaced by the same op but with different params. Should I add a new pass?
@sjarus, @vivekkhandelwal1 -- any thoughts? Thanks!