Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calling model.save_adapter() after training does not correctly save adapter parameters. #489

Open
2 of 4 tasks
jdwx opened this issue Feb 7, 2023 · 1 comment
Open
2 of 4 tasks
Labels
bug Something isn't working external-dependency Related to a third-party package or service

Comments

@jdwx
Copy link

jdwx commented Feb 7, 2023

Environment info

  • adapter-transformers version: 3.2.0a0 (2af89bd)
  • Platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.35
  • Python version: 3.10.6
  • Huggingface_hub version: 0.12.0
  • PyTorch version (GPU?): 1.13.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: Deepspeed

Information

Model I am using (Bert, XLNet ...): GPT-J-6B

Language I am using the model on (English, Chinese ...): English

Adapter setup I am using (if any): Default from add_adapter()

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

A simple script to try out training a large model adapter using Deepspeed.

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

I would like to add Adapters support to my GPT-J fine-tuning example code. For people working with consumer hardware, I suspect that creating adapters fits their resources and use cases better than fine-tuning the entire model.

To reproduce

Steps to reproduce the behavior:

  1. Create a GPTJAdapterModel.
  2. Add an adapter and a causal lm head, and activate the adapter.
  3. Train the adapter for one epoch using Deepspeed.
  4. Save the adapter.

I will include relevant examples below, as even the minimal code is over 100 lines.

Expected behavior

Saving an adapter after training should save the trained adapter's parameters suitable for future loading.

Instead, it seems that although the adapter clearly works during training, saving it subsequently does not work.

If I comment out the trainer.train() line in the example program shown below, the saved adapter looks like this:

-rw-rw-r-- 1 jdw jdw      1031 Feb  7 13:05 adapter_config.json
-rw-rw-r-- 1 jdw jdw       385 Feb  7 13:05 head_config.json
-rw-rw-r-- 1 jdw jdw 117723377 Feb  7 13:05 pytorch_adapter.bin
-rw-rw-r-- 1 jdw jdw 412877644 Feb  7 13:05 pytorch_model_head.bin

This is correct/reasonable, and the adapter can be loaded and used for inference, so everything seems to work. Though it produces gibberish because it's just untrained random weights:

When it comes to Transformers, Adapters are an exciting [ Packagehex socio shortestrating ]
When it comes to Transformers, Adapters are an exciting [ glorious equations equations lawmaker 900 ]
When it comes to Transformers, Adapters are an exciting [ corruption Vice withdrawischerTW ]

If I uncomment the trainer.train() line, the adapter appears to train well, with consistently decreasing loss for many epochs.

But after training, the saved adapter is much too small:

-rw-rw-r-- 1 jdw jdw   1031 Feb  7 13:22 adapter_config.json
-rw-rw-r-- 1 jdw jdw    385 Feb  7 13:22 head_config.json
-rw-rw-r-- 1 jdw jdw 280305 Feb  7 13:22 pytorch_adapter.bin
-rw-rw-r-- 1 jdw jdw    780 Feb  7 13:22 pytorch_model_head.bin

Predictably, loading this adapter with the same code that worked above produces errors:

Traceback (most recent call last):
  File "/home/jdw/gptj-finetune-adapters/use-adapter.py", line 39, in <module>
    use_adapter()
  File "/home/jdw/gptj-finetune-adapters/use-adapter.py", line 14, in use_adapter
    adapter_name = model.load_adapter("out/")
  File "/home/jdw/gptj-finetune-adapters/venv/lib/python3.10/site-packages/transformers/adapters/model_mixin.py", line 1122, in load_adapter
    return super().load_adapter(
  File "/home/jdw/gptj-finetune-adapters/venv/lib/python3.10/site-packages/transformers/adapters/model_mixin.py", line 701, in load_adapter
    load_dir, load_name = loader.load(
  File "/home/jdw/gptj-finetune-adapters/venv/lib/python3.10/site-packages/transformers/adapters/loading.py", line 459, in load
    missing_keys, _ = self.weights_helper.load_weights(
  File "/home/jdw/gptj-finetune-adapters/venv/lib/python3.10/site-packages/transformers/adapters/loading.py", line 158, in load_weights
    missing_keys, unexpected_keys = self._load_module_state_dict(
  File "/home/jdw/gptj-finetune-adapters/venv/lib/python3.10/site-packages/transformers/adapters/loading.py", line 117, in _load_module_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for GPTJAdapterModel:
	size mismatch for transformer.h.0.output_adapters.adapters.example.adapter_down.0.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([256, 4096]).
	size mismatch for transformer.h.0.output_adapters.adapters.example.adapter_up.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 256]).
	size mismatch for transformer.h.1.output_adapters.adapters.example.adapter_down.0.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([256, 4096]).
	size mismatch for transformer.h.1.output_adapters.adapters.example.adapter_up.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 256]).

(many more size mismatch lines omitted)

I would be sure this is a bug on my end, except that the same code for creating, saving, and loading the adapter works as expected if the training is commented out. It still may well be my error, but in case it isn't, I figured I should report it.

Sample Code

The training script (run as deepspeed make-adapter.py):

import pickle
from transformers.adapters import GPTJAdapterModel
from transformers import AdapterTrainer, GPTJConfig, TrainingArguments


deepspeed_args = {
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 12,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },
    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": False
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": False
        },
        "overlap_comm": True,
        "contiguous_gradients": True,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "gather_16bit_weights_on_model_save": True
    },
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": False
}

training_args = TrainingArguments(
    deepspeed=deepspeed_args,
    do_train=True,
    do_eval=True,
    fp16=True,
    evaluation_strategy="epoch",
    gradient_checkpointing=True,
    output_dir="out",
    overwrite_output_dir=True,
    eval_steps=1,
    warmup_steps=10,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=32,
    num_train_epochs=1,
    learning_rate=1e-04,
    save_strategy="no",
)


def get_data_from_pickle(file_path: str):
    with open(file_path, "rb") as f:
        try:
            while True:
                yield pickle.load(f)
        except EOFError:
            pass


def get_data_from_pickle_with_labels(file_path: str):
    for chunk in get_data_from_pickle(file_path):
        chunk["labels"] = chunk["input_ids"]
        yield chunk


def get_data_sets(train_path: str, test_path: str):
    train_set = list(get_data_from_pickle_with_labels(train_path))
    test_set = list(get_data_from_pickle_with_labels(test_path))
    return train_set, test_set


def make_adapter() -> None:

    # Get the datasets.
    train_set, test_set = get_data_sets("train_chunks.pkl", "test_chunks.pkl")

    # Set up the model.
    config = GPTJConfig.from_pretrained(
        "EleutherAI/gpt-j-6B", cache_dir=None, use_cache=False
    )
    model = GPTJAdapterModel.from_pretrained(
        "EleutherAI/gpt-j-6B", config=config, cache_dir=None
    )
    model.add_adapter("example")
    model.add_causal_lm_head("example")
    model.train_adapter("example")
    model = model.half()

    # Set up the trainer.
    trainer = AdapterTrainer(
        model,
        train_dataset=train_set,
        eval_dataset=test_set,
        args=training_args
    )

    # Do the training. (Comment out to exhibit full adapter save.)
    trainer.train()

    # Save the model.
    model.save_adapter("out/", "example")


if __name__ == "__main__":
    make_adapter()

The script to load the resulting adapter (run as python use-adapter.py):

from transformers.adapters import GPTJAdapterModel


def use_adapter() -> None:
    config = GPTJConfig.from_pretrained(
        "EleutherAI/gpt-j-6B", cache_dir=None, use_cache=False
    )
    model = GPTJAdapterModel.from_pretrained(
        "EleutherAI/gpt-j-6B",
        config=config,
        cache_dir=None
    )
    adapter_name = model.load_adapter("out/")
    model.add_causal_lm_head("finetune")
    model.set_active_adapters([adapter_name])
    model = model.half()
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
    prompt = "When it comes to Transformers, Adapters are an exciting"
    pipe = pipeline(
        task="text-generation",
        model=model,
        tokenizer=tokenizer,
        device=0
    )
    out = pipe(
        prompt,
        num_return_sequences=4,
        max_new_tokens=5,
        pad_token_id=tokenizer.eos_token_id
    )
    for result in out:
        gen = result['generated_text'][len(prompt):].strip()
        print(prompt, "[", gen, "]")


if __name__ == "__main__":
    use_adapter()
@jdwx jdwx added the bug Something isn't working label Feb 7, 2023
@adapter-hub-bert
Copy link
Member

This issue has been automatically marked as stale because it has been without activity for 90 days. This issue will be closed in 14 days unless you comment or remove the stale label.

@lenglaender lenglaender added do-not-stale This issue won't be automatically staled and closed after 90 days and removed Stale labels May 9, 2023
@calpt calpt added external-dependency Related to a third-party package or service and removed do-not-stale This issue won't be automatically staled and closed after 90 days labels Apr 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working external-dependency Related to a third-party package or service
Projects
None yet
Development

No branches or pull requests

4 participants