-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathjoint_train.sh
More file actions
97 lines (84 loc) · 2.75 KB
/
joint_train.sh
File metadata and controls
97 lines (84 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/bin/bash
set -x -e
export NCCL_DEBUG=INFO
export NCCL_TIMEOUT=18000000
export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True'
export MKL_THREADING_LAYER='GNU'
export HYDRA_FULL_ERROR='1'
export NCCL_ASYNC_ERROR_HANDLING='1'
NNODES=8
GPUS_PER_NODE=8
CPUS_PER_TASK=16
WORLD_SIZE=$((NNODES * GPUS_PER_NODE))
NUM_REPLICATE=$NNODES
NUM_SHARD=$GPUS_PER_NODE
MASTER_ADDR=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n1)
MASTER_PORT=$((RANDOM % 101 + 25199))
echo "MASTER_ADDR=$MASTER_ADDR"
echo "MASTER_PORT=$MASTER_PORT"
echo "NNODES=$NNODES, GPUS_PER_NODE=$GPUS_PER_NODE, WORLD_SIZE=$WORLD_SIZE"
job_id=${SLURM_JOB_ID}
name="g2vlm_joint_train_${WORLD_SIZE}g_${job_id}"
export MODEL_PATH="InternRobotics/G2VLM-Qwen2-VL-2B"
export PRETRAINED_CHECKPOINT="InternRobotics/G2VLM-2B-MoT"
export output_dir="./checkpoints/${name}/"
mkdir -p ${output_dir}
export checkpoint_dir="./checkpoints/${name}"
mkdir -p ${checkpoint_dir}
# export WANDB_MODE=offline
# export WANDB_API_KEY="your key"
# export CUDA_LAUNCH_BLOCKING=1
export current_time=$(date +%Y%m%d_%H%M%S)
export wandb_name=$name
# export PYTORCH_CUDA_ALLOC_CONF=garbage_collection_threshold:0.6,max_split_size_mb:128,expandable_segments:True
torchrun \
--nnodes=${NNODES} \
--nproc_per_node=${GPUS_PER_NODE} \
--node_rank=\$SLURM_NODEID \
--master_addr=${MASTER_ADDR} \
--master_port=${MASTER_PORT} \
train/joint_train_unified_model.py \
--dataset_config_file data/configs/joint_train.yaml \
--layer_module Qwen2VLMoTDecoderLayer \
--vit_path ${MODEL_PATH} \
--dino_path facebook/dinov2-with-registers-large \
--llm_path ${MODEL_PATH} \
--model_path ${MODEL_PATH} \
--use_flex True \
--expected_num_tokens 40960 \
--max_num_tokens 40960 \
--max_num_tokens_per_sample 40960 \
--wandb_project G2VLM \
--wandb_name ${wandb_name} \
--wandb_offline True \
--wandb_resume allow \
--checkpoint_dir ${checkpoint_dir} \
--llm_qk_norm True \
--finetune_from_hf True \
--auto_resume False \
--resume-model-only True \
--finetune-from-ema False \
--enable_ema_model False \
--resume_from ${PRETRAINED_CHECKPOINT} \
--finetune_dino_from_hf False \
--copy_init_moe False \
--visual_und True \
--visual_recon True \
--freeze_dino True \
--freeze_vit True \
--freeze_und False \
--freeze_recon False \
--joint_train_recon False \
--pretrain_train_recon False \
--results_dir ${output_dir} \
--save_every 500 \
--total_steps 12000 \
--warmup_steps 400 \
--log_every 1 \
--sharding_strategy HYBRID_SHARD \
--cpu_offload True \
--num_replicate=${NUM_REPLICATE} \
--lr 2e-5 \
--lr_scheduler cosine \
--num_shard=${NUM_SHARD} \
--num_workers 4"