Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Urgent] PyTorch GroupNorm and Meshgrid conversion fails for dynamic input #1303

Closed
CanyonWind opened this issue Sep 20, 2021 · 11 comments · Fixed by #1922
Closed

[Urgent] PyTorch GroupNorm and Meshgrid conversion fails for dynamic input #1303

CanyonWind opened this issue Sep 20, 2021 · 11 comments · Fixed by #1922
Labels
bug Unexpected behaviour that should be corrected (type) Flexible Shape PyTorch (traced)

Comments

@CanyonWind
Copy link

CanyonWind commented Sep 20, 2021

🐞Describe the bug

  • When converting a PyTorch model where GN is used and input is dynamic, the GN conversion fails.
  • The problem happens when converting PyTroch traced model -> CoreML
  • It seems like here h and w are specified for integers, but dynamic input model's h and w are placeholder.
  • Any advice/quick hack would be really appreciated.

Trace

Please run the code below to see the error.

To Reproduce

  • Here are the minimum code to reproduce the error
import torch
import torch.nn as nn
import coremltools as ct


class DynamicGN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3,
                         stride=1, padding=1, bias=False)
        self.gn = nn.GroupNorm(num_groups=4, num_channels=16)
    def forward(self, x):
        y = self.gn(self.conv(x))
        return y

def main():
    model = DynamicGN()
    input = torch.ones((1,3,16,16))
    output = model(input)
    traced_model = torch.jit.trace(model, input, check_trace=True)
    img_shape = ct.Shape(shape=(1, 3, ct.RangeDim(8, 128), ct.RangeDim(8, 128)))
    img = ct.TensorType(name='image', shape=img_shape)
    mlmodel = ct.convert(model=traced_model, inputs=[img])

if __name__ == '__main__':
    main()

Error:

ValueError: Cannot add const [1, 4, 4, is2, is3]

System environment (please complete the following information):

  • coremltools version (e.g., 3.0b5): 5.0b2
  • OS (e.g., MacOS, Linux): MacOS 11.3.1
  • macOS version (if applicable):
  • XCode version (if applicable):
  • How you install python (anaconda, virtualenv, system):
  • python version (e.g. 3.7): 3.8.5
  • any other relevant information:
    • PyTorch 1.9.0
@CanyonWind CanyonWind added the bug Unexpected behaviour that should be corrected (type) label Sep 20, 2021
@CanyonWind
Copy link
Author

By any chance, someone can help/share a hint? Thanks!

@aseemw
Copy link
Collaborator

aseemw commented Sep 23, 2021

Hi, Thanks for filing this issue with the reproducible case.

There is definitely a bug in the translation logic of group norm and most likely also an issue in the type inference of the reshape op. We will look into it.

Meanwhile, here is a workaround.

For the code snippet you have, the following updated logic should work. That is, you can update the conversion script as follows:

from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil import register_torch_op
from coremltools.converters.mil.frontend.torch.torch_op_registry  import _TORCH_OPS_REGISTRY
from coremltools.converters.mil.frontend.torch.ops import _std
import builtins

del _TORCH_OPS_REGISTRY["group_norm"]

@register_torch_op
def group_norm(context, node):
    inputs = [context[name] for name in node.inputs]
    x = inputs[0]
    num_groups = inputs[1].val
    weight = inputs[2]
    bias = inputs[3]
    eps = inputs[4]
    n,c,h,w = x.shape[0],x.shape[1],x.shape[2],x.shape[3]
    num_groups = builtins.min(num_groups,c)

    # this is different
    x = mb.expand_dims(x=x, axes=[0])
    x = mb.reshape(x=x, shape=[n, num_groups, c // num_groups, 0, 0])
   
    mean = mb.reduce_mean(x=x, axes=[2,3,4], keep_dims=True)
    var = _std(x,[2,3,4],True,False,eps.val)
    x = mb.sub(x=x,y=mean)
    x = mb.real_div(x=x,y=var)
 
    # this is different    
    x = mb.squeeze(x=x, axes=[0])
    x = mb.reshape(x=x, shape=[n, c, 0, 0])
    
   if weight is not None:
        weight = mb.reshape(x=weight, shape=[1,c,1,1])
        x = mb.mul(x=x,y=weight)
    if bias is not None:
        bias = mb.reshape(x=bias, shape=[1,c,1,1])
        x = mb.add(x=x,y=bias)
    context.add(x,node.name)


class DynamicGN(nn.Module):
     ...

def main():
    ...


if __name__ == '__main__':
    main()

The above script will work, if you also update / suppress this error, by replacing this line with

if 0 in op.shape.val and len(op.shape.val) != op.x.rank:

@CanyonWind
Copy link
Author

Thank you @aseemw! It works smoothly.

I found another problem related to dynamic input - meshgrid also fails for conversion. This might be off the topic from the issue title but help would be really appreciated.

Minimum code to reproduce:

class DynamicMeshgrid(nn.Module):
    def __init__(self):
        super(DynamicMeshgrid, self).__init__()
    
    def forward(self, rows, cols):
        grid_x, grid_y = torch.meshgrid(rows, cols)
        return grid_x, grid_y

def main():
    model = DynamicMeshgrid()
    xy_poses = [torch.linspace(-1, 1, 8), torch.linspace(-1, 1, 10)]
    pytorch_output = model(xy_poses[0], xy_poses[1])
    traced_model = torch.jit.trace(model, xy_poses, check_trace=True)
    pos_shape_1 = pos_shape_2  = ct.Shape(shape=(ct.RangeDim(1, 16),))  # dynamic size fails
    # pos_shape_1 = ct.Shape(shape=(8,))  # fixed size works well
    # pos_shape_2 = ct.Shape(shape=(10,))
    xy_poses_dynamic = [ct.TensorType(name='x_pos', shape=pos_shape_1), ct.TensorType(name='y_pos', shape=pos_shape_2)]
    mlmodel = ct.convert(model=traced_model, inputs=xy_poses_dynamic)
    coreml_output = mlmodel.predict({'x_pos': xy_poses[0].numpy(), 'y_pos': xy_poses[1].numpy()}, useCPUOnly=True)
    print('Done')

Error:

TypeError: cannot determine truth value of Relational

Trace the bug

  • Followed this PR to update meshgrid, it works for fixed input.
  • In meshgrid, repetition times are assumed as fixed integers.
  • Tried to make reps to take symbol as well but it causes the following step mb.tile fails, because reps in tile are also assumed as fixed integers.
  • Then I checked tile implementation and found that maybe we can forward two inputs to tile, one for the feature and the other for the symbolic shape. But I'm stuck here because I'm not sure how to create the symbolic shape and how to feed mb.tile with multiple inputs.

Any suggestion would be appreciated, thanks!

@CanyonWind CanyonWind changed the title [Urgent] PyTorch GroupNorm conversion fails for dynamic input [Urgent] PyTorch GroupNorm and Meshgrid conversion fails for dynamic input Sep 24, 2021
@aseemw
Copy link
Collaborator

aseemw commented Sep 24, 2021

Hmm, the tile op with multiple inputs should work, although the shape of the second input (reps) must be known fully.

example:

import coremltools as ct
from coremltools.converters.mil import Builder as mb
import numpy as np

@mb.program(input_specs=[mb.TensorSpec(shape=(2, 3)), mb.TensorSpec(shape=(2,))])
def prog(x, reps):
    x = mb.tile(x=x, reps=reps)
    return x


model = ct.convert(prog)
model.save("dynamic_tile.mlmodel")

input_dict = {}
input_dict["x"] = np.array([[1, 2, 3] , [4, 5, 6]], dtype=np.float)
input_dict["reps"] = np.array([1, 5], dtype=np.float)
out = model.predict(input_dict)["tile_0"]

print(input_dict["x"])
print(out.shape)
print(out)

[[1. 2. 3.]
 [4. 5. 6.]]
(2, 15)
[[1. 2. 3. 1. 2. 3. 1. 2. 3. 1. 2. 3. 1. 2. 3.]
 [4. 5. 6. 4. 5. 6. 4. 5. 6. 4. 5. 6. 4. 5. 6.]]

@CanyonWind
Copy link
Author

CanyonWind commented Sep 25, 2021

Thanks for the prompt response!

I'm a bit confused here for the example. It seems still takeing one single input for mb.tile, which is the 2x3 numpy array.

Maybe let me rephrase the question. In meshgrid, it takes two 1-d vectors (in my case they should be the dynamic size) v1 and v2. The vectors have shapes (is0,) and (is1,) responsively.

In current coremltools implementation, it first gets dim_tuple by:

# for dynamic shape, dim_tuple = (is0, is1)
dim_tuple = tuple(tensor_var.shape[0] for tensor_var in inputs)

then expands dimension, get reps and tile for each 1-d vector

size = len(inputs)
    for i in range(size):
        view_shape = [1] * size
        view_shape[i] = -1
        view_shape = tuple(view_shape) # (-1, 1) and (1, -1)
        # expand_dim to (is0, 1) and (1, is1)
        view = mb.reshape(
            x=inputs[i], shape=view_shape, name=node.name + "_view_" + str(I)
        )
        # get reps
        # For dynamic shape, ds = is0/is1 which breaks below list comprehension 
        # Tried to replace below if condition to -> if (not isinstance(ds, int) or ds > 0) and ts == 1. 
        # Then reps become [1, is1] and [is0, 1] but mb.tile breaks as reps are not all integers but contains symbo
        reps = [
            ds if ds > 0 and ts == 1 else 1 for ts, ds in zip(view.shape, dim_tuple)
        ]
        
        # tile according to reps
        # Ideally for dynamic shape, reps = (1, is1) and (is0, 1) but it doesn't work here for symbolic input
        # This is where I have no clue further
        expand = mb.tile(x=view, reps=reps, name=node.name + "_expand_" + str(i))
        grids.append(expand)

To shorten the question, is there any way to do tile or tile-like function on vector1 according to the dynamic shape of vector2 and vice versa?

Thanks for the patience!

@CanyonWind
Copy link
Author

Could anyone please share some hints on fixing the above 👆 meshgrid with dynamic input sizes? Would really appreciate it.

@iamgeo92
Copy link

@aseemw Your fix for groupnorm does not work on coremltools==6.2 .

I get this error now:

op_mapping.py", line 1926, in reshape
    raise ValueError(msg)
ValueError: Use 0 in shape only if len(shape) == x.rank. Report bug.

Use this to reproduce

import torch
import torch.nn as nn
import coremltools as ct

from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil import register_torch_op
from coremltools.converters.mil.frontend.torch.torch_op_registry  import _TORCH_OPS_REGISTRY
from coremltools.converters.mil.frontend.torch.ops import _std
import builtins

del _TORCH_OPS_REGISTRY["group_norm"]

@register_torch_op
def group_norm(context, node):
    inputs = [context[name] for name in node.inputs]
    x = inputs[0]
    num_groups = inputs[1].val
    weight = inputs[2]
    bias = inputs[3]
    eps = inputs[4]
    n,c,h,w = x.shape[0],x.shape[1],x.shape[2],x.shape[3]
    num_groups = builtins.min(num_groups,c)

    # this is different
    x = mb.expand_dims(x=x, axes=[0])
    x = mb.reshape(x=x, shape=[n, num_groups, c // num_groups, 0, 0])
   
    mean = mb.reduce_mean(x=x, axes=[2,3,4], keep_dims=True)
    var = _std(x,[2,3,4],True,False,eps.val)
    x = mb.sub(x=x,y=mean)
    x = mb.real_div(x=x,y=var)
 
    # this is different    
    x = mb.squeeze(x=x, axes=[0])
    x = mb.reshape(x=x, shape=[n, c, 0, 0])
    
    if weight is not None:
        weight = mb.reshape(x=weight, shape=[1,c,1,1])
        x = mb.mul(x=x,y=weight)
    if bias is not None:
        bias = mb.reshape(x=bias, shape=[1,c,1,1])
        x = mb.add(x=x,y=bias)
    context.add(x,node.name)




class DynamicGN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3,
                         stride=1, padding=1, bias=False)
        self.gn = nn.GroupNorm(num_groups=4, num_channels=16)
    def forward(self, x):
        y = self.gn(self.conv(x))
        return y

def main():
    model = DynamicGN()
    input = torch.ones((1,3,16,16))
    output = model(input)
    traced_model = torch.jit.trace(model, input, check_trace=True)
    img_shape = ct.Shape(shape=(1, 3, ct.RangeDim(8, 128), ct.RangeDim(8, 128)))
    img = ct.TensorType(name='image', shape=img_shape)
    mlmodel = ct.convert(model=traced_model, inputs=[img])

if __name__ == '__main__':
    main()

@CanyonWind can you pls check

cc - @bhushan23 ,

@Volutionn
Copy link

I'm facing the exact same error with Meshgrid. Any update regarding that?

TypeError: cannot determine truth value of Relational

@sercand
Copy link
Contributor

sercand commented Jul 21, 2023

to fix the GroupNorm, I did use the following code on coremltools 7.0b1 by editing

def group_norm(context, node):

from coremltools.converters.mil.frontend.torch.torch_op_registry import _TORCH_OPS_REGISTRY, register_torch_op
from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil.frontend.torch.ops import _get_inputs, _std
from coremltools.converters.mil.mil.types.symbolic import is_symbolic,any_symbolic
import builtins

del _TORCH_OPS_REGISTRY["group_norm"]

@register_torch_op
def group_norm(context, node):
    inputs = _get_inputs(context, node, expected=6)
    x = inputs[0]
    num_groups = inputs[1].val
    weight = inputs[2]
    bias = inputs[3]
    eps = inputs[4]
    n,c = x.shape[0],x.shape[1] # at minimum (N, C) required
    input_shape = [*x.shape] # n, c, *
    num_groups = builtins.min(num_groups,c)
    new_shape = [n, num_groups, c//num_groups]
    
    # Create the new_shape and input_shape to support dynamic sizes.
    if not any_symbolic(x.shape[2:]):
        new_shape += [*x.shape[2:]] # adds remaining dims
    else:
        ss = mb.shape(x=x)
        for i,v in enumerate(x.shape[2:]):
            if is_symbolic(v):
                x1 = mb.gather(x=ss, indices=i+2, axis=0)
                new_shape.append(x1)
            else:
                new_shape.append(v)
        new_shape = mb.concat(values=new_shape, axis=0)

    if any_symbolic(input_shape):
        ss = mb.shape(x=x)
        for i,v in enumerate(input_shape):
            if is_symbolic(v):
                x1 = mb.gather(x=ss, indices=i, axis=0)
                input_shape[i] = x1
        input_shape = mb.concat(values=input_shape, axis=0)

    num_extra_axes = len(x.shape[2:])
    axes_ = [int(i) for i in range(2, 2 + num_extra_axes + 1)]
    weight_shape, bias_shape = [1,c], [1,c]
    weight_shape += [1 for _ in range(num_extra_axes)]
    bias_shape += [1 for _ in range(num_extra_axes)]
    x = mb.reshape(x = x, shape=new_shape)
    mean = mb.reduce_mean(x=x, axes=axes_, keep_dims=True)
    var = _std(x,axes_,True,False,eps.val)
    x = mb.sub(x=x,y=mean)
    x = mb.real_div(x=x,y=var)
    x = mb.reshape(x=x, shape=input_shape)
    if weight is not None:
        weight = mb.reshape(x=weight, shape=weight_shape)
        x = mb.mul(x=x,y=weight)
    if bias is not None:
        bias = mb.reshape(x=bias, shape=bias_shape)
        x = mb.add(x=x,y=bias)
    context.add(x,node.name)

I come this idea by looking into how torch.reshape is converted with dynamic shape on other places.

@TobyRoseman
Copy link
Collaborator

@sercand - can you put up a pull request with your fix and a unit test?

@sercand
Copy link
Contributor

sercand commented Jul 26, 2023

@sercand - can you put up a pull request with your fix and a unit test?

I will be working on it.
edit: opened #1922 for the issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type) Flexible Shape PyTorch (traced)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants