Skip to content

Commit 59b3c65

Browse files
authored
Hyperparameter Optimization with Ray Tune and Weights and Biases (#76)
1 parent 3900999 commit 59b3c65

12 files changed

+650
-11
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,5 @@ OUT/
149149

150150
examples/experiments/grounded_program_synthesis/dataset
151151
ckpts/
152+
153+
ray_results/

configs/ppo_config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ model:
66

77
train:
88
seq_length: 48 # Size of LM context
9-
epochs: 1000 # Train for max(epochs, total_steps)
9+
epochs: 1000 # Train for max(epochs, total_steps)
1010
total_steps: 10000 # Train for max(epochs, total_steps)
1111
batch_size: 128 # batch size
1212

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
tune_config:
2+
mode: "max"
3+
metric: "mean_reward"
4+
search_alg: "bohb" # random
5+
scheduler: "hyperbandforbohb" # fifo
6+
num_samples: 15
7+
max_concurrent_trials: null
8+
time_budget_s: null
9+
reuse_actors: null
10+
11+
model:
12+
model_path: "lvwerra/gpt2-imdb" # Name of hf model to load
13+
tokenizer_path: "gpt2" # Name of hf tokenizer to load
14+
model_type: "AcceleratePPOModel" # Name of accelerate model type to load
15+
num_layers_unfrozen: # Number of bottom layers to freeze during training
16+
strategy: "choice"
17+
values: [2, 3]
18+
19+
train:
20+
seq_length: # Size of LM context
21+
strategy: "choice"
22+
values: [36, 48, 52]
23+
epochs: # Train for max(epochs, total_steps)
24+
strategy: "choice"
25+
values: [80, 100, 120]
26+
total_steps: 10000 # Train for max(epochs, total_steps)
27+
batch_size: 128 # batch size
28+
29+
lr_init: 1.412e-4 # init learning rate
30+
lr_target: 1.412e-4 # target final learning rate
31+
opt_betas: [0.9, 0.95] # adam betas
32+
opt_eps: 1.0e-8 # adam eps
33+
weight_decay: 1.0e-6 # weight decay param
34+
35+
checkpoint_interval: 10000 # checkpoint interval
36+
eval_interval: 4 # eval interval
37+
38+
pipeline: "PPOPipeline" # prompt pipeline to load
39+
orchestrator: "PPOOrchestrator" # orchestrator to load
40+
project_name: "trlx-hyperopt-bohb"
41+
42+
method:
43+
name: 'ppoconfig' # Name of RL method config
44+
num_rollouts: # Number of rollouts to collect per epoch
45+
strategy: "choice"
46+
values: [96, 128]
47+
chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator
48+
ppo_epochs: # Number of ppo epochs
49+
strategy: "randint"
50+
values: [3, 6] # 3 is inclusive, 6 is exclusive
51+
init_kl_coef: # init kl coefficient
52+
strategy: "quniform"
53+
values: [0.1, 0.3, 0.1]
54+
target: 6 # target kl coefficient, set None for fixed kl coef
55+
horizon: 10000 # PPO horizon
56+
gamma: 1 # PPO discount
57+
lam: # PPO lambda
58+
strategy: "uniform"
59+
values: [0.93, 0.98]
60+
cliprange: 0.2 # clip range
61+
cliprange_value: 0.2 # clip range
62+
vf_coef: 2.3 # value term weight
63+
gen_kwargs:
64+
max_length: 48 # LM max sample gen length
65+
min_length: 48 # LM min sample gen length
66+
top_k: 0.0 # top k
67+
top_p: 1.0 # top p
68+
do_sample: True # sample

setup.cfg

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ install_requires =
2020
transformers>=4.21.2
2121
tqdm
2222
wandb
23+
ray>=2.0.1
24+
tabulate>=0.9.0
2325

2426
[options.extras_require]
2527
dev =

train_sweep.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Usage: python train_sweep.py --config configs/ray_tune_configs/ppo_config.yml --example-name ppo_sentiments
2+
import wandb
3+
import argparse
4+
from pathlib import Path
5+
6+
import ray
7+
from ray.air import session
8+
from ray import tune
9+
10+
import trlx
11+
from trlx.ray_tune import load_ray_yaml
12+
from trlx.ray_tune import get_param_space
13+
from trlx.ray_tune import get_tune_config
14+
from trlx.ray_tune import get_train_function
15+
from trlx.ray_tune.wandb import log_trials, create_report
16+
17+
from ray.tune.logger import JsonLoggerCallback
18+
from ray.tune.logger import CSVLoggerCallback
19+
20+
21+
def tune_function(train_function, param_space: dict, tune_config: dict, resources: dict):
22+
tuner = tune.Tuner(
23+
tune.with_resources(train_function, resources=resources),
24+
param_space=param_space,
25+
tune_config=tune.TuneConfig(**tune_config),
26+
run_config = ray.air.RunConfig(
27+
local_dir="ray_results",
28+
callbacks=[CSVLoggerCallback()]
29+
),
30+
)
31+
32+
results = tuner.fit()
33+
34+
log_trials(
35+
tuner._local_tuner.get_experiment_checkpoint_dir(),
36+
param_space["train"]["project_name"]
37+
)
38+
39+
create_report(
40+
param_space,
41+
tune_config,
42+
Path(tuner._local_tuner.get_experiment_checkpoint_dir()).stem,
43+
results.get_best_result().config
44+
)
45+
46+
print("Best hyperparameters found were: ", results.get_best_result().config)
47+
48+
49+
if __name__ == "__main__":
50+
parser = argparse.ArgumentParser()
51+
parser.add_argument(
52+
"--example-name", type=str, default="ppo_sentiments", help="Name of the example"
53+
)
54+
parser.add_argument(
55+
"--config", type=str, default=None, required=True, help="The config file defining the param_space."
56+
)
57+
parser.add_argument(
58+
"--num-cpus", type=int, default=4, help="Number of CPUs to use per exp."
59+
)
60+
parser.add_argument(
61+
"--num-gpus", type=int, default=1, help="Number of GPUs to use per exp."
62+
)
63+
parser.add_argument(
64+
"--server-address",
65+
type=str,
66+
default=None,
67+
required=False,
68+
help="The address of server to connect to if using Ray Client.",
69+
)
70+
71+
args, _ = parser.parse_known_args()
72+
73+
# Read config and parse it
74+
config = load_ray_yaml(args.config)
75+
tune_config = get_tune_config(config)
76+
param_space = get_param_space(config)
77+
78+
# Initialize Ray.
79+
if args.server_address:
80+
ray.init(address=f"ray://{args.server_address}")
81+
else:
82+
ray.init()
83+
84+
resources = {
85+
"cpu": args.num_cpus,
86+
"gpu": args.num_gpus,
87+
}
88+
89+
# Register the training function that will be used for training the model.
90+
train_function = get_train_function(args.example_name)
91+
tune.register_trainable("train_function", train_function)
92+
93+
tune_function(train_function, param_space, tune_config, resources)
94+
95+
# Shut down Ray.
96+
ray.shutdown()

trlx/data/configs.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,21 @@ def to_dict(self):
142142
"""
143143
Convert TRLConfig to dictionary.
144144
"""
145-
data = self.model.__dict__.copy()
146-
data.update(self.train.__dict__)
147-
data.update(self.method.__dict__)
145+
data = {
146+
"model": self.model.__dict__,
147+
"train": self.train.__dict__,
148+
"method": self.method.__dict__,
149+
}
150+
148151
return data
152+
153+
@classmethod
154+
def from_dict(cls, config_dict: dict):
155+
"""
156+
Convert dictionary to TRLConfig.
157+
"""
158+
return cls(
159+
ModelConfig.from_dict(config_dict["model"]),
160+
TrainConfig.from_dict(config_dict["train"]),
161+
get_method(config_dict["method"]["name"]).from_dict(config_dict["method"]),
162+
)

trlx/model/accelerate_base_model.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@
1818
else:
1919
from tqdm import tqdm
2020

21+
import ray
22+
from ray.air import session
23+
24+
25+
def parse_results_for_session(results: dict):
26+
for k, v in results.items():
27+
if isinstance(v, torch.Tensor):
28+
results[k] = float(v)
29+
30+
return results
31+
2132

2233
@register_model
2334
class AccelerateRLModel(BaseRLModel):
@@ -63,7 +74,7 @@ def __init__(self, config, train_mode=True):
6374
for m in gpt_blocks_to_freeze:
6475
m.requires_grad_(False)
6576

66-
if self.accelerator.is_main_process:
77+
if self.accelerator.is_main_process and not ray.is_initialized():
6778
self.accelerator.init_trackers(
6879
project_name=self.config.train.project_name,
6980
config=self.config.to_dict(),
@@ -199,9 +210,8 @@ def evaluate(self):
199210
columns_data.append(values)
200211

201212
rows = list(zip(*columns_data))
202-
stats["samples"] = wandb.Table(columns=columns, rows=rows)
203-
204-
print(rows[0])
213+
if not ray.is_initialized():
214+
stats["samples"] = wandb.Table(columns=columns, rows=rows)
205215

206216
return stats
207217

@@ -246,7 +256,14 @@ def learn(self):
246256
"backward_time": backward_time,
247257
}
248258
)
249-
self.accelerator.log(results)
259+
260+
if not ray.is_initialized():
261+
self.accelerator.log(results)
262+
263+
# Report the metrics to Ray Tune.
264+
if ray.is_initialized():
265+
tmp_results = parse_results_for_session(results)
266+
session.report(tmp_results)
250267

251268
desc = ", ".join(f"{k}: {v:.2f}" for k, v in stats.items())
252269
tbar.set_description(desc)

trlx/orchestrator/ppo_orchestrator.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from trlx.utils import Clock
1111
from trlx.utils.modeling import logprobs_from_logits
1212

13+
import ray
14+
from ray.air import session
15+
1316

1417
@register_orchestrator
1518
class PPOOrchestrator(Orchestrator):
@@ -124,7 +127,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
124127
ppo_rl_elements += new_ppo_rl_elements
125128

126129
stats = {"exp_time": exp_time}
127-
self.rl_model.accelerator.log(stats, step=iter_count)
130+
131+
if not ray.is_initialized():
132+
self.rl_model.accelerator.log(stats, step=iter_count)
128133

129134
# Push samples and rewards to model's rollout storage
130135
self.rl_model.push_to_store(ppo_rl_elements)

0 commit comments

Comments
 (0)