Skip to content

Commit 9f9af42

Browse files
authored
Merge pull request #54 from huggingface/xrsrke/feature_doremi_new_codebase
[Feature] DoReMi
2 parents 53c3064 + 0dd67f7 commit 9f9af42

32 files changed

+2951
-22
lines changed

.github/workflows/3d_parallelism_unit_tests.yaml

+13-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
python -c "import torch; print('torch:', torch.__version__, torch)"
3838
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
3939
40-
- name: Instal nanotron
40+
- name: Install nanotron's dependencies
4141
run: |
4242
python -m pip install --upgrade pip
4343
pip install packaging
@@ -49,7 +49,7 @@ jobs:
4949
- name: Show installed libraries and their versions
5050
run: pip freeze | tee installed.txt
5151

52-
- name: Run tests
52+
- name: Run nanotron tests
5353
# NOTE: -m "not fa2" will run all the unit tests that don't have the mark
5454
# "fa2" (these are FA2-related tests, we can't run it on T4)
5555
run: |
@@ -61,3 +61,14 @@ jobs:
6161
--ignore tests/fp8 \
6262
--verbose \
6363
tests/
64+
# NOTE: T4 can't run FA2, DoReMi's LLaMa needs FÀ
65+
# - name: Run DoReMi tests
66+
# # NOTE: -m "not fa2" will run all the unit tests that don't have the mark
67+
# # "fa2" (these are FA2-related tests, we can't run it on T4)
68+
# run: |
69+
# pip install -r examples/doremi/requirements.txt && \
70+
# pytest \
71+
# --color=yes \
72+
# --durations=0 \
73+
# --verbose \
74+
# examples/doremi/tests/

.pre-commit-config.yaml

+14
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,17 @@ repos:
1919
args:
2020
- --fix
2121
- --exit-non-zero-on-fix
22+
- repo: https://github.com/PyCQA/isort
23+
rev: 5.12.0
24+
hooks:
25+
- id: isort
26+
args:
27+
- --profile=black
28+
- --skip-glob=wandb/**/*
29+
- --thirdparty=wandb
30+
- repo: https://github.com/codespell-project/codespell
31+
rev: v2.1.0
32+
hooks:
33+
- id: codespell
34+
args:
35+
- --ignore-words-list=nd,reacher,thist,ths,magent,ba,fo

examples/doremi/README.md

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining
2+
Paper: https://arxiv.org/abs/2305.10429
3+
4+
You might think that one of the key ways to speed up pretraining performance is either by finding more quality data, increasing FLOPs, or changing the model architecture, but actually, these are not the only ways. DoReMi shows that, given the same source of training data, a model using an optimal data mixing strategy could outperform its counterpart with random sampling in at least 70% domains or all domains and downstream evaluations without any knowledge of the downstream evaluation tasks.
5+
6+
In our implementation, the experiment results show that doremi outperforms 15 out of 22 domains on test set and has a lower average cross entropy test loss. Here are the comparison of the training losses between:
7+
8+
- 280M proxy and reference model [[link]](https://wandb.ai/neuralink/nanotron/reports/-DoReMi-280m-reference-vs-280m-proxy-s-training--Vmlldzo2NzYwNTU1)
9+
- 2.5B reference and tuned weight models [[link]](https://wandb.ai/neuralink/nanotron/reports/-DoReMi-2-5B-tuned-weights-vs-2-5B-token-ratio-domain-weights-s-training--Vmlldzo2NzYwNzE2)
10+
- And how the 280M proxy model's domain weights change during training [[link]](https://wandb.ai/neuralink/nanotron/runs/j9ojbso1?workspace=user-neuralink)
11+
12+
13+
![The domains in which we outperform](./assets/outperform.png)
14+
15+
16+
![The domains in which we don't outperform](./assets/not_outperform.png)
17+
18+
19+
![Domain weights comparison](./assets/domain_weights.png)
20+
21+
**Notes**: The graph above represent test losses, not validation losses (this is a typo 🫠). The x-axis doesn't mean anything, it simply means sampling another batch of evaluation data from the same final checkpoint.
22+
23+
### How it works
24+
25+
- Step 0: Preprocessing data
26+
27+
- Step 1: Train a small reference model using uniform sampling from each domain (for a given global batch size, you equally sample `x` samples across all domains, or in some cases, a domain has a smaller amount of samples than other domains. This leads to some domains running out of samples early, so you could enable automatic domain weights based on the token count).
28+
29+
```bash
30+
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 examples/doremi/train_reference.py --config-file examples/doremi/configs/config_280m_llama.yaml
31+
```
32+
33+
- Step 2: Use the trained reference model from step 1 to train an identical model, and use its performance to dynamically tune the domain weights during training.
34+
35+
```bash
36+
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 examples/doremi/train_doremi.py --config-file examples/doremi/configs/config_280m_llama_proxy.yaml
37+
```
38+
39+
- Step 3: Nanotron saves the domain weights in the model checkpoint. Now, calculate the optimal domain weights by averaging the domain weights across all training steps from step 1: $\bar{\alpha}=\frac{1}{T} \sum_{i=1}^T \alpha_t$.
40+
41+
42+
```python
43+
44+
import torch
45+
46+
domain_weights = torch.load("checkpoints/doremi/proxy-280m-llama/doremi_domain_weights_100000.pt")
47+
48+
total_weights = sum(d["domain_weights"] for d in domain_weights)
49+
avg_weights = total_weights / len(domain_weights)
50+
```
51+
52+
Then, set these `avg_weights` in the config of the larger run in the `doremi` section.
53+
54+
- Step 4: Use the optimized domain weights from step 3 to train a larger model (could be 10x to 30x larger).
55+
56+
```bash
57+
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 examples/doremi/train_reference.py --config-file examples/doremi/configs/config_2.8b_llama_with_tuned_weights.yaml
58+
```
59+
60+
### Dataset
61+
62+
We expect the dataset path to link to a folder that already has tokenized data in the structure:
63+
64+
```
65+
dataset
66+
domain_0
67+
...
68+
domain_1
69+
...
70+
domain_2
71+
...
72+
```
73+
74+
For each tokenized sample, we expect a column name `domain_ids` which contains the domain index of that domain in the dataset. For example, if a sample is from the third domain, it should have a `domain_ids` equal to 2, and the folder names are the same as the domain names that you provide in the DoReMi config
75+
76+
### The Experiment
77+
78+
We first train a small 280M model for 70k steps on the Pile to obtain a reference model. Then, we use the reference model to tune the domain weights of that same model, where we train from scratch (aka: proxy training) for 70k steps.
79+
80+
The reference model's performance is used as a baseline to determine how difficult a domain is, so that the DoReMi algorithm can adjust the model weights accordingly on-the-fly. Once we obtain the optimized weights, we use them to train a 2.5B model (9x larger than the reference model) for 70k steps and train another one based on the token ratio domain weights (this is technically the same as random sampling, since the probability of a token occurring in the training data is the same as its token ratio).
81+
82+
For evaluation, we do uniform sampling on the test set to evaluate a 2.5B model with optimized domain weights and token ratio domain weights. For more details on hyperparameters, please check the config YAML. Here are the model checkpoints in the experiment:
83+
- 280M LLaMA reference model: https://huggingface.co/nanotron/doremi-llama-280m-reference
84+
- 280m LLAMA proxy model: https://huggingface.co/nanotron/doremi-llama-280m-proxy
85+
- 2.5B LLaMA reference model: https://huggingface.co/nanotron/doremi-llama-2.5b-reference
86+
- 2.5B llama trained using the optimized weights: https://huggingface.co/nanotron/doremi-llama-2.5b-optimized-weights
87+
88+
and the dataset: https://huggingface.co/datasets/nanotron/the-pile-for-doremi

examples/doremi/__init__.py

Whitespace-only changes.
407 KB
Loading
523 KB
Loading

examples/doremi/assets/outperform.png

694 KB
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
checkpoints:
2+
checkpoint_interval: 1000
3+
checkpoints_path: checkpoints/doremi/big-run-02/reference-2.8b-llama
4+
checkpoints_path_is_shared_file_system: true
5+
resume_checkpoint_path: checkpoints/doremi/big-run-02/reference-2.8b-llama/70000
6+
save_initial_state: false
7+
8+
doremi:
9+
domain_names: Pile-CC, Github, OpenWebText2, StackExchange, Wikipedia (en), PubMed Abstracts, USPTO Backgrounds, FreeLaw, PubMed Central, Enron Emails, HackerNews, NIH ExPorter, Books3, ArXiv, DM Mathematics, OpenSubtitles, Gutenberg (PG-19), Ubuntu IRC, BookCorpus2, EuroParl, YoutubeSubtitles, PhilPapers
10+
# domain_weights: 0.1500, 0.1213, 0.0872, 0.0631, 0.0340, 0.0240, 0.0281, 0.0594, 0.1599, 0.0015, 0.0058, 0.0021, 0.0605, 0.1136, 0.0209, 0.0154, 0.0202, 0.0037, 0.0065, 0.0100, 0.0093, 0.0036
11+
12+
data:
13+
dataset:
14+
dataset_overwrite_cache: false
15+
dataset_processing_num_proc_per_process: 1
16+
hf_dataset_config_name: null
17+
18+
hf_dataset_or_datasets: project_data/doremi/datasets/the_pile_raw/tokenized_data/train
19+
20+
num_loading_workers: 1
21+
seed: 42
22+
general:
23+
benchmark_csv_path: null
24+
consumed_train_samples: null
25+
ignore_sanity_checks: true
26+
project: nanotron
27+
run: train_2.8b_llama_reference
28+
seed: 42
29+
step: null
30+
logging:
31+
iteration_step_info_interval: 1
32+
log_level: info
33+
log_level_replica: info
34+
model:
35+
ddp_bucket_cap_mb: 120
36+
dtype: bfloat16
37+
init_method:
38+
std: 0.025
39+
make_vocab_size_divisible_by: 1
40+
model_config:
41+
bos_token_id: 1
42+
eos_token_id: 2
43+
hidden_act: silu
44+
# NOTE: only change hidden_size, intermediate_size,
45+
# num_attention_heads, num_key_value_heads and num_hidden_layers
46+
hidden_size: 4096
47+
initializer_range: 0.02
48+
intermediate_size: 24576
49+
is_llama_config: true
50+
max_position_embeddings: 256
51+
num_attention_heads: 32
52+
num_hidden_layers: 6
53+
# num_hidden_layers: 1
54+
num_key_value_heads: 16
55+
pad_token_id: null
56+
pretraining_tp: 1
57+
rms_norm_eps: 1.0e-05
58+
rope_scaling: null
59+
tie_word_embeddings: true
60+
use_cache: true
61+
vocab_size: 49152
62+
optimizer:
63+
accumulate_grad_in_fp32: true
64+
adam_beta1: 0.9
65+
adam_beta2: 0.95
66+
adam_eps: 1.0e-08
67+
clip_grad: 1.0
68+
learning_rate_scheduler:
69+
learning_rate: 0.0003
70+
lr_decay_steps: 8
71+
lr_decay_style: cosine
72+
lr_warmup_steps: 2
73+
lr_warmup_style: linear
74+
min_decay_lr: 1.0e-05
75+
torch_adam_is_fused: true
76+
weight_decay: 0.01
77+
zero_stage: 0
78+
parallelism:
79+
# dp: 8
80+
# # dp: 2
81+
# pp: 1
82+
# tp: 8
83+
# # tp: 2
84+
85+
# NOTE: for running eval
86+
dp: 8
87+
pp: 1
88+
tp: 2
89+
90+
pp_engine: 1f1b
91+
tp_linear_async_communication: true
92+
tp_mode: REDUCE_SCATTER
93+
profiler: null
94+
tokenizer:
95+
tokenizer_max_length: null
96+
tokenizer_name_or_path: gpt2
97+
tokenizer_revision: null
98+
tokens:
99+
# batch_accumulation_per_replica * micro_batch_size * dp = 4 * 8 * 16 = 512
100+
# batch_accumulation_per_replica * micro_batch_size * dp = 16 * 8 * 4 = 512
101+
# 240 * 1024 = 245760
102+
# the doremi paper do 500k tokens per batch
103+
# batch_accumulation_per_replica: 16
104+
# NOTE: some weird bug, where if you run batch_accumulation_per_replica=16
105+
# it results no samples from some domainsbatch_accumulation_per_replica
106+
107+
# NOTE: this causes some domain losses are 0
108+
# batch_accumulation_per_replica: 8
109+
# micro_batch_size: 8
110+
111+
batch_accumulation_per_replica: 1
112+
micro_batch_size: 64
113+
114+
limit_test_batches: 0
115+
# NOTE: this is like the number of microbatches for validation
116+
limit_val_batches: 1
117+
sequence_length: 1024
118+
# train_steps: 1000
119+
# train_steps: 1579
120+
train_steps: 70_000
121+
val_check_interval: 2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
checkpoints:
2+
checkpoint_interval: 5000
3+
checkpoints_path: checkpoints/doremi/big-run-02/reference-2.8b-llama-tuned-weights_with_100k_proxy
4+
checkpoints_path_is_shared_file_system: true
5+
resume_checkpoint_path: checkpoints/doremi/big-run-02/reference-2.8b-llama-tuned-weights_with_100k_proxy/70000
6+
save_initial_state: false
7+
8+
doremi:
9+
domain_names: Pile-CC, Github, OpenWebText2, StackExchange, Wikipedia (en), PubMed Abstracts, USPTO Backgrounds, FreeLaw, PubMed Central, Enron Emails, HackerNews, NIH ExPorter, Books3, ArXiv, DM Mathematics, OpenSubtitles, Gutenberg (PG-19), Ubuntu IRC, BookCorpus2, EuroParl, YoutubeSubtitles, PhilPapers
10+
# domain_weights: 0.2333, 0.0700, 0.1154, 0.0528, 0.0665, 0.0670, 0.0366, 0.0571, 0.0451, 0.0036, 0.0087, 0.0078, 0.0708, 0.0656, 0.0034, 0.0048, 0.0222, 0.0084, 0.0038, 0.0186, 0.0149, 0.0235
11+
12+
data:
13+
dataset:
14+
dataset_overwrite_cache: false
15+
dataset_processing_num_proc_per_process: 1
16+
hf_dataset_config_name: null
17+
18+
hf_dataset_or_datasets: project_data/doremi/datasets/the_pile_raw/tokenized_data/train
19+
20+
num_loading_workers: 1
21+
seed: 42
22+
general:
23+
benchmark_csv_path: null
24+
consumed_train_samples: null
25+
ignore_sanity_checks: true
26+
project: nanotron
27+
run: train_tuned_2.8b_model
28+
seed: 42
29+
step: null
30+
logging:
31+
iteration_step_info_interval: 1
32+
log_level: info
33+
log_level_replica: info
34+
model:
35+
ddp_bucket_cap_mb: 120
36+
dtype: bfloat16
37+
init_method:
38+
std: 0.025
39+
make_vocab_size_divisible_by: 1
40+
model_config:
41+
bos_token_id: 1
42+
eos_token_id: 2
43+
hidden_act: silu
44+
hidden_size: 4096
45+
initializer_range: 0.02
46+
intermediate_size: 24576
47+
is_llama_config: true
48+
max_position_embeddings: 256
49+
num_attention_heads: 32
50+
# num_hidden_layers: 40
51+
num_hidden_layers: 6
52+
num_key_value_heads: 16
53+
pad_token_id: null
54+
pretraining_tp: 1
55+
rms_norm_eps: 1.0e-05
56+
rope_scaling: null
57+
tie_word_embeddings: true
58+
use_cache: true
59+
vocab_size: 49152
60+
optimizer:
61+
accumulate_grad_in_fp32: true
62+
adam_beta1: 0.9
63+
adam_beta2: 0.95
64+
adam_eps: 1.0e-08
65+
clip_grad: 1.0
66+
learning_rate_scheduler:
67+
learning_rate: 0.0003
68+
lr_decay_steps: 8
69+
lr_decay_style: cosine
70+
lr_warmup_steps: 2
71+
lr_warmup_style: linear
72+
min_decay_lr: 1.0e-05
73+
torch_adam_is_fused: true
74+
weight_decay: 0.01
75+
zero_stage: 0
76+
parallelism:
77+
# dp: 8
78+
# pp: 1
79+
# tp: 8
80+
# tp: 2
81+
82+
# NOTE: for running eval
83+
dp: 1
84+
pp: 1
85+
tp: 8
86+
pp_engine: 1f1b
87+
tp_linear_async_communication: true
88+
tp_mode: REDUCE_SCATTER
89+
profiler: null
90+
tokenizer:
91+
tokenizer_max_length: null
92+
tokenizer_name_or_path: gpt2
93+
tokenizer_revision: null
94+
tokens:
95+
# batch_accumulation_per_replica * micro_batch_size * dp = 4 * 8 * 16 = 512
96+
# batch_accumulation_per_replica * micro_batch_size * dp = 16 * 8 * 4 = 512
97+
# batch_accumulation_per_replica * micro_batch_size * dp = 8 * 8 * 8 = 512 (this one)
98+
# 240 * 1024 = 245760
99+
# the doremi paper do 500k tokens per batch
100+
# batch_accumulation_per_replica: 16
101+
102+
# NOTE: some weird bug, where if you run batch_accumulation_per_replica=16
103+
# it results no samples from some domains
104+
105+
# NOTE: this causes some domain losses are 0
106+
# batch_accumulation_per_replica: 8
107+
# micro_batch_size: 8
108+
109+
batch_accumulation_per_replica: 1
110+
micro_batch_size: 64
111+
112+
limit_test_batches: 0
113+
limit_val_batches: 1
114+
sequence_length: 1024
115+
# train_steps: 1000
116+
# train_steps: 70_000
117+
# train_steps: 70_000
118+
train_steps: 70_010
119+
val_check_interval: -1

0 commit comments

Comments
 (0)