Skip to content

Commit

Permalink
update rsl_rl to v2.0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
BigJohnn committed Jan 19, 2025
1 parent 5018186 commit 96cd96b
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 43 deletions.
Binary file modified examples/model_100.pt
Binary file not shown.
2 changes: 1 addition & 1 deletion sim/genesis/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# 安装rsl_rl。
```
git clone https://github.com/leggedrobotics/rsl_rl
cd rsl_rl && git checkout v1.0.2 && pip install -e .
cd rsl_rl && git checkout v2.0.1 && pip install -e .
```

# 安装tensorboard。
Expand Down
2 changes: 1 addition & 1 deletion sim/genesis/README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ cd sim/genesis
### Install rsl_rl
```bash
git clone https://github.com/leggedrobotics/rsl_rl
cd rsl_rl && git checkout v1.0.2 && pip install -e .
cd rsl_rl && git checkout v2.0.1 && pip install -e .
```

### Install TensorBoard
Expand Down
9 changes: 7 additions & 2 deletions sim/genesis/zeroth_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,15 @@ def step(self, actions):
self.last_actions[:] = self.actions[:]
self.last_dof_vel[:] = self.dof_vel[:]

return self.obs_buf, None, self.rew_buf, self.reset_buf, self.extras
return self.obs_buf, self.rew_buf, self.reset_buf, {
"observations": {
"critic": self.obs_buf
},
**self.extras
}

def get_observations(self):
return self.obs_buf
return self.obs_buf, {"observations": {"critic": self.obs_buf}}

def get_privileged_observations(self):
return None
Expand Down
6 changes: 5 additions & 1 deletion sim/genesis/zeroth_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def run_sim(env, policy, obs):
while True:
actions = policy(obs)
obs, _, rews, dones, infos = env.step(actions)
obs, _, _, _ = env.step(actions)

def main():
parser = argparse.ArgumentParser()
Expand All @@ -23,6 +23,10 @@ def main():

log_dir = f"logs/{args.exp_name}"
env_cfg, obs_cfg, reward_cfg, command_cfg, train_cfg = pickle.load(open(f"logs/{args.exp_name}/cfgs.pkl", "rb"))
# Add missing class_name fields
train_cfg["algorithm"]["class_name"] = "PPO"
train_cfg["policy"]["class_name"] = "ActorCritic"
print("train_cfg:", train_cfg) # Add debug print
reward_cfg["reward_scales"] = {}

env = ZerothEnv(
Expand Down
55 changes: 17 additions & 38 deletions sim/genesis/zeroth_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
def get_train_cfg(exp_name, max_iterations):

train_cfg_dict = {
"num_steps_per_env": 48,
"save_interval": 10,
"empirical_normalization": True, #这个功能可以帮助稳定训练过程,特别是在观测值范围变化较大的情况下
"algorithm": {
"clip_param": 0.2,
"desired_kl": 0.01,
Expand All @@ -24,42 +27,27 @@ def get_train_cfg(exp_name, max_iterations):
"schedule": "adaptive",
"use_clipped_value_loss": True,
"value_loss_coef": 1.0,
"class_name": "PPO",
},
"init_member_classes": {},
"policy": {
"activation": "elu",
"actor_hidden_dims": [512, 256, 128],
"critic_hidden_dims": [512, 256, 128],
"init_noise_std": 1.0,
"class_name": "ActorCritic",
},
"runner": {
"algorithm_class_name": "PPO",
"checkpoint": -1,
"experiment_name": exp_name,
"load_run": -1,
"log_interval": 1,
"max_iterations": max_iterations,
"num_steps_per_env": 48,
"policy_class_name": "ActorCritic",
"record_interval": -1,
"resume": False,
"resume_path": None,
"run_name": "",
"runner_class_name": "OnPolicyRunner",
"save_interval": 10,
},
"runner_class_name": "OnPolicyRunner",
"seed": 1,
"run_name": "zeroth-walking",
}
}

return train_cfg_dict


def get_cfgs():
env_cfg = {
"num_actions": 12,
# joint/link names
"default_joint_angles": { # [rad]
default_joint_angles={ # [rad]
"left_elbow_yaw": 3.14,
"right_elbow_yaw": 3.14,
"right_hip_pitch": 0.0,
Expand All @@ -72,21 +60,12 @@ def get_cfgs():
"left_knee_pitch": 0.0,
"right_ankle_pitch": 0.0,
"left_ankle_pitch": 0.0,
},
"dof_names": [
"left_elbow_yaw",
"right_elbow_yaw",
"right_hip_pitch",
"left_hip_pitch",
"right_hip_yaw",
"left_hip_yaw",
"right_hip_roll",
"left_hip_roll",
"right_knee_pitch",
"left_knee_pitch",
"right_ankle_pitch",
"left_ankle_pitch",
],
}
env_cfg = {
"num_actions": 12,
# joint/link names
"default_joint_angles": default_joint_angles,
"dof_names": list(default_joint_angles.keys()),
# PD
"kp": 20.0,
"kd": 0.5,
Expand Down Expand Up @@ -128,8 +107,8 @@ def get_cfgs():
"num_commands": 3,
# "lin_vel_y_range": [-0.5, -0.5], # move forward slowly
"lin_vel_y_range": [-0.6, -0.6], # move faster than above!
"lin_vel_x_range": [0, 0],
"ang_vel_range": [0, 0],
"lin_vel_x_range": [-0.01, 0.01],
"ang_vel_range": [-0.01, 0.01],
}

return env_cfg, obs_cfg, reward_cfg, command_cfg
Expand All @@ -139,7 +118,7 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--exp_name", type=str, default="zeroth-walking")
parser.add_argument("-B", "--num_envs", type=int, default=4096)
parser.add_argument("--max_iterations", type=int, default=100)
parser.add_argument("--max_iterations", type=int, default=101)
args = parser.parse_args()

gs.init(logging_level="warning")
Expand Down

0 comments on commit 96cd96b

Please sign in to comment.