1111def get_train_cfg (exp_name , max_iterations ):
1212
1313 train_cfg_dict = {
14+ "num_steps_per_env" : 48 ,
15+ "save_interval" : 10 ,
16+ "empirical_normalization" : True , #这个功能可以帮助稳定训练过程,特别是在观测值范围变化较大的情况下
1417 "algorithm" : {
1518 "clip_param" : 0.2 ,
1619 "desired_kl" : 0.01 ,
@@ -24,42 +27,27 @@ def get_train_cfg(exp_name, max_iterations):
2427 "schedule" : "adaptive" ,
2528 "use_clipped_value_loss" : True ,
2629 "value_loss_coef" : 1.0 ,
30+ "class_name" : "PPO" ,
2731 },
28- "init_member_classes" : {},
2932 "policy" : {
3033 "activation" : "elu" ,
3134 "actor_hidden_dims" : [512 , 256 , 128 ],
3235 "critic_hidden_dims" : [512 , 256 , 128 ],
3336 "init_noise_std" : 1.0 ,
37+ "class_name" : "ActorCritic" ,
3438 },
3539 "runner" : {
3640 "algorithm_class_name" : "PPO" ,
37- "checkpoint" : - 1 ,
3841 "experiment_name" : exp_name ,
39- "load_run" : - 1 ,
40- "log_interval" : 1 ,
41- "max_iterations" : max_iterations ,
42- "num_steps_per_env" : 48 ,
43- "policy_class_name" : "ActorCritic" ,
44- "record_interval" : - 1 ,
45- "resume" : False ,
46- "resume_path" : None ,
47- "run_name" : "" ,
48- "runner_class_name" : "OnPolicyRunner" ,
49- "save_interval" : 10 ,
50- },
51- "runner_class_name" : "OnPolicyRunner" ,
52- "seed" : 1 ,
42+ "run_name" : "zeroth-walking" ,
43+ }
5344 }
5445
5546 return train_cfg_dict
5647
5748
5849def get_cfgs ():
59- env_cfg = {
60- "num_actions" : 12 ,
61- # joint/link names
62- "default_joint_angles" : { # [rad]
50+ default_joint_angles = { # [rad]
6351 "left_elbow_yaw" : 3.14 ,
6452 "right_elbow_yaw" : 3.14 ,
6553 "right_hip_pitch" : 0.0 ,
@@ -72,21 +60,12 @@ def get_cfgs():
7260 "left_knee_pitch" : 0.0 ,
7361 "right_ankle_pitch" : 0.0 ,
7462 "left_ankle_pitch" : 0.0 ,
75- },
76- "dof_names" : [
77- "left_elbow_yaw" ,
78- "right_elbow_yaw" ,
79- "right_hip_pitch" ,
80- "left_hip_pitch" ,
81- "right_hip_yaw" ,
82- "left_hip_yaw" ,
83- "right_hip_roll" ,
84- "left_hip_roll" ,
85- "right_knee_pitch" ,
86- "left_knee_pitch" ,
87- "right_ankle_pitch" ,
88- "left_ankle_pitch" ,
89- ],
63+ }
64+ env_cfg = {
65+ "num_actions" : 12 ,
66+ # joint/link names
67+ "default_joint_angles" : default_joint_angles ,
68+ "dof_names" : list (default_joint_angles .keys ()),
9069 # PD
9170 "kp" : 20.0 ,
9271 "kd" : 0.5 ,
@@ -128,8 +107,8 @@ def get_cfgs():
128107 "num_commands" : 3 ,
129108 # "lin_vel_y_range": [-0.5, -0.5], # move forward slowly
130109 "lin_vel_y_range" : [- 0.6 , - 0.6 ], # move faster than above!
131- "lin_vel_x_range" : [0 , 0 ],
132- "ang_vel_range" : [0 , 0 ],
110+ "lin_vel_x_range" : [- 0.01 , 0.01 ],
111+ "ang_vel_range" : [- 0.01 , 0.01 ],
133112 }
134113
135114 return env_cfg , obs_cfg , reward_cfg , command_cfg
@@ -139,7 +118,7 @@ def main():
139118 parser = argparse .ArgumentParser ()
140119 parser .add_argument ("-e" , "--exp_name" , type = str , default = "zeroth-walking" )
141120 parser .add_argument ("-B" , "--num_envs" , type = int , default = 4096 )
142- parser .add_argument ("--max_iterations" , type = int , default = 100 )
121+ parser .add_argument ("--max_iterations" , type = int , default = 101 )
143122 args = parser .parse_args ()
144123
145124 gs .init (logging_level = "warning" )
0 commit comments