-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Behavior
Importing a torch.exported HuggingFace BERT (bert-base-multilingual-uncased) with TVM Relax Torch frontend fails with:
KeyError: 'bert.embeddings.position_ids'
The HF position_ids is a non-persistent buffer (registered with persistent=False) and therefore is omitted from PyTorch’s state_dict. The Relax importer currently assumes it must exist in exported_program.state_dict[spec.target], which raises a KeyError.
torch.export: OK
Traceback (most recent call last):
...
File ".../tvm/relax/frontend/torch/exported_program_translator.py", line 620, in create_input_vars
torch_shape = exported_program.state_dict[spec.target].shape
KeyError: 'bert.embeddings.position_ids'
Environment
- OS: (Ubuntu 22.04.4 LTS (x86_64))
- TVM version: (release v0.21.0)
- Python: (3.10.16)
- LLVM: (17.0.6)
- Pytorch: (2.8.0)
Steps to reproduce
import torch, torch.nn as nn
from transformers import AutoModel
class M(nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModel.from_pretrained("bert-base-multilingual-uncased")
self.cls = nn.Linear(self.bert.config.hidden_size, 2)
def forward(self, x, mask=None):
out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :]
return self.cls(out)
def main():
torch.manual_seed(0)
m = M().eval()
x = torch.randint(0, 30522, (2, 16))
mask = torch.ones_like(x)
from torch.export import export as torch_export
ep = torch_export(m, (x, mask))
print("torch.export: OK")
from tvm.relax.frontend.torch import from_exported_program
mod = from_exported_program(ep) # <-- raises KeyError
if __name__ == "__main__":
main()Triage
- needs-triage
- bug
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug