|
| 1 | +from enum import Enum |
| 2 | +from typing import List |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | + |
| 7 | +from sim.utils.helpers import draw_vector |
| 8 | + |
| 9 | + |
| 10 | +class CommandMode(Enum): |
| 11 | + FIXED = "fixed" |
| 12 | + OSCILLATING = "oscillating" |
| 13 | + KEYBOARD = "keyboard" |
| 14 | + RANDOM = "random" |
| 15 | + |
| 16 | + |
| 17 | +class CommandManager: |
| 18 | + """Manages robot commands""" |
| 19 | + |
| 20 | + def __init__( |
| 21 | + self, |
| 22 | + num_envs: int = 1, |
| 23 | + mode: str = "fixed", |
| 24 | + default_cmd: List[float] = [0.3, 0.0, 0.0, 0.0], |
| 25 | + device="cpu", |
| 26 | + env_cfg=None, |
| 27 | + ): |
| 28 | + self.num_envs = num_envs |
| 29 | + self.mode = CommandMode(mode) |
| 30 | + self.device = device |
| 31 | + self.default_cmd = torch.tensor(default_cmd, device=self.device) |
| 32 | + self.commands = self.default_cmd.repeat(num_envs, 1) |
| 33 | + self.time = 0 |
| 34 | + self.env_cfg = env_cfg |
| 35 | + |
| 36 | + # Mode-specific parameters |
| 37 | + if self.mode == CommandMode.OSCILLATING: |
| 38 | + self.osc_period = 5.0 # secs |
| 39 | + self.min_x_vel = env_cfg.commands.ranges.lin_vel_x[0] if env_cfg else 0.0 |
| 40 | + self.max_x_vel = env_cfg.commands.ranges.lin_vel_x[1] if env_cfg else 0.3 |
| 41 | + self.osc_amplitude = (self.max_x_vel - self.min_x_vel) / 2 |
| 42 | + self.osc_offset = (self.max_x_vel + self.min_x_vel) / 2 |
| 43 | + elif self.mode == CommandMode.RANDOM: |
| 44 | + self.cmd_ranges = { |
| 45 | + 'lin_vel_x': env_cfg.commands.ranges.lin_vel_x, |
| 46 | + 'lin_vel_y': env_cfg.commands.ranges.lin_vel_y, |
| 47 | + 'ang_vel_yaw': env_cfg.commands.ranges.ang_vel_yaw, |
| 48 | + 'heading': env_cfg.commands.ranges.heading |
| 49 | + } if env_cfg else { |
| 50 | + 'lin_vel_x': [-0.05, 0.23], |
| 51 | + 'lin_vel_y': [-0.05, 0.05], |
| 52 | + 'ang_vel_yaw': [-0.5, 0.5], |
| 53 | + 'heading': [-np.pi, np.pi] |
| 54 | + } |
| 55 | + self.resampling_time = env_cfg.commands.resampling_time if env_cfg else 8.0 |
| 56 | + self.last_sample_time = 0.0 |
| 57 | + elif self.mode == CommandMode.KEYBOARD: |
| 58 | + try: |
| 59 | + import pygame |
| 60 | + pygame.init() |
| 61 | + pygame.display.set_mode((100, 100)) |
| 62 | + self.x_vel_cmd = 0.0 |
| 63 | + self.y_vel_cmd = 0.0 |
| 64 | + self.yaw_vel_cmd = 0.0 |
| 65 | + except ImportError: |
| 66 | + print("WARNING: pygame not found, falling back to fixed commands") |
| 67 | + self.mode = CommandMode.FIXED |
| 68 | + |
| 69 | + def update(self, dt: float) -> torch.Tensor: |
| 70 | + """Updates and returns commands based on current mode.""" |
| 71 | + self.time += dt |
| 72 | + |
| 73 | + if self.mode == CommandMode.FIXED: |
| 74 | + return self.commands |
| 75 | + elif self.mode == CommandMode.OSCILLATING: |
| 76 | + # Oscillate x velocity between min and max |
| 77 | + x_vel = self.osc_offset + self.osc_amplitude * torch.sin( |
| 78 | + torch.tensor(2 * np.pi * self.time / self.osc_period) |
| 79 | + ) |
| 80 | + self.commands[:, 0] = x_vel.to(self.device) |
| 81 | + elif self.mode == CommandMode.RANDOM: |
| 82 | + if self.time - self.last_sample_time >= self.resampling_time: |
| 83 | + self.last_sample_time = self.time |
| 84 | + # Generate random commands within training ranges |
| 85 | + new_commands = torch.tensor([ |
| 86 | + np.random.uniform(*self.cmd_ranges['lin_vel_x']), |
| 87 | + np.random.uniform(*self.cmd_ranges['lin_vel_y']), |
| 88 | + 0.0, |
| 89 | + np.random.uniform(*self.cmd_ranges['heading']) |
| 90 | + ], device=self.device) if self.env_cfg and self.env_cfg.commands.heading_command else torch.tensor([ |
| 91 | + np.random.uniform(*self.cmd_ranges['lin_vel_x']), |
| 92 | + np.random.uniform(*self.cmd_ranges['lin_vel_y']), |
| 93 | + np.random.uniform(*self.cmd_ranges['ang_vel_yaw']), |
| 94 | + 0.0 |
| 95 | + ], device=self.device) |
| 96 | + self.commands = new_commands.repeat(self.num_envs, 1) |
| 97 | + elif self.mode == CommandMode.KEYBOARD: |
| 98 | + self._handle_keyboard_input() |
| 99 | + self.commands[:, 0] = torch.tensor(self.x_vel_cmd, device=self.device) |
| 100 | + self.commands[:, 1] = torch.tensor(self.y_vel_cmd, device=self.device) |
| 101 | + self.commands[:, 2] = torch.tensor(self.yaw_vel_cmd, device=self.device) |
| 102 | + |
| 103 | + return self.commands |
| 104 | + |
| 105 | + def draw(self, gym, viewer, env_handles, robot_positions, actual_vels) -> None: |
| 106 | + """Draws command and actual velocity arrows for all robots.""" |
| 107 | + if viewer is None: |
| 108 | + return |
| 109 | + |
| 110 | + gym.clear_lines(viewer) |
| 111 | + cmd_vels = self.commands[:, :2].cpu().numpy() |
| 112 | + for env_handle, robot_pos, cmd_vel, actual_vel in zip(env_handles, robot_positions, cmd_vels, actual_vels): |
| 113 | + draw_vector(gym, viewer, env_handle, robot_pos, cmd_vel, (0.0, 1.0, 0.0)) # cmd vector (green) |
| 114 | + draw_vector(gym, viewer, env_handle, robot_pos, actual_vel, (1.0, 0.0, 0.0)) # vel vector (red) |
| 115 | + |
| 116 | + def _handle_keyboard_input(self): |
| 117 | + """Handles keyboard input for command updates.""" |
| 118 | + import pygame |
| 119 | + |
| 120 | + for event in pygame.event.get(): |
| 121 | + if event.type == pygame.QUIT: |
| 122 | + pygame.quit() |
| 123 | + |
| 124 | + keys = pygame.key.get_pressed() |
| 125 | + |
| 126 | + # Update movement commands based on arrow keys |
| 127 | + if keys[pygame.K_UP]: |
| 128 | + self.x_vel_cmd = min(self.x_vel_cmd + 0.0005, 0.5) |
| 129 | + if keys[pygame.K_DOWN]: |
| 130 | + self.x_vel_cmd = max(self.x_vel_cmd - 0.0005, -0.5) |
| 131 | + if keys[pygame.K_LEFT]: |
| 132 | + self.y_vel_cmd = min(self.y_vel_cmd + 0.0005, 0.5) |
| 133 | + if keys[pygame.K_RIGHT]: |
| 134 | + self.y_vel_cmd = max(self.y_vel_cmd - 0.0005, -0.5) |
| 135 | + |
| 136 | + # Yaw control |
| 137 | + if keys[pygame.K_a]: |
| 138 | + self.yaw_vel_cmd = min(self.yaw_vel_cmd + 0.001, 0.5) |
| 139 | + if keys[pygame.K_z]: |
| 140 | + self.yaw_vel_cmd = max(self.yaw_vel_cmd - 0.001, -0.5) |
| 141 | + |
| 142 | + # Reset commands |
| 143 | + if keys[pygame.K_SPACE]: |
| 144 | + self.x_vel_cmd = 0.0 |
| 145 | + self.y_vel_cmd = 0.0 |
| 146 | + self.yaw_vel_cmd = 0.0 |
0 commit comments