Skip to content

Commit 51832a6

Browse files
cat-statecat-state
and
cat-state
authored
NeMo PPO (#472)
This PR adds NeMoPPOTrainer, a trainer for training with PPO NeMo Megatron models, supporting tensor parallelism and reference model offloading. Supports 1.3b and 20b NeMo models. Thanks to @maxreciprocate for helping debug issues and reviewing the huge PR! --------- Co-authored-by: cat-state <cat@meow>
1 parent 404217b commit 51832a6

21 files changed

+2241
-102
lines changed
+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
name: megatron_gpt_1.3b
2+
restore_from_path: null # used when starting from a .nemo file
3+
4+
trainer:
5+
devices: 8
6+
num_nodes: 1
7+
accelerator: gpu
8+
precision: bf16
9+
logger: False # logger provided by exp_manager
10+
enable_checkpointing: False
11+
replace_sampler_ddp: False
12+
max_epochs: -1 # PTL default. In practice, max_steps will be reached first.
13+
max_steps: 200 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
14+
log_every_n_steps: 1
15+
val_check_interval: 20
16+
# check_val_every_n_epoch: null
17+
limit_val_batches: 2
18+
limit_test_batches: 0
19+
accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models
20+
gradient_clip_val: 1.0
21+
benchmark: False
22+
23+
exp_manager:
24+
# set this to save checkpoints
25+
explicit_log_dir: ppo_sentiments_logs
26+
exp_dir: null
27+
name: megatron_gpt_1.3b_ppo_sentiments
28+
create_tensorboard_logger: False
29+
create_wandb_logger: False
30+
wandb_logger_kwargs:
31+
project: trlxnemo
32+
name: megatron_gpt_1.3b_ppo_sentiments
33+
resume_if_exists: False
34+
resume_ignore_no_checkpoint: True
35+
# set this to save checkpoints
36+
create_checkpoint_callback: False
37+
checkpoint_callback_params:
38+
monitor: reduced_train_loss
39+
save_top_k: 1
40+
mode: min
41+
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
42+
save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits
43+
filename: 'megatron_gpt-{reduced_train_loss:.2f}-{step}-{consumed_samples}'
44+
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}
45+
log_step_timing: True
46+
step_timing_kwargs:
47+
sync_cuda: True
48+
buffer_size: 5
49+
50+
model:
51+
global_batch_size: 256
52+
micro_batch_size: 32
53+
tensor_model_parallel_size: 1
54+
pipeline_model_parallel_size: 1
55+
encoder_seq_length: 2048
56+
max_position_embeddings: 2048
57+
num_layers: 24
58+
hidden_size: 2048
59+
ffn_hidden_size: 3072
60+
num_attention_heads: 16
61+
init_method_std: 0.015
62+
hidden_dropout: 0.1
63+
kv_channels: null
64+
apply_query_key_layer_scaling: true
65+
layernorm_epsilon: 1.0e-05
66+
make_vocab_size_divisible_by: 128
67+
pre_process: true
68+
post_process: true
69+
tokenizer:
70+
library: megatron
71+
type: GPT2BPETokenizer
72+
model: null
73+
vocab_file: /artifacts/vocab.json
74+
merge_file: /artifacts/merges.txt
75+
native_amp_init_scale: 4294967296
76+
native_amp_growth_interval: 1000
77+
fp32_residual_connection: false
78+
fp16_lm_cross_entropy: false
79+
80+
megatron_amp_O2: True
81+
sync_batch_comm: False
82+
83+
seed: 1234
84+
use_cpu_initialization: false
85+
onnx_safe: false
86+
activations_checkpoint_method: null
87+
activations_checkpoint_num_layers: 1
88+
89+
gradient_as_bucket_view: True
90+
resume_from_checkpoint: null
91+
sequence_parallel: True
92+
93+
data:
94+
data_prefix:
95+
- 0.0333
96+
- /preproc_data/my-gpt3_00_text_document
97+
- 0.0333
98+
- /preproc_data/my-gpt3_01_text_document
99+
- 0.0333
100+
- /preproc_data/my-gpt3_02_text_document
101+
- 0.0333
102+
- /preproc_data/my-gpt3_03_text_document
103+
- 0.0333
104+
- /preproc_data/my-gpt3_04_text_document
105+
- 0.0333
106+
- /preproc_data/my-gpt3_05_text_document
107+
- 0.0333
108+
- /preproc_data/my-gpt3_06_text_document
109+
- 0.0333
110+
- /preproc_data/my-gpt3_07_text_document
111+
- 0.0333
112+
- /preproc_data/my-gpt3_08_text_document
113+
- 0.0333
114+
- /preproc_data/my-gpt3_09_text_document
115+
- 0.0333
116+
- /preproc_data/my-gpt3_10_text_document
117+
- 0.0333
118+
- /preproc_data/my-gpt3_11_text_document
119+
- 0.0333
120+
- /preproc_data/my-gpt3_12_text_document
121+
- 0.0333
122+
- /preproc_data/my-gpt3_13_text_document
123+
- 0.0333
124+
- /preproc_data/my-gpt3_14_text_document
125+
- 0.0333
126+
- /preproc_data/my-gpt3_15_text_document
127+
- 0.0333
128+
- /preproc_data/my-gpt3_16_text_document
129+
- 0.0333
130+
- /preproc_data/my-gpt3_17_text_document
131+
- 0.0333
132+
- /preproc_data/my-gpt3_18_text_document
133+
- 0.0333
134+
- /preproc_data/my-gpt3_19_text_document
135+
- 0.0333
136+
- /preproc_data/my-gpt3_20_text_document
137+
- 0.0333
138+
- /preproc_data/my-gpt3_21_text_document
139+
- 0.0333
140+
- /preproc_data/my-gpt3_22_text_document
141+
- 0.0333
142+
- /preproc_data/my-gpt3_23_text_document
143+
- 0.0333
144+
- /preproc_data/my-gpt3_24_text_document
145+
- 0.0333
146+
- /preproc_data/my-gpt3_25_text_document
147+
- 0.0333
148+
- /preproc_data/my-gpt3_26_text_document
149+
- 0.0333
150+
- /preproc_data/my-gpt3_27_text_document
151+
- 0.0333
152+
- /preproc_data/my-gpt3_28_text_document
153+
- 0.0334
154+
- /preproc_data/my-gpt3_29_text_document
155+
data_impl: mmap
156+
splits_string: 99990,8,2
157+
seq_length: 2048
158+
skip_warmup: true
159+
num_workers: 0
160+
dataloader_type: single
161+
reset_position_ids: false
162+
reset_attention_mask: false
163+
eod_mask_loss: True
164+
optim:
165+
name: distributed_fused_adam
166+
lr: 6e-05
167+
weight_decay: 1e-06
168+
betas:
169+
- 0.9
170+
- 0.95
171+
sched:
172+
name: CosineAnnealing
173+
warmup_steps: 0
174+
constant_steps: 100000000
175+
min_lr: 5e-05
176+
precision: bf16
177+
vocab_file: nemo:c4aec99015da48ba8cbcba41b48feb2c_vocab.json
178+
merges_file: nemo:50284f68eefe440e850c4fb42c4d13e7_merges.txt

configs/nemo_configs/megatron_20b.yaml

+13-13
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
name: megatron_gpt
1+
name: megatron_gpt_20b
22
restore_from_path: null # used when starting from a .nemo file
33

44
trainer:
55
devices: 8
66
num_nodes: 4
77
accelerator: gpu
8-
precision: 16
8+
precision: bf16
99
logger: False # logger provided by exp_manager
1010
enable_checkpointing: False
1111
replace_sampler_ddp: False
@@ -22,18 +22,18 @@ trainer:
2222

2323
exp_manager:
2424
# set this to save checkpoints
25-
explicit_log_dir: ilql_sentiments_logs
25+
explicit_log_dir: ppo_sentiments_logs
2626
exp_dir: null
27-
name: megatron_gpt_20b_ilql_sentiments
27+
name: megatron_gpt_20b_ppo_sentiments
2828
create_tensorboard_logger: False
29-
create_wandb_logger: True
29+
create_wandb_logger: False
3030
wandb_logger_kwargs:
3131
project: trlxnemo
32-
name: megatron_gpt_20b_ilql_sentiments
32+
name: megatron_gpt_20b_ppo_sentiments
3333
resume_if_exists: False
3434
resume_ignore_no_checkpoint: True
3535
# set this to save checkpoints
36-
create_checkpoint_callback: True
36+
create_checkpoint_callback: False
3737
checkpoint_callback_params:
3838
monitor: reduced_train_loss
3939
save_top_k: 1
@@ -48,13 +48,13 @@ exp_manager:
4848
buffer_size: 5
4949

5050
model:
51-
micro_batch_size: 4
52-
global_batch_size: 512
51+
micro_batch_size: 2
52+
global_batch_size: 64
5353
tensor_model_parallel_size: 4
5454
pipeline_model_parallel_size: 1
5555
resume_from_checkpoint: null # manually set the checkpoint file to load from
5656
# model architecture
57-
encoder_seq_length: 1024
57+
encoder_seq_length: 2048
5858
max_position_embeddings: 2048
5959
num_layers: 44
6060
hidden_size: 6144
@@ -98,7 +98,6 @@ model:
9898
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16
9999

100100
# Megatron O2-style half-precision
101-
# TODO: this causes hangs for some reason
102101
megatron_amp_O2: True # Enable O2-level automatic mixed precision using main parameters
103102
grad_allreduce_chunk_size_mb: 125
104103
sync_batch_comm: False
@@ -133,12 +132,13 @@ model:
133132

134133
optim:
135134
name: distributed_fused_adam
136-
lr: 5.0e-5
135+
lr: 6.0e-5
137136
weight_decay: 1.0e-6
138137
betas:
139138
- 0.9
140139
- 0.95
141140
sched:
142141
name: CosineAnnealing
143-
max_steps: 200
142+
warmup_steps: 0
143+
constant_steps: 10000000
144144
min_lr: 5.0e-5

0 commit comments

Comments
 (0)