Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to increase batch size by using multiple gpus? #3207

Open
13918763630 opened this issue Jan 31, 2025 · 8 comments
Open

How to increase batch size by using multiple gpus? #3207

13918763630 opened this issue Jan 31, 2025 · 8 comments

Comments

@13918763630
Copy link

Hello! My fine-tuned model need a large batch size to get the best performance. I have multiple gpus with 40G VRAM each. How can i use them together to enlarge the batch size? Currently i can only set the batch size be 3 per GPU and seems they won't share the datas. How can i make the total batch size become 24?

@tomaarsen
Copy link
Collaborator

Hello!

I would recommend having a read through #2831, where we discuss sharing negatives across devices (i.e., using multiple GPUs to create one big batch in which the in-batch negatives are shared).
The conclusion there was that although it is possible to create a loss that shares negatives in the batch (there's actually a working snippet in there that does this), it's slower than the alternative: using CachedMultipleNegativeRankingLoss. In short, this loss uses a clever gradient caching approach that allows us to perform the embedding computations in mini-batches while still giving identical results as if you used one big batch.

So, you can use a SentenceTransformerTrainingArguments(..., per_device_batch_size=32), while using a CachedMultipleNegativeRankingLoss(..., mini_batch_size=3). Each device will then take the in-batch negatives from the 32 samples in the batch, but there won't be any (slow!) data sharing.

You can scale this 32 up to any number that you wish (larger is generally better, to an extent). Then, using multiple GPUs should allow you to parallelize this quite well. Everything stays on the same GPU, but you're processing 8 times as many large batches as with just one GPU.

  • Tom Aarsen

@13918763630
Copy link
Author

Hello!

I would recommend having a read through #2831, where we discuss sharing negatives across devices (i.e., using multiple GPUs to create one big batch in which the in-batch negatives are shared). The conclusion there was that although it is possible to create a loss that shares negatives in the batch (there's actually a working snippet in there that does this), it's slower than the alternative: using CachedMultipleNegativeRankingLoss. In short, this loss uses a clever gradient caching approach that allows us to perform the embedding computations in mini-batches while still giving identical results as if you used one big batch.

So, you can use a SentenceTransformerTrainingArguments(..., per_device_batch_size=32), while using a CachedMultipleNegativeRankingLoss(..., mini_batch_size=3). Each device will then take the in-batch negatives from the 32 samples in the batch, but there won't be any (slow!) data sharing.

You can scale this 32 up to any number that you wish (larger is generally better, to an extent). Then, using multiple GPUs should allow you to parallelize this quite well. Everything stays on the same GPU, but you're processing 8 times as many large batches as with just one GPU.

  • Tom Aarsen

Thank you so much! It works. But there also another big issue... It seems the model have been loaded repeatedly, which cause the gpu's memory OOM. Some other people also reports this problem. Some 7B model cannot be trained on a 80G RAM GPU like A100.

@tomaarsen
Copy link
Collaborator

Fair enough! For that, you would need FSDP. FSDP is partially supported in Sentence Transformers, but it's not been tested significantly. See the documentation here: https://sbert.net/docs/sentence_transformer/training/distributed.html#fsdp

  • Tom Aarsen

@13918763630
Copy link
Author

Fair enough! For that, you would need FSDP. FSDP is partially supported in Sentence Transformers, but it's not been tested significantly. See the documentation here: https://sbert.net/docs/sentence_transformer/training/distributed.html#fsdp

  • Tom Aarsen

Thank you so much! Now it can run properly on 1 GPU. But if i extend it to 8, it will lead to OOM. When i use the torchrun, it seems all of the process will run on the GPU 0 , When i use the accelerator, the GPU 0 also have the OOM errors. Could you help me figure out this problem? Thank you so much!

@tomaarsen
Copy link
Collaborator

tomaarsen commented Feb 10, 2025

But if i extend it to 8, it will lead to OOM.

Did you keep the per_device_batch_size and mini_batch_size identical? I'm surprised to hear it OOM's.
Are you using --nproc_per_node=8 in torchrun as well?

You shouldn't have to set up the accelerator yourself, that should be taken care of. The training script should be pretty much the same as 1 GPU (the only difference is that you have to wrap your main code in if __name__ == "__main__" as is standard for multi-gpu/multi-process code).

Edit: I realise now that it might be that the device placement is too naive. Instead, you should use:

def main():
    local_rank = int(os.environ["LOCAL_RANK"])

    # 1. Load a model to finetune
    model = SentenceTransformer(
        model_name_or_path=Alibaba-NLP/gte-Qwen2-7B-instruct,
        model_kwargs={
            "device_map": "auto",
        },
        tokenizer_kwargs={
            "model_max_length": 512,
            "truncation": True
        },
        device=f"cuda:{local_rank}",
    )
    # set the max input seq length to 512
    model.max_seq_length = 512


if __name__ == "__main__":
    main()

This uses the device_map option in AutoModel.from_pretrained from transformers, which places the model on the correct device. I think the device=... might not be necessary (or if it still breaks, consider removing it), but it might also be required.

  • Tom Aarsen

@13918763630
Copy link
Author

But if i extend it to 8, it will lead to OOM.

Did you keep the per_device_batch_size and mini_batch_size identical? I'm surprised to hear it OOM's. Are you using --nproc_per_node=8 in torchrun as well?

You shouldn't have to set up the accelerator yourself, that should be taken care of. The training script should be pretty much the same as 1 GPU (the only difference is that you have to wrap your main code in if __name__ == "__main__" as is standard for multi-gpu/multi-process code).

Edit: I realise now that it might be that the device placement is too naive. Instead, you should use:

def main():
local_rank = int(os.environ["LOCAL_RANK"])

# 1. Load a model to finetune
model = SentenceTransformer(
    model_name_or_path=Alibaba-NLP/gte-Qwen2-7B-instruct,
    model_kwargs={
        "device_map": "auto",
    },
    tokenizer_kwargs={
        "model_max_length": 512,
        "truncation": True
    },
    device=f"cuda:{local_rank}",
)
# set the max input seq length to 512
model.max_seq_length = 512

if name == "main":
main()
This uses the device_map option in AutoModel.from_pretrained from transformers, which places the model on the correct device. I think the device=... might not be necessary (or if it still breaks, consider removing it), but it might also be required.

  • Tom Aarsen

Hi~
There are exists some errors.
ValueError: You can't train a model that has been loaded with device_map='auto' in any distributed mode. Please rerun your script specifying --num_processes=1 or by launching with python {{myscript.py}}.
When i run the torchrun --nproc_per_node=2 script.py with the "device_map": "auto",

@tomaarsen
Copy link
Collaborator

tomaarsen commented Feb 18, 2025

Hmm, perhaps then the device_map should be avoided with distributed training?
My apologies, I haven't ran distributed training myself in a few months.

  • Tom Aarsen

@13918763630
Copy link
Author

13918763630 commented Feb 19, 2025

Hmm, perhaps then the device_map should be avoided with distributed training? My apologies, I haven't ran distributed training myself in a few months.

  • Tom Aarsen

Case1:
torchrun --nproc_per_node=2 script.py with the "device_map": "auto"
=========> ValueError: You can't train a model that has been loaded with device_map='auto' in any distributed mode. Please rerun your script specifying --num_processes=1 or by launching with python {{myscript.py}}.

Case2:
python script.py with the "device_map": "auto" the tensor will be distributed to different GPU device and leat to
=========> RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

Case3:
torchrun --nproc_per_node=2 script.py without the "device_map": "auto" all tensor be sent to the cuda:0 first
==========> cuda:0 OOM!

Case4:
python script.py script.py without the "device_map": "auto"
This is the only case the model can be trained. But can only use one GPU at the same time.


This is a really strange problem. I don't know whether this problem happens due to my script.py or not.

Here is the code of the script.py:


from datasets import load_dataset
from datasets import Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator
from cachedselfloss import CachedInfonce
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainingArguments, SentenceTransformerTrainer
# Other imports here
import os

def main():
   

    def is_valid_charset(s):
        try:
            s.encode('ascii')
            return True
        except UnicodeEncodeError:
            return False

                
    def read_txt_files_to_tuple_list():
        tuple_list = []
        question_list = []
        ans_list = []
        for x in ['dataset','dataset4','dataset2','dataset3','dataset6','dataset5']:
            directory = '/root/work_dir/root/dataset_set/'+x
            for filename in os.listdir(directory):
                if filename.endswith(".txt"):
                    filepath = os.path.join(directory, filename)
                    
                    with open(filepath, 'r') as file:
                        content = file.read().strip()
                        try:
                            tuple_list += eval(content)
                        except Exception as e:
                            pass
                            
        for item in tuple_list:
            q,a = item
            if is_valid_charset(q) and is_valid_charset(a):
                question_list.append(q)
                ans_list.append(a)
        return question_list,ans_list
            
    # 1. Load a model to finetune with 2. (Optional) model card data
    model = SentenceTransformer(
        'Salesforce/SFR-Embedding-Mistral',
    model_kwargs={
            "device_map": "auto",
        },
    tokenizer_kwargs={
            "model_max_length": 8192,
            "truncation": True
        },
    )

    question_list,ans_list = read_txt_files_to_tuple_list()
    dataset = Dataset.from_dict({
        "anchor": question_list,
        "positive": ans_list,
    })


    train_test_split_result = dataset.train_test_split(test_size=0.2) 
    train_dataset = train_test_split_result['train']
    temp_dataset = train_test_split_result['test']


    eval_dataset = temp_dataset

    # 4. Define a loss function
    #loss = MultipleNegativesRankingLoss(model)
    #loss = Infonce(model)
    loss = CachedInfonce(model,mini_batch_size=1)

    # 5. (Optional) Specify training arguments
    args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir="models/sfr_model",
        # Optional training parameters:
        num_train_epochs=1,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        learning_rate=2e-5,
        warmup_ratio=0.1,
        fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=False,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=20,
        save_strategy="steps",
        save_steps=20,
        save_total_limit=2,
        logging_steps=20,
        run_name="sfr_model",  # Will be used in W&B if `wandb` is installed
    )

    # 7. Create a trainer & train
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
    )
    trainer.train()

    # 8. Save the trained model
    model.save_pretrained("models/aaaa2/final")

    # 9. (Optional) Push it to the Hugging Face Hub
    model.push_to_hub("mpnet-base-all-nli-triplet")



if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants