- 
                Notifications
    
You must be signed in to change notification settings  - Fork 2.1k
 
Open
Description
System Info
Accelerateversion: 1.10.1- Platform: Linux-5.15.0-157-generic-x86_64-with-glibc2.39
 acceleratebash location: /opt/uv/venv/bin/accelerate- Python version: 3.12.11
 - Numpy version: 2.2.6
 - PyTorch version: 2.9.0.dev20250825+cu128
 - PyTorch accelerator: CUDA
 - System RAM: 1003.13 GB
 - GPU type: NVIDIA H100 80GB HBM3
 Acceleratedefault config:
Not found- peft version: 0.15.2
 - transformers version: 4.56.1
 
Who can help?
Reproduction
"""Based on peft/examples/sft/run_peft_qlora_fsdp.sh
Launch command:
{
    "name": "Accelerate Launch - Minimal FSDP QLoRA Training",
    "type": "debugpy",
    "request": "launch",
    "module": "accelerate.commands.launch",
    "args": [
        "--config_file",
        "scripts/fsdp_config_qlora.yaml",
        "--num_processes",
        "2",
        "scripts/20251008_fsdp_qlora_sft_custom.py",
        "--seed",
        "100",
        "--model_name_or_path",
        "meta-llama/Llama-3.1-8B-Instruct",
        "--dataset_name",
        "smangrul/ultrachat-10k-chatml",
        "--add_special_tokens",
        "False",
        "--append_concat_token",
        "False",
        "--splits",
        "train,test",
        "--max_seq_len",
        "2048",
        "--num_train_epochs",
        "1",
        "--logging_steps",
        "5",
        "--log_level",
        "info",
        "--logging_strategy",
        "steps",
        "--learning_rate",
        "1e-4",
        "--lr_scheduler_type",
        "cosine",
        "--weight_decay",
        "1e-4",
        "--warmup_ratio",
        "0.0",
        "--max_grad_norm",
        "1.0",
        "--output_dir",
        "llama-sft-qlora-fsdp",
        "--per_device_train_batch_size",
        "2",
        "--per_device_eval_batch_size",
        "2",
        "--gradient_accumulation_steps",
        "2",
        "--gradient_checkpointing",
        "True",
        "--lora_r",
        "8",
        "--lora_alpha",
        "16",
        "--lora_dropout",
        "0.1",
        "--lora_target_modules",
        "all-linear",
        "--max_steps",
        "2",
    ],
    "console": "integratedTerminal",
    "justMyCode": false,
    "cwd": "${workspaceFolder}"
}
"""
import os
import sys
from dataclasses import dataclass, field
import torch
from accelerate import Accelerator
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError
from peft import LoraConfig, PeftConfig, PeftModel
from peft.utils.other import fsdp_auto_wrap_policy
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    PreTrainedModel,
    PreTrainedTokenizer,
    TrainingArguments,
    get_scheduler,
    set_seed,
)
from transformers.data.data_collator import DataCollatorWithPadding
class MinimalSFTTrainer:
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        peft_config: PeftConfig,
        train_dataset,
        args: TrainingArguments,
    ):
        self.args = args
        self.train_dataset = train_dataset
        self.tokenizer = tokenizer
        # Initialize accelerator with FSDP
        self.accelerator = Accelerator(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            mixed_precision="bf16",
        )
        # Prepare PEFT model
        if args.gradient_checkpointing:
            model.gradient_checkpointing_enable()
            model.enable_input_require_grads()
        # Create PEFT model
        self.model = PeftModel.from_pretrained(
            model,
            "AlignmentResearch/Llama-3.1-8B-Instruct-gsm8k-lora-reference",
            autocast_adapter_dtype=False,
            adapter_name="reference",
        )
        self.model.load_adapter(
            "AlignmentResearch/Llama-3.1-8B-Instruct-gsm8k-lora-reference", adapter_name="policy", autocast_adapter_dtype=False
        )
        # Critical: Update FSDP plugin for QLORA
        if self.accelerator.state.fsdp_plugin is not None:
            fsdp_plugin = self.accelerator.state.fsdp_plugin
            # Set auto wrap policy for PEFT
            fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
            quant_storage = model.hf_quantizer.quantization_config.bnb_4bit_quant_storage
            if quant_storage.is_floating_point:
                fsdp_plugin.set_mixed_precision(quant_storage, override=True)
        # Create dataloader
        formatted_ds = self.train_dataset.map(
            lambda x: {"content": tokenizer.apply_chat_template(x["messages"], tokenize=False)},
            batched=False,
            remove_columns=self.train_dataset.column_names,
        )
        tokenized_ds = formatted_ds.map(
            lambda x: self.tokenizer(x["content"], truncation=True), batched=True, remove_columns=formatted_ds.column_names
        )
        self.train_dataloader = DataLoader(
            tokenized_ds,
            batch_size=args.per_device_train_batch_size,
            collate_fn=DataCollatorWithPadding(self.tokenizer),
            shuffle=True,
        )
        # Create optimizer - only optimize trainable parameters
        optimizer_params = [p for p in model.parameters() if p.requires_grad]
        self.optimizer = torch.optim.AdamW(
            optimizer_params,
            lr=args.learning_rate,
            weight_decay=args.weight_decay,
        )
        # Calculate training steps
        num_update_steps_per_epoch = len(self.train_dataloader) // args.gradient_accumulation_steps
        max_steps = args.max_steps if args.max_steps > 0 else int(args.num_train_epochs * num_update_steps_per_epoch)
        # Create scheduler
        self.lr_scheduler = get_scheduler(
            args.lr_scheduler_type,
            optimizer=self.optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=max_steps,
        )
        # Prepare everything with accelerator
        self.model.base_model.set_adapter(["reference", "policy"])
        self.accelerator.print(f"Active adapters: {self.model.active_adapters}")
        for name, param in self.model.named_parameters():
            if "layers.0.self_attn.q_proj" in name:
                print(f"{name} {param.shape} {param.device} {param.dtype} {param.requires_grad}")
        # N.B. the below will hang unless peft.tuners.tuner_utils.py::BaseTunerLayer._move_adapter_to_device_of_base_layer is
        # overridden to remove the special meta device handling
        self.accelerator.print("Preparing everything with accelerator")
        self.model, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare(
            self.model, self.optimizer, self.train_dataloader, self.lr_scheduler
        )
        self.accelerator.print("Everything prepared with accelerator")
        self.global_step = 0
        self.max_steps = max_steps
def create_dataset(tokenizer, data_args):
    raw_datasets = DatasetDict()
    for split in data_args.splits.split(","):
        try:
            # Try first if dataset on a Hub repo
            dataset = load_dataset(data_args.dataset_name, split=split)
        except DatasetGenerationError:
            # If not, check local dataset
            dataset = load_from_disk(os.path.join(data_args.dataset_name, split))
        assert isinstance(dataset, Dataset)
        dataset = dataset.select(range(8))
        if "train" in split:
            raw_datasets["train"] = dataset
        elif "test" in split:
            raw_datasets["test"] = dataset
        else:
            raise ValueError(f"Split type {split} not recognized as one of test or train.")
    train_data = raw_datasets["train"]
    print(f"Size of the train set: {len(train_data)}")
    print(f"A sample of train dataset: {train_data[0]}")
    return train_data
def create_and_prepare_model(args):
    quant_storage_dtype = torch.bfloat16
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype="bfloat16",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_storage=quant_storage_dtype,
    )
    torch_dtype = quant_storage_dtype if quant_storage_dtype and quant_storage_dtype.is_floating_point else torch.float32
    # Prepare model loading arguments
    model_kwargs = {
        "trust_remote_code": True,
        "torch_dtype": torch_dtype,
        "attn_implementation": "flash_attention_2",
        "quantization_config": bnb_config,
        "use_cache": False,
    }
    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)
    peft_config = LoraConfig(
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        r=args.lora_r,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=args.lora_target_modules.split(",")
        if args.lora_target_modules != "all-linear"
        else args.lora_target_modules,
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    return model, peft_config, tokenizer
# Define and parse arguments.
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    max_seq_length: int | None = field(
        default=512,
        metadata={"help": "The maximum total input sequence length after tokenization."},
    )
    lora_alpha: int | None = field(default=16)
    lora_dropout: float | None = field(default=0.1)
    lora_r: int | None = field(default=64)
    lora_target_modules: str | None = field(
        default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
        metadata={"help": "comma separated list of target modules to apply LoRA layers to"},
    )
@dataclass
class DataTrainingArguments:
    dataset_name: str | None = field(
        default="timdettmers/openassistant-guanaco",
        metadata={"help": "The preference dataset to use."},
    )
    append_concat_token: bool | None = field(
        default=False,
        metadata={"help": "If True, appends `eos_token_id` at the end of each sample being packed."},
    )
    add_special_tokens: bool | None = field(
        default=False,
        metadata={"help": "If True, tokenizers adds special tokens to each sample being packed."},
    )
    splits: str | None = field(
        default="train,test",
        metadata={"help": "Comma separate list of the splits to use from the dataset."},
    )
def main(model_args, data_args, training_args):
    # Set seed for reproducibility
    set_seed(training_args.seed)
    # model
    model, peft_config, tokenizer = create_and_prepare_model(model_args)
    training_args.dataset_kwargs = {
        "append_concat_token": data_args.append_concat_token,
        "add_special_tokens": data_args.add_special_tokens,
    }
    # datasets
    train_dataset = create_dataset(
        tokenizer,
        data_args,
    )
    # trainer
    trainer = MinimalSFTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        peft_config=peft_config,
    )
    trainer.accelerator.print(f"{trainer.model}")
    if hasattr(trainer.model, "print_trainable_parameters"):
        trainer.model.print_trainable_parameters()
if __name__ == "__main__":
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    main(model_args, data_args, training_args)
Expected behavior
The accelerator.prepare should not hang, and also I would expect that device1 would show all tensors on meta device, but in fact it shows that the second adapter is redundantly on cpu
base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight torch.Size([4194304, 1]) meta torch.bfloat16 False
base_model.model.model.layers.0.self_attn.q_proj.lora_A.reference.weight torch.Size([64, 4096]) meta torch.bfloat16 True
base_model.model.model.layers.0.self_attn.q_proj.lora_A.policy.weight torch.Size([64, 4096]) cpu torch.float32 True
base_model.model.model.layers.0.self_attn.q_proj.lora_B.reference.weight torch.Size([4096, 64]) meta torch.bfloat16 True
base_model.model.model.layers.0.self_attn.q_proj.lora_B.policy.weight torch.Size([4096, 64]) cpu torch.float32 True
Metadata
Metadata
Assignees
Labels
No labels