Skip to content

Commit ee255f9

Browse files
committed
update rsl_rl to v2.0.1
1 parent 5018186 commit ee255f9

File tree

5 files changed

+30
-42
lines changed

5 files changed

+30
-42
lines changed

examples/model_100.pt

0 Bytes
Binary file not shown.

sim/genesis/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# 安装rsl_rl。
55
```
66
git clone https://github.com/leggedrobotics/rsl_rl
7-
cd rsl_rl && git checkout v1.0.2 && pip install -e .
7+
cd rsl_rl && git checkout v2.0.1 && pip install -e .
88
```
99

1010
# 安装tensorboard。

sim/genesis/zeroth_env.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,15 @@ def step(self, actions):
181181
self.last_actions[:] = self.actions[:]
182182
self.last_dof_vel[:] = self.dof_vel[:]
183183

184-
return self.obs_buf, None, self.rew_buf, self.reset_buf, self.extras
184+
return self.obs_buf, self.rew_buf, self.reset_buf, {
185+
"observations": {
186+
"critic": self.obs_buf
187+
},
188+
**self.extras
189+
}
185190

186191
def get_observations(self):
187-
return self.obs_buf
192+
return self.obs_buf, {"observations": {"critic": self.obs_buf}}
188193

189194
def get_privileged_observations(self):
190195
return None

sim/genesis/zeroth_eval.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
def run_sim(env, policy, obs):
1212
while True:
1313
actions = policy(obs)
14-
obs, _, rews, dones, infos = env.step(actions)
14+
obs, _, _, _ = env.step(actions)
1515

1616
def main():
1717
parser = argparse.ArgumentParser()
@@ -23,6 +23,10 @@ def main():
2323

2424
log_dir = f"logs/{args.exp_name}"
2525
env_cfg, obs_cfg, reward_cfg, command_cfg, train_cfg = pickle.load(open(f"logs/{args.exp_name}/cfgs.pkl", "rb"))
26+
# Add missing class_name fields
27+
train_cfg["algorithm"]["class_name"] = "PPO"
28+
train_cfg["policy"]["class_name"] = "ActorCritic"
29+
print("train_cfg:", train_cfg) # Add debug print
2630
reward_cfg["reward_scales"] = {}
2731

2832
env = ZerothEnv(

sim/genesis/zeroth_train.py

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
def get_train_cfg(exp_name, max_iterations):
1212

1313
train_cfg_dict = {
14+
"num_steps_per_env": 48,
15+
"save_interval": 10,
16+
"empirical_normalization": True, #这个功能可以帮助稳定训练过程,特别是在观测值范围变化较大的情况下
1417
"algorithm": {
1518
"clip_param": 0.2,
1619
"desired_kl": 0.01,
@@ -24,42 +27,27 @@ def get_train_cfg(exp_name, max_iterations):
2427
"schedule": "adaptive",
2528
"use_clipped_value_loss": True,
2629
"value_loss_coef": 1.0,
30+
"class_name": "PPO",
2731
},
28-
"init_member_classes": {},
2932
"policy": {
3033
"activation": "elu",
3134
"actor_hidden_dims": [512, 256, 128],
3235
"critic_hidden_dims": [512, 256, 128],
3336
"init_noise_std": 1.0,
37+
"class_name": "ActorCritic",
3438
},
3539
"runner": {
3640
"algorithm_class_name": "PPO",
37-
"checkpoint": -1,
3841
"experiment_name": exp_name,
39-
"load_run": -1,
40-
"log_interval": 1,
41-
"max_iterations": max_iterations,
42-
"num_steps_per_env": 48,
43-
"policy_class_name": "ActorCritic",
44-
"record_interval": -1,
45-
"resume": False,
46-
"resume_path": None,
47-
"run_name": "",
48-
"runner_class_name": "OnPolicyRunner",
49-
"save_interval": 10,
50-
},
51-
"runner_class_name": "OnPolicyRunner",
52-
"seed": 1,
42+
"run_name": "zeroth-walking",
43+
}
5344
}
5445

5546
return train_cfg_dict
5647

5748

5849
def get_cfgs():
59-
env_cfg = {
60-
"num_actions": 12,
61-
# joint/link names
62-
"default_joint_angles": { # [rad]
50+
default_joint_angles={ # [rad]
6351
"left_elbow_yaw": 3.14,
6452
"right_elbow_yaw": 3.14,
6553
"right_hip_pitch": 0.0,
@@ -72,21 +60,12 @@ def get_cfgs():
7260
"left_knee_pitch": 0.0,
7361
"right_ankle_pitch": 0.0,
7462
"left_ankle_pitch": 0.0,
75-
},
76-
"dof_names": [
77-
"left_elbow_yaw",
78-
"right_elbow_yaw",
79-
"right_hip_pitch",
80-
"left_hip_pitch",
81-
"right_hip_yaw",
82-
"left_hip_yaw",
83-
"right_hip_roll",
84-
"left_hip_roll",
85-
"right_knee_pitch",
86-
"left_knee_pitch",
87-
"right_ankle_pitch",
88-
"left_ankle_pitch",
89-
],
63+
}
64+
env_cfg = {
65+
"num_actions": 12,
66+
# joint/link names
67+
"default_joint_angles": default_joint_angles,
68+
"dof_names": list(default_joint_angles.keys()),
9069
# PD
9170
"kp": 20.0,
9271
"kd": 0.5,
@@ -128,8 +107,8 @@ def get_cfgs():
128107
"num_commands": 3,
129108
# "lin_vel_y_range": [-0.5, -0.5], # move forward slowly
130109
"lin_vel_y_range": [-0.6, -0.6], # move faster than above!
131-
"lin_vel_x_range": [0, 0],
132-
"ang_vel_range": [0, 0],
110+
"lin_vel_x_range": [-0.01, 0.01],
111+
"ang_vel_range": [-0.01, 0.01],
133112
}
134113

135114
return env_cfg, obs_cfg, reward_cfg, command_cfg
@@ -139,7 +118,7 @@ def main():
139118
parser = argparse.ArgumentParser()
140119
parser.add_argument("-e", "--exp_name", type=str, default="zeroth-walking")
141120
parser.add_argument("-B", "--num_envs", type=int, default=4096)
142-
parser.add_argument("--max_iterations", type=int, default=100)
121+
parser.add_argument("--max_iterations", type=int, default=101)
143122
args = parser.parse_args()
144123

145124
gs.init(logging_level="warning")

0 commit comments

Comments
 (0)