diff --git a/run_llama_train.sh b/run_llama_train.sh index a4107806..a69c967a 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -19,6 +19,7 @@ if [ $# -ne 0 ]; then overrides="$*" fi +PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ train.py --job.config_file ${CONFIG_FILE} $overrides