Skip to content

Commit 07c962e

Browse files
Create Example training scripts to run in Stability cluster (#419)
* Creating slurm submission scripts * removing newline * Adding additional comment
1 parent 6e655a4 commit 07c962e

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

scripts/accelerate_train_example.sh

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/bin/bash
2+
3+
set -exuo pipefail
4+
5+
# HOSTNAMES MASTER_ADDR MASTER_PORT COUNT_NODE are coming from the main script
6+
H=`hostname`
7+
RANK=`echo -e $HOSTNAMES | python3 -c "import sys;[sys.stdout.write(str(i)) for i,line in enumerate(next(sys.stdin).split(' ')) if line.strip() == '$H'.strip()]"`
8+
9+
CONFIG_FILE=${1-configs/deepspeed/zero2-bf16.yaml} # relative to TRLX_DIR
10+
CONDA_DIR=${2:-/admin/home-amuzio/miniconda3}
11+
CONDA_ENV_NAME=${3:-trlx}
12+
13+
# This script assumes the following:
14+
# (1) a conda environment named $CONDA_ENV_NAME
15+
# (2) It is being run from the $TRLX_DIR directory
16+
# If using venv, you can remove the conda stuff and just activate the venv directly
17+
set +x
18+
export PATH="$CONDA_DIR/condabin:$PATH"
19+
source $CONDA_DIR/etc/profile.d/conda.sh
20+
conda activate $CONDA_ENV_NAME
21+
set -x
22+
23+
24+
accelerate launch \
25+
--num_processes $((8 * $COUNT_NODE)) \
26+
--num_machines $COUNT_NODE \
27+
--machine_rank $RANK \
28+
--main_process_ip $MASTER_ADDR \
29+
--config_file $CONFIG_FILE \
30+
examples/ilql_sentiments.py

scripts/slurm_train.sh

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=trlx
3+
#SBATCH --nodes=1
4+
#SBATCH --ntasks-per-node=1
5+
#SBATCH --partition=g40
6+
#SBATCH --mem=0
7+
#SBATCH --output=logs/%x_%j.out
8+
#SBATCH --error=logs/%x_%j.err
9+
#SBATCH --comment=carperai
10+
#SBATCH --exclusive
11+
12+
# Example usage:
13+
# sbatch slurm_train.sh TRLX_DIR
14+
15+
set -exuo pipefail
16+
17+
export LD_LIBRARY_PATH=/opt/aws-ofi-nccl/lib:/opt/amazon/efa/lib64:/usr/local/cuda-11.0/efa/lib:/usr/local/cuda-11.0/lib:/usr/local/cuda-11.0/lib64:/usr/local/cuda-11.0:/opt/nccl/build/lib:/opt/aws-ofi-nccl-install/lib:/opt/aws-ofi-nccl/lib:$LD_LIBRARY_PATH
18+
export PATH=/opt/amazon/efa/bin:/opt/amazon/openmpi/bin:$PATH
19+
20+
export NCCL_DEBUG=WARN
21+
export NCCL_PROTO=simple
22+
export FI_EFA_FORK_SAFE=1
23+
export FI_LOG_LEVEL=1
24+
export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn
25+
export FI_EFA_ENABLE_SHM_TRANSFER=0
26+
export FI_PROVIDER=efa
27+
export FI_EFA_TX_MIN_CRE DITS=64
28+
# export CUDA_LAUNCH_BLOCKING=1
29+
30+
export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
31+
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
32+
export MASTER_PORT=1234
33+
export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`
34+
35+
TRLX_DIR=${1:-/fsx/home-amuzio/trlx}
36+
TRAIN_SCRIPT=${2-scripts/accelerate_train_example.sh} # relative to TRLX_DIR
37+
CONFIG_FILE=${3-configs/accelerate/zero2-bf16.yaml} # relative to TRLX_DIR
38+
CONDA_DIR=${4:-/admin/home-amuzio/miniconda3}
39+
CONDA_ENV_NAME=${5:-trlx}
40+
41+
pushd $TRLX_DIR
42+
srun --comment carperai $TRAIN_SCRIPT \
43+
$CONFIG_FILE \
44+
$CONDA_DIR \
45+
$CONDA_ENV_NAME

0 commit comments

Comments
 (0)