|
| 1 | +import torch |
| 2 | +import math |
| 3 | +import genesis as gs |
| 4 | +from genesis.utils.geom import quat_to_xyz, transform_by_quat, inv_quat, transform_quat_by_quat |
| 5 | + |
| 6 | + |
| 7 | +def gs_rand_float(lower, upper, shape, device): |
| 8 | + return (upper - lower) * torch.rand(size=shape, device=device) + lower |
| 9 | + |
| 10 | + |
| 11 | +class ZerothEnv: |
| 12 | + def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_viewer=False, device="mps"): |
| 13 | + self.device = torch.device(device) |
| 14 | + |
| 15 | + self.num_envs = num_envs |
| 16 | + self.num_obs = obs_cfg["num_obs"] |
| 17 | + self.num_privileged_obs = None |
| 18 | + self.num_actions = env_cfg["num_actions"] |
| 19 | + self.num_commands = command_cfg["num_commands"] |
| 20 | + |
| 21 | + self.simulate_action_latency = True # there is a 1 step latency on real robot |
| 22 | + self.dt = 0.02 # control frequence on real robot is 50hz |
| 23 | + self.max_episode_length = math.ceil(env_cfg["episode_length_s"] / self.dt) |
| 24 | + |
| 25 | + self.env_cfg = env_cfg |
| 26 | + self.obs_cfg = obs_cfg |
| 27 | + self.reward_cfg = reward_cfg |
| 28 | + self.command_cfg = command_cfg |
| 29 | + |
| 30 | + self.obs_scales = obs_cfg["obs_scales"] |
| 31 | + self.reward_scales = reward_cfg["reward_scales"] |
| 32 | + |
| 33 | + # create scene |
| 34 | + self.scene = gs.Scene( |
| 35 | + sim_options=gs.options.SimOptions(dt=self.dt, substeps=2), |
| 36 | + viewer_options=gs.options.ViewerOptions( |
| 37 | + max_FPS=int(0.5 / self.dt), |
| 38 | + camera_pos=(2.0, 0.0, 2.5), |
| 39 | + camera_lookat=(0.0, 0.0, 0.5), |
| 40 | + camera_fov=40, |
| 41 | + ), |
| 42 | + vis_options=gs.options.VisOptions(n_rendered_envs=1), |
| 43 | + rigid_options=gs.options.RigidOptions( |
| 44 | + dt=self.dt, |
| 45 | + constraint_solver=gs.constraint_solver.Newton, |
| 46 | + enable_collision=True, |
| 47 | + enable_joint_limit=True, |
| 48 | + ), |
| 49 | + show_viewer=show_viewer, |
| 50 | + ) |
| 51 | + |
| 52 | + # add plain |
| 53 | + self.scene.add_entity(gs.morphs.URDF(file="urdf/plane/plane.urdf", fixed=True)) |
| 54 | + |
| 55 | + # add robot |
| 56 | + self.base_init_pos = torch.tensor(self.env_cfg["base_init_pos"], device=self.device) |
| 57 | + self.base_init_quat = torch.tensor(self.env_cfg["base_init_quat"], device=self.device) |
| 58 | + self.inv_base_init_quat = inv_quat(self.base_init_quat) |
| 59 | + self.robot = self.scene.add_entity( |
| 60 | + gs.morphs.URDF( |
| 61 | + # file="urdf/go2/urdf/go2.urdf", |
| 62 | + file="../resources/stompymicro/robot_fixed.urdf", |
| 63 | + pos=self.base_init_pos.cpu().numpy(), |
| 64 | + quat=self.base_init_quat.cpu().numpy(), |
| 65 | + ), |
| 66 | + ) |
| 67 | + |
| 68 | + # build |
| 69 | + self.scene.build(n_envs=num_envs) |
| 70 | + |
| 71 | + # names to indices |
| 72 | + self.motor_dofs = [self.robot.get_joint(name).dof_idx_local for name in self.env_cfg["dof_names"]] |
| 73 | + |
| 74 | + # PD control parameters |
| 75 | + self.robot.set_dofs_kp([self.env_cfg["kp"]] * self.num_actions, self.motor_dofs) |
| 76 | + self.robot.set_dofs_kv([self.env_cfg["kd"]] * self.num_actions, self.motor_dofs) |
| 77 | + |
| 78 | + # prepare reward functions and multiply reward scales by dt |
| 79 | + self.reward_functions, self.episode_sums = dict(), dict() |
| 80 | + for name in self.reward_scales.keys(): |
| 81 | + self.reward_scales[name] *= self.dt |
| 82 | + self.reward_functions[name] = getattr(self, "_reward_" + name) |
| 83 | + self.episode_sums[name] = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float) |
| 84 | + |
| 85 | + # initialize buffers |
| 86 | + self.base_lin_vel = torch.zeros((self.num_envs, 3), device=self.device, dtype=gs.tc_float) |
| 87 | + self.base_ang_vel = torch.zeros((self.num_envs, 3), device=self.device, dtype=gs.tc_float) |
| 88 | + self.projected_gravity = torch.zeros((self.num_envs, 3), device=self.device, dtype=gs.tc_float) |
| 89 | + self.global_gravity = torch.tensor([0.0, 0.0, -1.0], device=self.device, dtype=gs.tc_float).repeat( |
| 90 | + self.num_envs, 1 |
| 91 | + ) |
| 92 | + self.obs_buf = torch.zeros((self.num_envs, self.num_obs), device=self.device, dtype=gs.tc_float) |
| 93 | + self.rew_buf = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float) |
| 94 | + self.reset_buf = torch.ones((self.num_envs,), device=self.device, dtype=gs.tc_int) |
| 95 | + self.episode_length_buf = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_int) |
| 96 | + self.commands = torch.zeros((self.num_envs, self.num_commands), device=self.device, dtype=gs.tc_float) |
| 97 | + self.commands_scale = torch.tensor( |
| 98 | + [self.obs_scales["lin_vel"], self.obs_scales["lin_vel"], self.obs_scales["ang_vel"]], |
| 99 | + device=self.device, |
| 100 | + dtype=gs.tc_float, |
| 101 | + ) |
| 102 | + self.actions = torch.zeros((self.num_envs, self.num_actions), device=self.device, dtype=gs.tc_float) |
| 103 | + self.last_actions = torch.zeros_like(self.actions) |
| 104 | + self.dof_pos = torch.zeros_like(self.actions) |
| 105 | + self.dof_vel = torch.zeros_like(self.actions) |
| 106 | + self.last_dof_vel = torch.zeros_like(self.actions) |
| 107 | + self.base_pos = torch.zeros((self.num_envs, 3), device=self.device, dtype=gs.tc_float) |
| 108 | + self.base_quat = torch.zeros((self.num_envs, 4), device=self.device, dtype=gs.tc_float) |
| 109 | + self.default_dof_pos = torch.tensor( |
| 110 | + [self.env_cfg["default_joint_angles"][name] for name in self.env_cfg["dof_names"]], |
| 111 | + device=self.device, |
| 112 | + dtype=gs.tc_float, |
| 113 | + ) |
| 114 | + self.extras = dict() # extra information for logging |
| 115 | + |
| 116 | + def _resample_commands(self, envs_idx): |
| 117 | + self.commands[envs_idx, 0] = gs_rand_float(*self.command_cfg["lin_vel_x_range"], (len(envs_idx),), self.device) |
| 118 | + self.commands[envs_idx, 1] = gs_rand_float(*self.command_cfg["lin_vel_y_range"], (len(envs_idx),), self.device) |
| 119 | + self.commands[envs_idx, 2] = gs_rand_float(*self.command_cfg["ang_vel_range"], (len(envs_idx),), self.device) |
| 120 | + |
| 121 | + def step(self, actions): |
| 122 | + self.actions = torch.clip(actions, -self.env_cfg["clip_actions"], self.env_cfg["clip_actions"]) |
| 123 | + exec_actions = self.last_actions if self.simulate_action_latency else self.actions |
| 124 | + target_dof_pos = exec_actions * self.env_cfg["action_scale"] + self.default_dof_pos |
| 125 | + self.robot.control_dofs_position(target_dof_pos, self.motor_dofs) |
| 126 | + self.scene.step() |
| 127 | + |
| 128 | + # update buffers |
| 129 | + self.episode_length_buf += 1 |
| 130 | + self.base_pos[:] = self.robot.get_pos() |
| 131 | + self.base_quat[:] = self.robot.get_quat() |
| 132 | + self.base_euler = quat_to_xyz( |
| 133 | + transform_quat_by_quat(torch.ones_like(self.base_quat) * self.inv_base_init_quat, self.base_quat) |
| 134 | + ) |
| 135 | + inv_base_quat = inv_quat(self.base_quat) |
| 136 | + self.base_lin_vel[:] = transform_by_quat(self.robot.get_vel(), inv_base_quat) |
| 137 | + self.base_ang_vel[:] = transform_by_quat(self.robot.get_ang(), inv_base_quat) |
| 138 | + self.projected_gravity = transform_by_quat(self.global_gravity, inv_base_quat) |
| 139 | + self.dof_pos[:] = self.robot.get_dofs_position(self.motor_dofs) |
| 140 | + self.dof_vel[:] = self.robot.get_dofs_velocity(self.motor_dofs) |
| 141 | + |
| 142 | + # resample commands |
| 143 | + envs_idx = ( |
| 144 | + (self.episode_length_buf % int(self.env_cfg["resampling_time_s"] / self.dt) == 0) |
| 145 | + .nonzero(as_tuple=False) |
| 146 | + .flatten() |
| 147 | + ) |
| 148 | + self._resample_commands(envs_idx) |
| 149 | + |
| 150 | + # check termination and reset |
| 151 | + self.reset_buf = self.episode_length_buf > self.max_episode_length |
| 152 | + self.reset_buf |= torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"] |
| 153 | + self.reset_buf |= torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"] |
| 154 | + |
| 155 | + time_out_idx = (self.episode_length_buf > self.max_episode_length).nonzero(as_tuple=False).flatten() |
| 156 | + self.extras["time_outs"] = torch.zeros_like(self.reset_buf, device=self.device, dtype=gs.tc_float) |
| 157 | + self.extras["time_outs"][time_out_idx] = 1.0 |
| 158 | + |
| 159 | + self.reset_idx(self.reset_buf.nonzero(as_tuple=False).flatten()) |
| 160 | + |
| 161 | + # compute reward |
| 162 | + self.rew_buf[:] = 0.0 |
| 163 | + for name, reward_func in self.reward_functions.items(): |
| 164 | + rew = reward_func() * self.reward_scales[name] |
| 165 | + self.rew_buf += rew |
| 166 | + self.episode_sums[name] += rew |
| 167 | + |
| 168 | + # compute observations |
| 169 | + self.obs_buf = torch.cat( |
| 170 | + [ |
| 171 | + self.base_ang_vel * self.obs_scales["ang_vel"], # 3 |
| 172 | + self.projected_gravity, # 3 |
| 173 | + self.commands * self.commands_scale, # 3 |
| 174 | + (self.dof_pos - self.default_dof_pos) * self.obs_scales["dof_pos"], # 12 |
| 175 | + self.dof_vel * self.obs_scales["dof_vel"], # 12 |
| 176 | + self.actions, # 12 |
| 177 | + ], |
| 178 | + axis=-1, |
| 179 | + ) |
| 180 | + |
| 181 | + self.last_actions[:] = self.actions[:] |
| 182 | + self.last_dof_vel[:] = self.dof_vel[:] |
| 183 | + |
| 184 | + return self.obs_buf, None, self.rew_buf, self.reset_buf, self.extras |
| 185 | + |
| 186 | + def get_observations(self): |
| 187 | + return self.obs_buf |
| 188 | + |
| 189 | + def get_privileged_observations(self): |
| 190 | + return None |
| 191 | + |
| 192 | + def reset_idx(self, envs_idx): |
| 193 | + if len(envs_idx) == 0: |
| 194 | + return |
| 195 | + |
| 196 | + # reset dofs |
| 197 | + self.dof_pos[envs_idx] = self.default_dof_pos |
| 198 | + self.dof_vel[envs_idx] = 0.0 |
| 199 | + self.robot.set_dofs_position( |
| 200 | + position=self.dof_pos[envs_idx], |
| 201 | + dofs_idx_local=self.motor_dofs, |
| 202 | + zero_velocity=True, |
| 203 | + envs_idx=envs_idx, |
| 204 | + ) |
| 205 | + |
| 206 | + # reset base |
| 207 | + self.base_pos[envs_idx] = self.base_init_pos |
| 208 | + self.base_quat[envs_idx] = self.base_init_quat.reshape(1, -1) |
| 209 | + self.robot.set_pos(self.base_pos[envs_idx], zero_velocity=False, envs_idx=envs_idx) |
| 210 | + self.robot.set_quat(self.base_quat[envs_idx], zero_velocity=False, envs_idx=envs_idx) |
| 211 | + self.base_lin_vel[envs_idx] = 0 |
| 212 | + self.base_ang_vel[envs_idx] = 0 |
| 213 | + self.robot.zero_all_dofs_velocity(envs_idx) |
| 214 | + |
| 215 | + # reset buffers |
| 216 | + self.last_actions[envs_idx] = 0.0 |
| 217 | + self.last_dof_vel[envs_idx] = 0.0 |
| 218 | + self.episode_length_buf[envs_idx] = 0 |
| 219 | + self.reset_buf[envs_idx] = True |
| 220 | + |
| 221 | + # fill extras |
| 222 | + self.extras["episode"] = {} |
| 223 | + for key in self.episode_sums.keys(): |
| 224 | + self.extras["episode"]["rew_" + key] = ( |
| 225 | + torch.mean(self.episode_sums[key][envs_idx]).item() / self.env_cfg["episode_length_s"] |
| 226 | + ) |
| 227 | + self.episode_sums[key][envs_idx] = 0.0 |
| 228 | + |
| 229 | + self._resample_commands(envs_idx) |
| 230 | + |
| 231 | + def reset(self): |
| 232 | + self.reset_buf[:] = True |
| 233 | + self.reset_idx(torch.arange(self.num_envs, device=self.device)) |
| 234 | + return self.obs_buf, None |
| 235 | + |
| 236 | + # ------------ reward functions---------------- |
| 237 | + def _reward_tracking_lin_vel(self): |
| 238 | + # Tracking of linear velocity commands (xy axes) |
| 239 | + lin_vel_error = torch.sum(torch.square(self.commands[:, :2] - self.base_lin_vel[:, :2]), dim=1) |
| 240 | + return torch.exp(-lin_vel_error / self.reward_cfg["tracking_sigma"]) |
| 241 | + |
| 242 | + def _reward_tracking_ang_vel(self): |
| 243 | + # Tracking of angular velocity commands (yaw) |
| 244 | + ang_vel_error = torch.square(self.commands[:, 2] - self.base_ang_vel[:, 2]) |
| 245 | + return torch.exp(-ang_vel_error / self.reward_cfg["tracking_sigma"]) |
| 246 | + |
| 247 | + def _reward_lin_vel_z(self): |
| 248 | + # Penalize z axis base linear velocity |
| 249 | + return torch.square(self.base_lin_vel[:, 2]) |
| 250 | + |
| 251 | + def _reward_action_rate(self): |
| 252 | + # Penalize changes in actions |
| 253 | + return torch.sum(torch.square(self.last_actions - self.actions), dim=1) |
| 254 | + |
| 255 | + def _reward_similar_to_default(self): |
| 256 | + # Penalize joint poses far away from default pose |
| 257 | + return torch.sum(torch.abs(self.dof_pos - self.default_dof_pos), dim=1) |
| 258 | + |
| 259 | + def _reward_base_height(self): |
| 260 | + # Penalize base height away from target |
| 261 | + return torch.square(self.base_pos[:, 2] - self.reward_cfg["base_height_target"]) |
| 262 | + |
| 263 | + def _reward_gait_symmetry(self): |
| 264 | + # Reward symmetric gait patterns |
| 265 | + left_hip = self.dof_pos[:, self.env_cfg["dof_names"].index("left_hip_pitch")] |
| 266 | + right_hip = self.dof_pos[:, self.env_cfg["dof_names"].index("right_hip_pitch")] |
| 267 | + left_knee = self.dof_pos[:, self.env_cfg["dof_names"].index("left_knee_pitch")] |
| 268 | + right_knee = self.dof_pos[:, self.env_cfg["dof_names"].index("right_knee_pitch")] |
| 269 | + |
| 270 | + hip_symmetry = torch.abs(left_hip - right_hip) |
| 271 | + knee_symmetry = torch.abs(left_knee - right_knee) |
| 272 | + |
| 273 | + return torch.exp(-(hip_symmetry + knee_symmetry)) |
| 274 | + |
| 275 | + def _reward_energy_efficiency(self): |
| 276 | + # Reward energy efficiency by penalizing high joint velocities |
| 277 | + return -torch.sum(torch.square(self.dof_vel), dim=1) |
0 commit comments