Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid fp32 cast for Torch div operator #2241

Closed
wants to merge 3 commits into from
Closed

Conversation

HennerM
Copy link

@HennerM HennerM commented Jun 17, 2024

The div Torch op was always casting both operands to fp32, even if both operands are of type fp16. This cast should get removed by the "common::add_fp16_cast" optimization pass. However, it causes issues during the PyTorch conversion, for example let's say we have a forward method like this:

class Foo:
    def __init__(self):
        super().__init__()
        self.proj = torch.nn.Linear(16, 1)
    def forward(self, x, y): # both fp16 tensors, shape [1, 16]
        r = x / y # r is now fp32
        return self.proj(r) # Problem

Now if we have moved the model (and it's parameters) to fp16 with eg. m = Foo().to(torch.float16), we get an error at conversion time:

In op, of type linear, named linear_0, the named input bias must have the same data type as the named input x. However, bias has dtype fp16 whereas x has dtype fp32.

This is because the result of the div operation stays fp32, and this doesn't match the resulting type of the PyTorch expression.

HennerM added 3 commits June 17, 2024 14:52
The `div` Torch op was always casting both operands to fp32, even if both operands are of type fp16. This cast should get removed by the `"common::add_fp16_cast"` optimization pass. However, it causes issues during the PyTorch conversion, for example let's say we have a forward method like this:

```python
class Foo:
    def __init__(self):
        super().__init__()
        self.proj = torch.nn.Linear(16, 1)
    def forward(self, x, y): # both fp16 tensors, shape [1, 16]
        r = x / y # r is now fp32
        return self.proj(r) # Problem
```

Now if we have moved the model (and it's parameters) to fp16 with eg. `m = Foo().to(torch.float16)`, we get an error at conversion time:

> In op, of type linear, named linear_0, the named input `bias` must have the same data type as the named input `x`. However, bias has dtype fp16 whereas x has dtype fp32.

This is because the result of the `div` operation stays fp32, and this doesn't match the resulting type of the PyTorch expression.
@TobyRoseman
Copy link
Collaborator

Please add a unit test to test_torch_ops.py which fails without your fix but passes with your fix.

@jeethu
Copy link

jeethu commented Jul 11, 2024

I stumbled across the same issue and managed to debug it. It turns out the root cause isn't the div op. This happens because the torch converter casts inputs to fp32 here.

Here's a minimal repro that does not use the div op and still fails with the same error:

import coremltools as ct
import numpy as np
import torch


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.proj = torch.nn.Linear(16, 1)

    def forward(self, x):
        return self.proj(x)


x = torch.randn(1, 16, dtype=torch.float16)

with torch.no_grad():
    mlmodel = ct.convert(
        torch.jit.trace(Net().half().eval(), x),
        inputs=[ct.TensorType(name="x", shape=x.shape, dtype=np.float16)],
        outputs=[ct.TensorType(name="output")],
        convert_to="mlprogram",
        compute_precision=ct.precision.FLOAT16,
        minimum_deployment_target=ct.target.iOS17,
    )

This fails with the same exception as the snippet above with the div op:

ValueError: In op, of type linear, named linear_0, the named input `bias` must have the same data type as the named input `weight`. However, bias has dtype fp16 whereas weight has dtype fp32.

I've got a fix for this in #2274.

@YifanShenSZ
Copy link
Collaborator

Hi @HennerM, inputs=[ct.TensorType(dtype=np.float16)] and compute_precision=ct.precision.FLOAT16 are enough to obtain a fp16-input fp16-computation Core ML model. There is no need to make the PyTorch model itself fp16

Concretely, internally we translate torch model in fp32. Then,

  • If given compute_precision=ct.precision.FLOAT16, we will insert fp16 casts to make computation (i.e. weight & activation) fp16
  • If given inputs=[ct.TensorType(name="x", shape=x.shape, dtype=np.float16)], we will change input signature for x to fp16

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants