-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
94 lines (77 loc) · 2.98 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Created by: Mateus Gonçalves Machado
# Based on: https://docs.cleanrl.dev/ (by Shengyi Huang)
import time
import gymnasium as gym
import numpy as np
from methods.sac import SAC, SACStrat
from utils.experiment import get_experiment, make_env
from utils.experiment import parse_args
from utils.experiment import setup_run
from utils.logger import SACLogger
def train(args, exp_name, logger: SACLogger):
envs = gym.vector.AsyncVectorEnv(
[make_env(args, i, exp_name) for i in range(args.num_envs)]
)
if args.stratified:
agent = SACStrat(
args,
envs.single_observation_space,
envs.single_action_space,
)
else:
agent = SAC(args, envs.single_observation_space, envs.single_action_space)
obs, _ = envs.reset()
for global_step in range(args.total_timesteps):
if global_step < args.learning_starts:
actions = np.array(
[envs.single_action_space.sample() for _ in range(args.num_envs)]
)
else:
actions = agent.get_action(obs)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)
logger.log_episode(infos, rewards)
# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
real_next_obs = next_obs.copy()
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
agent.replay_buffer.add(obs, actions, rewards, real_next_obs, terminations)
obs = next_obs
if args.dylam:
agent.add_episode_rewards(rewards, terminations, truncations)
agent.update_lambdas()
# ALGO LOGIC: training.
if (
global_step > args.learning_starts
and global_step % args.update_frequency == 0
):
update_actor = global_step % args.policy_frequency == 0
losses = agent.update(args.batch_size, update_actor)
if global_step % args.target_network_frequency == 0:
agent.critic_target.sync(args.tau)
if global_step % 100 == 0:
loss_dict = {
"policy_loss": losses[0],
"qf1_loss": losses[1],
"qf2_loss": losses[2],
"alpha": agent.alpha,
"alpha_loss": losses[3],
}
logger.log_losses(loss_dict)
if args.dylam:
logger.log_lambdas(agent.lambdas)
logger.push(global_step)
if global_step % 9999 == 0:
agent.save(f"models/{exp_name}/")
logger.log_artifact()
envs.close()
def main(params):
gym_name = params.gym_id.split("-")[1]
exp_name = f"{gym_name}-{params.setup}_{int(time.time())}"
logger = SACLogger(exp_name, params)
setup_run(params)
train(params, exp_name, logger)
if __name__ == "__main__":
args = parse_args()
params = get_experiment(args)
main(params)