Skip to content

Commit 9a6e46f

Browse files
authored
feature(pu): add ddp config of dqn and onppo (opendilab#842)
* feature(pu): add pong and cartpole ddp config of dqn and onppo * fix(pu):fix atari_ppo_ddp.py * polish(pu): polish atari_dqn_ddp.py and atari_ppo_ddp.py * polish(pu): polish atari ddp configs
1 parent 580ea65 commit 9a6e46f

11 files changed

+304
-14
lines changed

ding/entry/serial_entry_onpolicy.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ding.config import read_config, compile_config
1313
from ding.policy import create_policy, PolicyFactory
1414
from ding.reward_model import create_reward_model
15-
from ding.utils import set_pkg_seed
15+
from ding.utils import set_pkg_seed, get_rank
1616

1717

1818
def serial_pipeline_onpolicy(
@@ -68,7 +68,7 @@ def serial_pipeline_onpolicy(
6868
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
6969

7070
# Create worker components: learner, collector, evaluator, replay buffer, commander.
71-
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
71+
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
7272
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
7373
collector = create_serial_collector(
7474
cfg.policy.collect.collector,

ding/worker/collector/interaction_serial_evaluator.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ def eval(
204204
'''
205205
# evaluator only work on rank0
206206
stop_flag = False
207+
episode_info = None # Initialize to ensure it's defined in all ranks
208+
207209
if get_rank() == 0:
208210
if n_episode is None:
209211
n_episode = self._default_n_episode
@@ -317,5 +319,7 @@ def eval(
317319
broadcast_object_list(objects, src=0)
318320
stop_flag, episode_info = objects
319321

320-
episode_info = to_item(episode_info)
322+
# Ensure episode_info is converted to the correct format
323+
episode_info = to_item(episode_info) if episode_info is not None else {}
324+
321325
return stop_flag, episode_info

ding/worker/collector/sample_serial_collector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from ding.envs import BaseEnvManager
99
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \
10-
broadcast_object_list, allreduce_data
10+
allreduce_data
1111
from ding.torch_utils import to_tensor, to_ndarray
1212
from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions
1313

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from easydict import EasyDict
2+
3+
pong_dqn_config = dict(
4+
exp_name='data_pong/pong_dqn_ddp_seed0',
5+
env=dict(
6+
collector_env_num=4,
7+
evaluator_env_num=4,
8+
n_evaluator_episode=8,
9+
stop_value=20,
10+
env_id='PongNoFrameskip-v4',
11+
#'ALE/Pong-v5' is available. But special setting is needed after gym make.
12+
frame_stack=4,
13+
),
14+
policy=dict(
15+
multi_gpu=True,
16+
cuda=True,
17+
priority=False,
18+
model=dict(
19+
obs_shape=[4, 84, 84],
20+
action_shape=6,
21+
encoder_hidden_size_list=[128, 128, 512],
22+
),
23+
nstep=3,
24+
discount_factor=0.99,
25+
learn=dict(
26+
update_per_collect=10,
27+
batch_size=32,
28+
learning_rate=0.0001,
29+
target_update_freq=500,
30+
),
31+
collect=dict(n_sample=96, ),
32+
eval=dict(evaluator=dict(eval_freq=4000, )),
33+
other=dict(
34+
eps=dict(
35+
type='exp',
36+
start=1.,
37+
end=0.05,
38+
decay=250000,
39+
),
40+
replay_buffer=dict(replay_buffer_size=100000, ),
41+
),
42+
),
43+
)
44+
pong_dqn_config = EasyDict(pong_dqn_config)
45+
main_config = pong_dqn_config
46+
pong_dqn_create_config = dict(
47+
env=dict(
48+
type='atari',
49+
import_names=['dizoo.atari.envs.atari_env'],
50+
),
51+
env_manager=dict(type='subprocess'),
52+
policy=dict(type='dqn'),
53+
)
54+
pong_dqn_create_config = EasyDict(pong_dqn_create_config)
55+
create_config = pong_dqn_create_config
56+
57+
if __name__ == '__main__':
58+
"""
59+
Overview:
60+
This script should be executed with <nproc_per_node> GPUs.
61+
Run the following command to launch the script:
62+
python -m torch.distributed.launch --nproc_per_node=2 ./dizoo/atari/config/serial/pong/pong_dqn_ddp_config.py
63+
"""
64+
from ding.utils import DDPContext
65+
from ding.entry import serial_pipeline
66+
with DDPContext():
67+
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(3e6))

dizoo/atari/config/serial/pong/pong_onppo_config.py renamed to dizoo/atari/config/serial/pong/pong_ppo_config.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from easydict import EasyDict
22

3-
pong_onppo_config = dict(
3+
pong_ppo_config = dict(
44
env=dict(
55
collector_env_num=8,
66
evaluator_env_num=8,
@@ -49,19 +49,19 @@
4949
eval=dict(evaluator=dict(eval_freq=5000, )),
5050
),
5151
)
52-
main_config = EasyDict(pong_onppo_config)
52+
main_config = EasyDict(pong_ppo_config)
5353

54-
pong_onppo_create_config = dict(
54+
pong_ppo_create_config = dict(
5555
env=dict(
5656
type='atari',
5757
import_names=['dizoo.atari.envs.atari_env'],
5858
),
5959
env_manager=dict(type='subprocess'),
6060
policy=dict(type='ppo'),
6161
)
62-
create_config = EasyDict(pong_onppo_create_config)
62+
create_config = EasyDict(pong_ppo_create_config)
6363

6464
if __name__ == "__main__":
65-
# or you can enter `ding -m serial_onpolicy -c pong_onppo_config.py -s 0`
65+
# or you can enter `ding -m serial_onpolicy -c pong_ppo_config.py -s 0`
6666
from ding.entry import serial_pipeline_onpolicy
6767
serial_pipeline_onpolicy((main_config, create_config), seed=0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from easydict import EasyDict
2+
3+
pong_ppo_config = dict(
4+
exp_name='data_pong/pong_ppo_ddp_seed0',
5+
env=dict(
6+
collector_env_num=8,
7+
evaluator_env_num=8,
8+
n_evaluator_episode=8,
9+
stop_value=20,
10+
env_id='PongNoFrameskip-v4',
11+
#'ALE/Pong-v5' is available. But special setting is needed after gym make.
12+
frame_stack=4,
13+
),
14+
policy=dict(
15+
multi_gpu=True,
16+
cuda=True,
17+
recompute_adv=True,
18+
action_space='discrete',
19+
model=dict(
20+
obs_shape=[4, 84, 84],
21+
action_shape=6,
22+
action_space='discrete',
23+
encoder_hidden_size_list=[64, 64, 128],
24+
actor_head_hidden_size=128,
25+
critic_head_hidden_size=128,
26+
),
27+
learn=dict(
28+
epoch_per_collect=10,
29+
update_per_collect=1,
30+
batch_size=320,
31+
learning_rate=3e-4,
32+
value_weight=0.5,
33+
entropy_weight=0.001,
34+
clip_ratio=0.2,
35+
adv_norm=True,
36+
value_norm=True,
37+
# for ppo, when we recompute adv, we need the key done in data to split traj, so we must
38+
# use ignore_done=False here,
39+
# but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
40+
# for halfcheetah, the length=1000
41+
ignore_done=False,
42+
grad_clip_type='clip_norm',
43+
grad_clip_value=0.5,
44+
),
45+
collect=dict(
46+
n_sample=3200,
47+
unroll_len=1,
48+
discount_factor=0.99,
49+
gae_lambda=0.95,
50+
),
51+
eval=dict(evaluator=dict(eval_freq=1000, )),
52+
),
53+
)
54+
main_config = EasyDict(pong_ppo_config)
55+
56+
pong_ppo_create_config = dict(
57+
env=dict(
58+
type='atari',
59+
import_names=['dizoo.atari.envs.atari_env'],
60+
),
61+
env_manager=dict(type='subprocess'),
62+
policy=dict(type='ppo'),
63+
)
64+
create_config = EasyDict(pong_ppo_create_config)
65+
66+
if __name__ == "__main__":
67+
"""
68+
Overview:
69+
This script should be executed with <nproc_per_node> GPUs.
70+
Run the following command to launch the script:
71+
python -m torch.distributed.launch --nproc_per_node=2 ./dizoo/atari/config/serial/pong/pong_ppo_ddp_config.py
72+
"""
73+
from ding.utils import DDPContext
74+
from ding.entry import serial_pipeline_onpolicy
75+
with DDPContext():
76+
serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(3e6))

dizoo/atari/example/atari_dqn_ddp.py

+6
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,10 @@ def main():
5656

5757

5858
if __name__ == "__main__":
59+
"""
60+
Overview:
61+
This script should be executed with <nproc_per_node> GPUs.
62+
Run the following command to launch the script:
63+
python -m torch.distributed.launch --nproc_per_node=2 ./dizoo/atari/example/atari_dqn_ddp.py
64+
"""
5965
main()

dizoo/atari/example/atari_ppo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
gae_estimator, termination_checker
1212
from ding.utils import set_pkg_seed
1313
from dizoo.atari.envs.atari_env import AtariEnv
14-
from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config
14+
from dizoo.atari.config.serial.pong.pong_ppo_config import main_config, create_config
1515

1616

1717
def main():

dizoo/atari/example/atari_ppo_ddp.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
from ding.framework.context import OnlineRLContext
1010
from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
1111
gae_estimator, ddp_termination_checker, online_logger
12-
from ding.utils import set_pkg_seed, DistContext, get_rank, get_world_size
12+
from ding.utils import set_pkg_seed, DDPContext, get_rank, get_world_size
1313
from dizoo.atari.envs.atari_env import AtariEnv
14-
from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config
14+
from dizoo.atari.config.serial.pong.pong_ppo_config import main_config, create_config
1515

1616

1717
def main():
1818
logging.getLogger().setLevel(logging.INFO)
19-
with DistContext():
19+
with DDPContext():
2020
rank, world_size = get_rank(), get_world_size()
2121
main_config.example = 'pong_ppo_seed0_ddp_avgsplit'
2222
main_config.policy.multi_gpu = True
@@ -45,12 +45,19 @@ def main():
4545
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
4646
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
4747
task.use(gae_estimator(cfg, policy.collect_mode))
48-
task.use(multistep_trainer(cfg, policy.learn_mode))
48+
task.use(multistep_trainer(policy.learn_mode))
4949
if rank == 0:
5050
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
51+
task.use(online_logger(record_train_iter=True))
5152
task.use(ddp_termination_checker(max_env_step=int(1e7), rank=rank))
5253
task.run()
5354

5455

5556
if __name__ == "__main__":
57+
"""
58+
Overview:
59+
This script should be executed with <nproc_per_node> GPUs.
60+
Run the following command to launch the script:
61+
python -m torch.distributed.launch --nproc_per_node=2 ./dizoo/atari/example/atari_ppo_ddp.py
62+
"""
5663
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from easydict import EasyDict
2+
3+
cartpole_dqn_config = dict(
4+
exp_name='cartpole_dqn_seed0',
5+
env=dict(
6+
collector_env_num=8,
7+
evaluator_env_num=5,
8+
n_evaluator_episode=5,
9+
stop_value=195,
10+
replay_path='cartpole_dqn_seed0/video',
11+
),
12+
policy=dict(
13+
multi_gpu=True,
14+
cuda=True,
15+
model=dict(
16+
obs_shape=4,
17+
action_shape=2,
18+
encoder_hidden_size_list=[128, 128, 64],
19+
dueling=True,
20+
# dropout=0.1,
21+
),
22+
nstep=1,
23+
discount_factor=0.97,
24+
learn=dict(
25+
update_per_collect=5,
26+
batch_size=64,
27+
learning_rate=0.001,
28+
),
29+
collect=dict(n_sample=8),
30+
eval=dict(evaluator=dict(eval_freq=40, )),
31+
other=dict(
32+
eps=dict(
33+
type='exp',
34+
start=0.95,
35+
end=0.1,
36+
decay=10000,
37+
),
38+
replay_buffer=dict(replay_buffer_size=20000, ),
39+
),
40+
),
41+
)
42+
cartpole_dqn_config = EasyDict(cartpole_dqn_config)
43+
main_config = cartpole_dqn_config
44+
cartpole_dqn_create_config = dict(
45+
env=dict(
46+
type='cartpole',
47+
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
48+
),
49+
env_manager=dict(type='subprocess'),
50+
policy=dict(type='dqn'),
51+
)
52+
cartpole_dqn_create_config = EasyDict(cartpole_dqn_create_config)
53+
create_config = cartpole_dqn_create_config
54+
55+
if __name__ == "__main__":
56+
"""
57+
Overview:
58+
This script should be executed with <nproc_per_node> GPUs.
59+
Run the following command to launch the script:
60+
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./dizoo/classic_control/cartpole/config/cartpole_dqn_ddp_config.py
61+
"""
62+
from ding.utils import DDPContext
63+
from ding.entry import serial_pipeline
64+
with DDPContext():
65+
serial_pipeline((main_config, create_config), seed=0)
66+

0 commit comments

Comments
 (0)