Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mismatch between saved weights and model description #3

Open
jeremy9959 opened this issue Feb 5, 2024 · 1 comment
Open

Mismatch between saved weights and model description #3

jeremy9959 opened this issue Feb 5, 2024 · 1 comment

Comments

@jeremy9959
Copy link

jeremy9959 commented Feb 5, 2024

The generate.py script won't run because the weights on hugging face are incompatible with the model architecture in the repository.

Here's a greatly simplified part of the file generated.py.

from utils import *
from config import *
from transformers import GPT2Config
import requests
from tqdm import tqdm

filename = "weights.pth"
url = "https://huggingface.co/sander-wood/tunesformer/resolve/main/weights.pth"
response = requests.get(url, stream=True)
total_size = int(response.headers.get("content-length", 0))
chunk_size = 10

with open(filename, "wb") as file, tqdm(
    desc=filename,
    total=total_size,
    unit="B",
    unit_scale=True,
    unit_divisor=1024,
) as bar:
    for data in response.iter_content(chunk_size=chunk_size):
        size = file.write(data)
        bar.update(size)

patchilizer = Patchilizer()
patch_config = GPT2Config(
    num_hidden_layers=PATCH_NUM_LAYERS,
    max_length=PATCH_LENGTH,
    max_position_embeddings=PATCH_LENGTH,
    vocab_size=1,
)
char_config = GPT2Config(
    num_hidden_layers=CHAR_NUM_LAYERS,
    max_length=PATCH_SIZE,
    max_position_embeddings=PATCH_SIZE,
    vocab_size=128,
)
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)

checkpoint = torch.load("weights.pth")
model.load_state_dict(checkpoint["model"])

Result of running this is

weights.pth: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 0.99G/0.99G [06:52<00:00, 2.57MB/s]
Traceback (most recent call last):
  File "/home/jet08013/GitHub/tunesformer/jeremy.py", line 41, in <module>
    model.load_state_dict(checkpoint["model"])
  File "/home/jet08013/anaconda3/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for TunesFormer:
        Unexpected key(s) in state_dict: "patch_level_decoder.base.h.0.attn.bias", 
"patch_level_decoder.base.h.0.attn.masked_bias", 
"patch_level_decoder.base.h.1.attn.bias", "patch_level_decoder.base.h.1.attn.masked_bias", 
"patch_level_decoder.base.h.2.attn.bias", "patch_level_decoder.base.h.2.attn.masked_bias", 
"patch_level_decoder.base.h.3.attn.bias", "patch_level_decoder.base.h.3.attn.masked_bias",
"patch_level_decoder.base.h.4.attn.bias", "patch_level_decoder.base.h.4.attn.masked_bias", 
"patch_level_decoder.base.h.5.attn.bias", "patch_level_decoder.base.h.5.attn.masked_bias",
"patch_level_decoder.base.h.6.attn.bias", "patch_level_decoder.base.h.6.attn.masked_bias", 
"patch_level_decoder.base.h.7.attn.bias", "patch_level_decoder.base.h.7.attn.masked_bias", 
"patch_level_decoder.base.h.8.attn.bias", "patch_level_decoder.base.h.8.attn.masked_bias", 
"char_level_decoder.base.transformer.h.0.attn.bias", "char_level_decoder.base.transformer.h.0.attn.masked_bias", 
"char_level_decoder.base.transformer.h.1.attn.bias", "char_level_decoder.base.transformer.h.1.attn.masked_bias", 
"char_level_decoder.base.transformer.h.2.attn.bias", "char_level_decoder.base.transformer.h.2.attn.masked_bias".

It looks like the saved weights include biases to the attention layers that aren't present in the model description.

@jeremy9959
Copy link
Author

Incidentally if you filter out the extra weights from the state_dict then the program works and seems to generate perfectly nice tunes:

    checkpoint = torch.load("weights.pth")
    fixed_weights = {
        k: v
        for k, v in checkpoint["model"].items()
        if not re.search("\.attn.bias|\.attn.masked_bias", k)
    }
    model.load_state_dict(fixed_weights)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant