Skip to content

Conversation

KuriUni
Copy link

@KuriUni KuriUni commented Aug 14, 2025

layer.output_adapters = BottleneckLayer("output_adapter", is_layer_hooked=True)
ln_2_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_2, None)
layer_output_proj.register_forward_hook(partial(hook_fn, layer.output_adapters, ln_2_get_fn))

This code causes the layer.output_adapters of cuda:n to always point to the layer.output_adapters of cuda 0 during multi-GPU training with the default distributed settings of the Huggingface trainer. The model can be properly distributed to different GPUs. I suspect it is due to partial. So I tried to save variables like layer.xxx and layer in the context so that it can run on multiple GPUs.

Variables like residual and hidden state are both shown to be on cuda1 during debugging, but layer is shown to be on cuda0. I printed the addresses of the layer variable on two GPUs. The address of layer on cuda:1 is the same as that on cuda:0. Since my GPU can't handle models like Qwen, and it's not easy to provide data for my own model, could you please test whether this problem occurs in multi-GPU training? Thank you! I followed the process of adapters-for-any-transformer.

@KuriUni
Copy link
Author

KuriUni commented Aug 14, 2025

I accidentally clicked on "Abandon Submission"😭😭😭😭

@lenglaender lenglaender self-requested a review August 23, 2025 09:10
@lenglaender
Copy link
Member

lenglaender commented Aug 23, 2025

Hi @KuriUni, thanks for the PR!

Could you provide a script to reproduce your issue? I was unable to reproduce the multi-GPU problem you described. I ran the script below to fine-tune a Qwen model on 2 GPUs using both DataParallel and DistributedDataParallel, and it completed without any errors. I used a bottleneck adapter

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
import numpy as np
import evaluate
import os
from adapters import BnConfig
import adapters


os.environ["WANDB_DISABLED"] = "true"

def main():
    model_name = "Qwen/Qwen3-0.6B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = model.config.eos_token_id
    adapters.init(model)
    config = BnConfig(mh_adapter=True, output_adapter=True, reduction_factor=16, non_linearity="relu")
    model.add_adapter("test_adapter_name", config=config)
    model.set_active_adapters("test_adapter_name")
    model.train_adapter("test_adapter_name")
    
    print(model.adapter_summary())

    dataset = load_dataset("timdettmers/openassistant-guanaco")
    def tokenize(element):
        return tokenizer(
            element["text"],
            truncation=True,
            max_length=512, # can set to longer values such as 2048
            add_special_tokens=False,
            padding="max_length",
        )

    dataset_tokenized = dataset.map(
        tokenize, 
        batched=True, 
        num_proc=os.cpu_count(),    # multithreaded
        remove_columns=["text"]     # don't need this anymore, we have tokens from here on
    )
    args = TrainingArguments(
        output_dir="outputs.log",
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        logging_steps=10,
        save_steps=500,
        eval_steps=100,
        save_total_limit=3,
        gradient_accumulation_steps=16,
        max_steps=1875,
        lr_scheduler_type="constant",
        learning_rate=0.0002,
        group_by_length=True,
        warmup_ratio=0.03,
        max_grad_norm=0.3,
    )

    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        args=args,
        data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
        train_dataset=dataset_tokenized["train"],
        eval_dataset=dataset_tokenized["test"],
    )
    
    print(f"🖥️ Trainer is configured to use {trainer.args.n_gpu} GPUs.")

    trainer.train()

if __name__ == "__main__":
    main()

For DataParallel, I started the script like this: python3 script.py
and for Distributed Data Parallel like this: python3 -m torch.distributed.launch --nproc_per_node=2 script.py

Both times (DP & DDP), it ran through without issues. To help us investigate, could you please provide a minimal script that reproduces the problem?

@lenglaender lenglaender self-assigned this Sep 10, 2025
@lenglaender
Copy link
Member

lenglaender commented Sep 10, 2025

Hi @KuriUni,
did you get a chance to follow up on this? It would be immensely helpful for the community if we could clarify whether multi-GPU training is indeed a problem or not 🙌

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants