diff --git a/omnigibson/action_primitives/__init__.py b/omnigibson/action_primitives/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/omnigibson/action_primitives/pick_place_semantic_action_primitives.py b/omnigibson/action_primitives/pick_place_semantic_action_primitives.py new file mode 100644 index 000000000..6a4f380a2 --- /dev/null +++ b/omnigibson/action_primitives/pick_place_semantic_action_primitives.py @@ -0,0 +1,256 @@ +import numpy as np +import torch as th +from aenum import IntEnum, auto + +import omnigibson as og +import omnigibson.utils.transform_utils as T +from omnigibson.action_primitives.starter_semantic_action_primitives import StarterSemanticActionPrimitives +from omnigibson.controllers.controller_base import ControlType +from omnigibson.macros import create_module_macros + +m = create_module_macros(module_path=__file__) + +m.DEFAULT_BODY_OFFSET_FROM_FLOOR = 0.05 +m.MAX_CARTESIAN_HAND_STEP = 0.07 +m.MAX_STEPS_FOR_HAND_MOVE_IK = 50 +m.JOINT_POS_DIFF_THRESHOLD = 0.005 +m.MOVE_HAND_POS_THRESHOLD = 0.02 + + +class PickPlaceSemanticActionPrimitives(StarterSemanticActionPrimitives): + def __init__( + self, + env, + vid_logger, + add_context=False, + enable_head_tracking=True, + always_track_eef=False, + task_relevant_objects_only=False, + ): + super().__init__(env) + self.vid_logger = vid_logger + self.reset_state_info() + + def _update_macros(self): + # update superclass macros with m from this file. + self.m.update(m) + + def reset_state_info(self): + """ + Called upon each env.reset(). + TODO: This stuff should probably be moved to another Semantic Env-side class + """ + self.state_info = dict( + gripper_closed=False, + ) + + def _grasp(self, obj_name): + # Set grasp pose + # quat is hardcoded (a somewhat top-down pose) + org_quat = th.tensor([0.79082719, -0.20438075, -0.55453328, -0.15910284]) + obj = self.env.scene.object_registry("name", obj_name) + org_pos = obj.get_position_orientation()[0] + print("pos, ori:", org_pos, org_quat) + + im = og.sim.viewer_camera._get_obs()[0]["rgb"][:, :, :3] + obs, obs_info = self.env.get_obs() + im_robot_view = obs["robot0"]["robot0:eyes:Camera:0"]["rgb"][:, :, :3] + + # If want to keep the original target pose + new_pos, new_quat = org_pos, org_quat + print("obj pos", obj.get_position_orientation()[0]) + + # 1. Move to pregrasp pose + pre_grasp_pose = (th.tensor(new_pos) + th.tensor([0.0, 0.0, 0.2]), th.tensor(new_quat)) + self.state_info["gripper_closed"] = self.execute_controller( + self._move_hand_linearly_cartesian( + pre_grasp_pose, + stop_if_stuck=False, + ignore_failure=True, + gripper_closed=self.state_info["gripper_closed"], + move_hand_pos_thresh=m.MOVE_HAND_POS_THRESHOLD, + ), + self.state_info["gripper_closed"], + self.vid_logger, + ) + + init_obj_pos = obj.get_position_orientation()[0] + + # 2. Move to grasp pose + grasp_pose = (th.tensor(new_pos), th.tensor(new_quat)) + self.state_info["gripper_closed"] = self.execute_controller( + self._move_hand_linearly_cartesian( + grasp_pose, + stop_if_stuck=False, + ignore_failure=True, + gripper_closed=self.state_info["gripper_closed"], + move_hand_pos_thresh=m.MOVE_HAND_POS_THRESHOLD, + ), + self.state_info["gripper_closed"], + self.vid_logger, + ) + + # 3. Perform grasp + self.state_info["gripper_closed"] = True + action = self._empty_action() + action[20] = -1 + _ = self.execute_controller( + [action], + self.state_info["gripper_closed"], + self.vid_logger, + ) + + # step the simulator a few steps to let the gripper close completely + for _ in range(40): + og.sim.step() + self.vid_logger.save_im_text() + + action_to_add = np.concatenate((np.array([0.0, 0.0, 0.0]), np.array(action[12:19]))) + + # 4. Move to a random pose in a neighbourhood + x, y = org_pos[:2] + z = org_pos[2] + 0.15 + neighbourhood_pose = (th.tensor([x, y, z]), grasp_pose[1]) + self.state_info["gripper_closed"] = self.execute_controller( + self._move_hand_linearly_cartesian( + neighbourhood_pose, + stop_if_stuck=False, + ignore_failure=True, + gripper_closed=self.state_info["gripper_closed"], + move_hand_pos_thresh=m.MOVE_HAND_POS_THRESHOLD, + ), + self.state_info["gripper_closed"], + self.vid_logger, + ) + + final_obj_pos = obj.get_position_orientation()[0] + # Adding all 0 action for the last step + success = (final_obj_pos[2] - init_obj_pos[2]) > 0.02 + + return success + + def _place_on_top(self, obj_name, dest_obj_name): + """ + obj_name (str): success depends on this obj being on the target location + dest_obj_name (str): obj to place obj_name on + xyz_pos (torch.tensor): point the gripper should move to to open. + """ + dest_obj = self.env.scene.object_registry("name", dest_obj_name) + + # 1. Move to a drop point + obj_place_loc = dest_obj.get_position_orientation()[0] + xyz_pos = th.tensor(obj_place_loc + np.array([0.0, 0.0, 0.2])) + quat = th.tensor([0.79082719, -0.20438075, -0.55453328, -0.15910284]) + open_gripper_pose = (xyz_pos, quat) + self.state_info["gripper_closed"] = self.execute_controller( + self._move_hand_linearly_cartesian( + open_gripper_pose, + stop_if_stuck=False, + ignore_failure=True, + gripper_closed=self.state_info["gripper_closed"], + move_hand_pos_thresh=m.MOVE_HAND_POS_THRESHOLD, + ), + self.state_info["gripper_closed"], + self.vid_logger, + ) + + # 2. Open Gripper + self.state_info["gripper_closed"] = False + action = self._empty_action() + # action[20] = 1 + _ = self.execute_controller( + [action], + self.state_info["gripper_closed"], + self.vid_logger, + ) + + # step the simulator a few steps to let the gripper open completely + for _ in range(40): + og.sim.step() + self.vid_logger.save_im_text() + + obj = self.env.scene.object_registry("name", obj_name) + obj_pos = obj.get_position_orientation()[0] + obj_place_loc = dest_obj.get_position_orientation()[0] + obj_z_dist = th.norm(obj_pos[2] - obj_place_loc[2]) + obj_xy_dist = th.norm(obj_pos[:2] - obj_place_loc[:2]) + print(f"obj_xy_dist, obj_z_dist: {obj_xy_dist} {obj_z_dist}") + success = bool((obj_z_dist <= 0.07).item() and (obj_xy_dist <= 0.06).item()) + return success + + def execute_controller(self, ctrl_gen, gripper_closed, arr=None): + actions = [] + counter = 0 + for action in ctrl_gen: + if action == "Done": + obs, obs_info = self.env.get_obs() + + proprio = self.robot._get_proprioception_dict() + # add eef pose and base pose to proprio + proprio["left_eef_pos"], proprio["left_eef_orn"] = self.robot.get_relative_eef_pose(arm="left") + proprio["right_eef_pos"], proprio["right_eef_orn"] = self.robot.get_relative_eef_pose(arm="right") + proprio["base_pos"], proprio["base_orn"] = self.robot.get_position_orientation() + + is_contact = detect_robot_collision_in_sim(self.robot) + + continue + + wait = False + + if gripper_closed: + action[20] = -1 + else: + action[20] = 1 + + o, r, te, tr, info = self.env.step(action) + self.vid_logger.save_im_text() + + if wait: + for _ in range(60): + og.sim.step() + + counter += 1 + + print("total steps: ", counter) + return gripper_closed + + def _empty_action(self): + """ + Get a no-op action that allows us to run simulation without changing robot configuration. + + Returns: + np.array or None: Action array for one step for the robot to do nothing + """ + action = th.zeros(self.robot.action_dim) + for name, controller in self.robot._controllers.items(): + joint_idx = controller.dof_idx.long() + action_idx = self.robot.controller_action_idx[name] + if ( + controller.control_type == ControlType.POSITION + and len(joint_idx) == len(action_idx) + and not controller.use_delta_commands + ): + action[action_idx] = self.robot.get_joint_positions()[joint_idx] + elif self.robot._controller_config[name]["name"] == "InverseKinematicsController": + if self.robot._controller_config["arm_" + self.arm]["mode"] == "pose_absolute_ori": + current_quat = self.robot.get_relative_eef_orientation() + current_ori = T.quat2axisangle(current_quat) + control_idx = self.robot.controller_action_idx["arm_" + self.arm] + action[control_idx[3:]] = current_ori + + return action + + def _move_hand_direct_ik( + self, + target_pose, + pos_thresh=0.01, + ori_thresh=0.1, + **kwargs, + ): + # change pos and ori thresh + return super()._move_hand_direct_ik( + target_pose, + pos_thresh=pos_thresh, + ori_thresh=ori_thresh, + **kwargs, + ) diff --git a/omnigibson/action_primitives/starter_semantic_action_primitives.py b/omnigibson/action_primitives/starter_semantic_action_primitives.py index 91e0affeb..6a5d000a2 100644 --- a/omnigibson/action_primitives/starter_semantic_action_primitives.py +++ b/omnigibson/action_primitives/starter_semantic_action_primitives.py @@ -14,9 +14,11 @@ import cv2 import gymnasium as gym +import numpy as np import torch as th from aenum import IntEnum, auto from matplotlib import pyplot as plt +from scipy.spatial.transform import Rotation, Slerp import omnigibson as og import omnigibson.lazy as lazy @@ -187,6 +189,7 @@ def _assemble_robot_copy(self): if link_name in arm_links else self.robot_copy.links_relative_poses[self.robot_copy_type][link_name] ) + link_pose = [th.tensor(arr) for arr in link_pose] mesh_copy_pose = T.pose_transform( *link_pose, *self.robot_copy.relative_poses[self.robot_copy_type][link_name][mesh_name] ) @@ -329,6 +332,13 @@ def __init__( self.robot_copy = self._load_robot_copy() + self.m = m # create a pointer so subclasses can maybe update macros + self._update_macros() + + def _update_macros(self): + """Subclasses use this to update m if needed""" + pass + @property def arm(self): if not isinstance(self.robot, ManipulationRobot): @@ -723,7 +733,7 @@ def _grasp(self, obj): indented_print("Navigating to grasp pose if needed") yield from self._navigate_if_needed(obj, pose_on_obj=grasp_pose) - indented_print("Moving hand to grasp pose") + indented_print(f"Moving to grasp pose {grasp_pose}") yield from self._move_hand(grasp_pose) # We can pre-grasp in sticky grasping mode. @@ -732,7 +742,7 @@ def _grasp(self, obj): # Since the grasp pose is slightly off the object, we want to move towards the object, around 5cm. # It's okay if we can't go all the way because we run into the object. - indented_print("Performing grasp approach") + indented_print(f"Move to grasp approach pose {approach_pose}") yield from self._move_hand_linearly_cartesian(approach_pose, stop_on_contact=True) # Step once to update @@ -1004,7 +1014,7 @@ def _move_hand_joint(self, joint_pos): # Follow the plan to navigate. indented_print("Plan has %d steps", len(plan)) for i, joint_pos in enumerate(plan): - indented_print("Executing grasp plan step %d/%d", i + 1, len(plan)) + indented_print(f"Executing grasp plan step {i + 1}/{len(plan)}") yield from self._move_hand_direct_joint(joint_pos, ignore_failure=True) def _move_hand_ik(self, eef_pose, stop_if_stuck=False): @@ -1040,7 +1050,7 @@ def _move_hand_ik(self, eef_pose, stop_if_stuck=False): indented_print("Plan has %d steps", len(plan)) for i, target_pose in enumerate(plan): target_pos = target_pose[:3] - target_quat = T.axisangle2quat(target_pose[3:]) + target_quat = T.axisangle2quat(th.tensor(target_pose[3:])) indented_print("Executing grasp plan step %d/%d", i + 1, len(plan)) yield from self._move_hand_direct_ik( (target_pos, target_quat), ignore_failure=True, in_world_frame=False, stop_if_stuck=stop_if_stuck @@ -1107,7 +1117,7 @@ def _move_hand_direct_joint(self, joint_pos, stop_on_contact=False, ignore_failu if not ignore_failure: raise ActionPrimitiveError( ActionPrimitiveError.Reason.EXECUTION_ERROR, - "Your hand was obstructed from moving to the desired joint position", + "[_move_hand_direct_joint]: Your hand was obstructed from moving to the desired joint position", ) def _move_hand_direct_ik( @@ -1143,9 +1153,14 @@ def _move_hand_direct_ik( assert ( controller_config["name"] == "InverseKinematicsController" ), "Controller must be InverseKinematicsController" - assert controller_config["mode"] == "pose_absolute_ori", "Controller must be in pose_absolute_ori mode" + assert controller_config["mode"] in [ + "pose_absolute_ori", + "pose_delta_ori", + ], "Controller must be in pose_absolute_ori or pose_delta_ori mode" if in_world_frame: target_pose = self._get_pose_in_robot_frame(target_pose) + else: + target_pose = [th.tensor(arr) for arr in target_pose] target_pos = target_pose[0] target_orn = target_pose[1] target_orn_axisangle = T.quat2axisangle(target_pose[1]) @@ -1162,6 +1177,11 @@ def _move_hand_direct_ik( delta_pos = target_pos - current_pos target_pos_diff = th.norm(delta_pos) + if target_pos_diff < 5 * pos_thresh: + delta_pos *= 5 * (pos_thresh / target_pos_diff) + + delta_ori = Rotation.from_quat(target_orn) * Rotation.from_quat(current_orn).inv() + delta_ori = th.tensor(delta_ori.as_rotvec()).float() target_orn_diff = T.get_orientation_diff_in_radian(current_orn, target_orn) reached_goal = target_pos_diff < pos_thresh and target_orn_diff < ori_thresh if reached_goal: @@ -1180,17 +1200,27 @@ def _move_hand_direct_ik( prev_pos = current_pos prev_orn = current_orn - action[control_idx] = th.cat([delta_pos, target_orn_axisangle]) + if self.robot._controller_config["arm_" + self.arm]["mode"] == "pose_absolute_ori": + action[control_idx] = th.cat([delta_pos, target_orn_axisangle]) + elif self.robot._controller_config["arm_" + self.arm]["mode"] == "pose_delta_ori": + action[control_idx] = th.cat([delta_pos, delta_ori]) + yield self._postprocess_action(action) if not ignore_failure: raise ActionPrimitiveError( ActionPrimitiveError.Reason.EXECUTION_ERROR, - "Your hand was obstructed from moving to the desired joint position", + "[_move_hand_direct_ik] Your hand was obstructed from moving to the desired joint position", ) def _move_hand_linearly_cartesian( - self, target_pose, stop_on_contact=False, ignore_failure=False, stop_if_stuck=False + self, + target_pose, + stop_on_contact=False, + ignore_failure=False, + stop_if_stuck=False, + gripper_closed=False, + move_hand_pos_thresh=0.01, ): """ Yields action for the robot to move its arm to reach the specified target pose by moving the eef along a line in cartesian @@ -1208,13 +1238,18 @@ def _move_hand_linearly_cartesian( # into 1cm-long pieces start_pos, start_orn = self.robot.eef_links[self.arm].get_position_orientation() travel_distance = th.norm(target_pose[0] - start_pos) - num_poses = th.max([2, int(travel_distance / m.MAX_CARTESIAN_HAND_STEP) + 1]).item() - pos_waypoints = th.linspace(start_pos, target_pose[0], num_poses) + print(f"start_pos, target_pos: {start_pos}, {target_pose[0]}") + num_poses = th.max(th.tensor([2, int(travel_distance / m.MAX_CARTESIAN_HAND_STEP) + 1])).item() + pos_waypoints = th.tensor(np.linspace(start_pos, target_pose[0], num_poses)) # Also interpolate the rotations t_values = th.linspace(0, 1, num_poses) quat_waypoints = [T.quat_slerp(start_orn, target_pose[1], t) for t in t_values] + # remove the first waypoint as it is the starting pose + pos_waypoints = pos_waypoints[1:] + quat_waypoints = quat_waypoints[1:] + controller_config = self.robot._controller_config["arm_" + self.arm] if controller_config["name"] == "InverseKinematicsController": waypoints = list(zip(pos_waypoints, quat_waypoints)) @@ -1230,7 +1265,7 @@ def _move_hand_linearly_cartesian( else: yield from self._move_hand_direct_ik( waypoints[-1], - pos_thresh=0.01, + pos_thresh=move_hand_pos_thresh, ori_thresh=0.1, stop_on_contact=stop_on_contact, ignore_failure=ignore_failure, @@ -1759,8 +1794,8 @@ def _sample_pose_near_object(self, obj, pose_on_obj=None, **kwargs): distance_lo, distance_hi = 0.0, 5.0 distance = (th.rand(1) * (distance_hi - distance_lo) + distance_lo).item() yaw_lo, yaw_hi = -math.pi, math.pi - yaw = (th.rand(1) * (yaw_hi - yaw_lo) + yaw_lo).item() - avg_arm_workspace_range = th.mean(self.robot.arm_workspace_range[self.arm]) + yaw = th.rand(1) * (yaw_hi - yaw_lo) + yaw_lo + avg_arm_workspace_range = th.mean(th.tensor(self.robot.arm_workspace_range[self.arm])) pose_2d = th.tensor( [ pose_on_obj[0][0] + distance * th.cos(yaw), diff --git a/omnigibson/configs/tiago_primitives.yaml b/omnigibson/configs/tiago_primitives.yaml new file mode 100644 index 000000000..96f8687de --- /dev/null +++ b/omnigibson/configs/tiago_primitives.yaml @@ -0,0 +1,69 @@ +env: + action_frequency: 30 # (int): environment executes action at the action_frequency rate + physics_frequency: 120 # (int): physics frequency (1 / physics_timestep for physx) + device: null # (None or str): specifies the device to be used if running on the gpu with torch backend + automatic_reset: false # (bool): whether to automatic reset after an episode finishes + flatten_action_space: false # (bool): whether to flatten the action space as a sinle 1D-array + flatten_obs_space: false # (bool): whether the observation space should be flattened when generated + use_external_obs: false # (bool): Whether to use external observations or not + initial_pos_z_offset: 0.1 + external_sensors: null # (None or list): If specified, list of sensor configurations for external sensors to add. Should specify sensor "type" and any additional kwargs to instantiate the sensor. Each entry should be the kwargs passed to @create_sensor, in addition to position, orientation + +render: + viewer_width: 1280 + viewer_height: 720 + +scene: + type: InteractiveTraversableScene + scene_model: Rs_int + trav_map_resolution: 0.1 + default_erosion_radius: 0.0 + trav_map_with_objects: true + num_waypoints: 1 + waypoint_resolution: 0.2 + load_object_categories: null + not_load_object_categories: null + load_room_types: null + load_room_instances: null + load_task_relevant_only: false + seg_map_resolution: 0.1 + scene_source: OG + include_robots: false + +robots: + - type: Tiago + # obs_modalities: [rgb, depth, seg_semantic, normal, seg_instance, seg_instance_id] + obs_modalities: [rgb] + scale: 1.0 + self_collisions: true + action_normalize: false + action_type: continuous + grasping_mode: physical + rigid_trunk: true + default_trunk_offset: 0.15 + default_arm_pose: vertical + default_arm_side: right + sensor_config: + VisionSensor: + sensor_kwargs: + image_height: 128 #720 + image_width: 128 #720 + controller_config: + base: + name: JointController + arm_left: + name: NullJointController + arm_right: + name: InverseKinematicsController + gripper_left: + name: MultiFingerGripperController + mode: binary + gripper_right: + name: MultiFingerGripperController + camera: + name: JointController + +objects: [] + +task: + type: DummyTask diff --git a/omnigibson/controllers/controller_base.py b/omnigibson/controllers/controller_base.py index 84a72bbb4..e0f5706e9 100644 --- a/omnigibson/controllers/controller_base.py +++ b/omnigibson/controllers/controller_base.py @@ -133,8 +133,8 @@ def __init__( ) command_output_limits = ( ( - self._control_limits[self.control_type][0][self.dof_idx], - self._control_limits[self.control_type][1][self.dof_idx], + th.tensor(self._control_limits[self.control_type][0])[self.dof_idx.long()], + th.tensor(self._control_limits[self.control_type][1])[self.dof_idx.long()], ) if type(command_output_limits) == str and command_output_limits == "default" else command_output_limits @@ -281,11 +281,11 @@ def clip_control(self, control): Array[float]: Clipped control signal """ clipped_control = control.clip( - self._control_limits[self.control_type][0][self.dof_idx], - self._control_limits[self.control_type][1][self.dof_idx], + self._control_limits[self.control_type][0][self.dof_idx.long()], + self._control_limits[self.control_type][1][self.dof_idx.long()], ) idx = ( - self._dof_has_limits[self.dof_idx] + self._dof_has_limits[self.dof_idx.long()] if self.control_type == ControlType.POSITION else [True] * self.control_dim ) diff --git a/omnigibson/controllers/ik_controller.py b/omnigibson/controllers/ik_controller.py index c17670458..139c7e22b 100644 --- a/omnigibson/controllers/ik_controller.py +++ b/omnigibson/controllers/ik_controller.py @@ -312,7 +312,7 @@ def compute_control(self, goal_dict, control_dict): target_quat = goal_dict["target_quat"] # Calculate and return IK-backed out joint angles - current_joint_pos = control_dict["joint_position"][self.dof_idx] + current_joint_pos = control_dict["joint_position"][self.dof_idx.long()] # If the delta is really small, we just keep the current joint position. This avoids joint # drift caused by IK solver inaccuracy even when zero delta actions are provided. @@ -326,15 +326,15 @@ def compute_control(self, goal_dict, control_dict): err = th.cat([pos_err, ori_err]) # Use the jacobian to compute a local approximation - j_eef = control_dict[f"{self.task_name}_jacobian_relative"][:, self.dof_idx] + j_eef = control_dict[f"{self.task_name}_jacobian_relative"][:, self.dof_idx.long()] j_eef_pinv = th.linalg.pinv(j_eef) delta_j = j_eef_pinv @ err target_joint_pos = current_joint_pos + delta_j # Clip values to be within the joint limits target_joint_pos = target_joint_pos.clamp( - min=self._control_limits[ControlType.get_type("position")][0][self.dof_idx], - max=self._control_limits[ControlType.get_type("position")][1][self.dof_idx], + min=self._control_limits[ControlType.get_type("position")][0][self.dof_idx.long()], + max=self._control_limits[ControlType.get_type("position")][1][self.dof_idx.long()], ) # Optionally pass through smoothing filter for better stability diff --git a/omnigibson/controllers/joint_controller.py b/omnigibson/controllers/joint_controller.py index eca76523b..df85c17b3 100644 --- a/omnigibson/controllers/joint_controller.py +++ b/omnigibson/controllers/joint_controller.py @@ -135,7 +135,7 @@ def __init__( def _update_goal(self, command, control_dict): # Compute the base value for the command - base_value = control_dict[f"joint_{self._motor_type}"][self.dof_idx] + base_value = control_dict[f"joint_{self._motor_type}"][self.dof_idx.long()] # If we're using delta commands, add this value if self._use_delta_commands: @@ -166,8 +166,8 @@ def _update_goal(self, command, control_dict): # Clip the command based on the limits target = target.clip( - self._control_limits[ControlType.get_type(self._motor_type)][0][self.dof_idx], - self._control_limits[ControlType.get_type(self._motor_type)][1][self.dof_idx], + self._control_limits[ControlType.get_type(self._motor_type)][0][self.dof_idx.long()], + self._control_limits[ControlType.get_type(self._motor_type)][1][self.dof_idx.long()], ) return dict(target=target) @@ -189,7 +189,7 @@ def compute_control(self, goal_dict, control_dict): Returns: Array[float]: outputted (non-clipped!) control signal to deploy """ - base_value = control_dict[f"joint_{self._motor_type}"][self.dof_idx] + base_value = control_dict[f"joint_{self._motor_type}"][self.dof_idx.long()] target = goal_dict["target"] # Convert control into efforts @@ -197,7 +197,7 @@ def compute_control(self, goal_dict, control_dict): if self._motor_type == "position": # Run impedance controller -- effort = pos_err * kp + vel_err * kd position_error = target - base_value - vel_pos_error = -control_dict[f"joint_velocity"][self.dof_idx] + vel_pos_error = -control_dict[f"joint_velocity"][self.dof_idx.long()] u = position_error * self.kp + vel_pos_error * self.kd elif self._motor_type == "velocity": # Compute command torques via PI velocity controller plus gravity compensation torques @@ -207,16 +207,17 @@ def compute_control(self, goal_dict, control_dict): u = target dof_idxs_mat = th.meshgrid(self.dof_idx, self.dof_idx, indexing="xy") + dof_idxs_mat = tuple(x.long() for x in dof_idxs_mat) mm = control_dict["mass_matrix"][dof_idxs_mat] u = mm @ u # Add gravity compensation if self._use_gravity_compensation: - u += control_dict["gravity_force"][self.dof_idx] + u += control_dict["gravity_force"][self.dof_idx.long()] # Add Coriolis / centrifugal compensation if self._use_cc_compensation: - u += control_dict["cc_force"][self.dof_idx] + u += control_dict["cc_force"][self.dof_idx.long()] else: # Desired is the exact goal @@ -229,7 +230,7 @@ def compute_no_op_goal(self, control_dict): # Compute based on mode if self._motor_type == "position": # Maintain current qpos - target = control_dict[f"joint_{self._motor_type}"][self.dof_idx] + target = control_dict[f"joint_{self._motor_type}"][self.dof_idx.long()] else: # For velocity / effort, directly set to 0 target = th.zeros(self.control_dim) diff --git a/omnigibson/controllers/multi_finger_gripper_controller.py b/omnigibson/controllers/multi_finger_gripper_controller.py index a900e21ed..bb67b01eb 100644 --- a/omnigibson/controllers/multi_finger_gripper_controller.py +++ b/omnigibson/controllers/multi_finger_gripper_controller.py @@ -159,20 +159,20 @@ def compute_control(self, goal_dict, control_dict): Array[float]: outputted (non-clipped!) control signal to deploy """ target = goal_dict["target"] - joint_pos = control_dict["joint_position"][self.dof_idx] + joint_pos = control_dict["joint_position"][self.dof_idx.long()] # Choose what to do based on control mode if self._mode == "binary": # Use max control signal should_open = target[0] >= 0.0 if not self._inverted else target[0] > 0.0 if should_open: u = ( - self._control_limits[ControlType.get_type(self._motor_type)][1][self.dof_idx] + self._control_limits[ControlType.get_type(self._motor_type)][1][self.dof_idx.long()] if self._open_qpos is None else self._open_qpos ) else: u = ( - self._control_limits[ControlType.get_type(self._motor_type)][0][self.dof_idx] + self._control_limits[ControlType.get_type(self._motor_type)][0][self.dof_idx.long()] if self._closed_qpos is None else self._closed_qpos ) @@ -183,10 +183,10 @@ def compute_control(self, goal_dict, control_dict): # If we're near the joint limits and we're using velocity / torque control, we zero out the action if self._motor_type in {"velocity", "torque"}: violate_upper_limit = ( - joint_pos > self._control_limits[ControlType.POSITION][1][self.dof_idx] - self._limit_tolerance + joint_pos > self._control_limits[ControlType.POSITION][1][self.dof_idx.long()] - self._limit_tolerance ) violate_lower_limit = ( - joint_pos < self._control_limits[ControlType.POSITION][0][self.dof_idx] + self._limit_tolerance + joint_pos < self._control_limits[ControlType.POSITION][0][self.dof_idx.long()] + self._limit_tolerance ) violation = th.logical_or(violate_upper_limit * (u > 0), violate_lower_limit * (u < 0)) u *= ~violation @@ -228,7 +228,7 @@ def _update_grasping_state(self, control_dict): is_grasping = IsGraspingState.UNKNOWN else: - finger_pos = control_dict["joint_position"][self.dof_idx] + finger_pos = control_dict["joint_position"][self.dof_idx.long()] # For joint position control, if the desired positions are the same as the current positions, is_grasping unknown if self._motor_type == "position" and th.mean(th.abs(finger_pos - self._control)) < m.POS_TOLERANCE: @@ -240,9 +240,9 @@ def _update_grasping_state(self, control_dict): # Otherwise, the last control signal intends to "move" the gripper else: - finger_vel = control_dict["joint_velocity"][self.dof_idx] - min_pos = self._control_limits[ControlType.POSITION][0][self.dof_idx] - max_pos = self._control_limits[ControlType.POSITION][1][self.dof_idx] + finger_vel = control_dict["joint_velocity"][self.dof_idx.long()] + min_pos = self._control_limits[ControlType.POSITION][0][self.dof_idx.long()] + max_pos = self._control_limits[ControlType.POSITION][1][self.dof_idx.long()] # Make sure we don't have any invalid values (i.e.: fingers should be within the limits) finger_pos = th.clip(finger_pos, min_pos, max_pos) diff --git a/omnigibson/examples/action_primitives/pick_place_example.py b/omnigibson/examples/action_primitives/pick_place_example.py new file mode 100644 index 000000000..9d3ddd702 --- /dev/null +++ b/omnigibson/examples/action_primitives/pick_place_example.py @@ -0,0 +1,167 @@ +import os +import time +from argparse import ArgumentParser +from collections import Counter + +import numpy as np +import torch as th +import yaml +from scipy.spatial.transform import Rotation as R + +import omnigibson as og +from omnigibson.action_primitives.pick_place_semantic_action_primitives import PickPlaceSemanticActionPrimitives +from omnigibson.utils.motion_planning_utils import detect_robot_collision_in_sim +from omnigibson.utils.video_logging_utils import VideoLogger + + +def custom_reset(env, robot, args, vid_logger): + proprio = robot._get_proprioception_dict() + # curr_right_arm_joints = th.tensor(proprio['arm_right_qpos']) + reset_right_arm_joints = th.tensor([0.85846, -0.44852, 1.81008, 1.63368, 0.43764, -1.32488, -0.68415]) + + noise_1 = np.random.uniform(-0.2, 0.2, 3) + noise_2 = np.random.uniform(-0.01, 0.01, 4) + noise = th.tensor(np.concatenate((noise_1, noise_2))) + right_hand_joints_pos = reset_right_arm_joints + noise + # right_hand_joints_pos = curr_right_arm_joints + noise + + scene_initial_state = env.scene._initial_state + # for manipulation + base_pos = np.array([-0.05, -0.4, 0.0]) + base_x_noise = np.random.uniform(-0.15, 0.15) + base_y_noise = np.random.uniform(-0.15, 0.15) + base_noise = np.array([base_x_noise, base_y_noise, 0.0]) + base_pos += base_noise + scene_initial_state["object_registry"]["robot0"]["root_link"]["pos"] = base_pos + + base_yaw = -120 + base_yaw_noise = np.random.uniform(-15, 15) + base_yaw += base_yaw_noise + r_euler = R.from_euler("z", base_yaw, degrees=True) # or -120 + r_quat = R.as_quat(r_euler) + scene_initial_state["object_registry"]["robot0"]["root_link"]["ori"] = r_quat + + default_head_joints = th.tensor([-0.5031718015670776, -0.9972541332244873]) + noise_1 = np.random.uniform(-0.1, 0.1, 1) + noise_2 = np.random.uniform(-0.1, 0.1, 1) + noise = th.tensor(np.concatenate((noise_1, noise_2))) + head_joints = default_head_joints + noise + + # Reset environment and robot + obs, info = env.reset() + robot.reset(right_hand_joints_pos=right_hand_joints_pos, head_joints_pos=head_joints) + + # Step simulator a few times so that the effects of "reset" take place + for _ in range(10): + og.sim.step() + + obs, obs_info = env.get_obs() + + proprio = robot._get_proprioception_dict() + # add eef pose and base pose to proprio + proprio["left_eef_pos"], proprio["left_eef_orn"] = robot.get_relative_eef_pose(arm="left") + proprio["right_eef_pos"], proprio["right_eef_orn"] = robot.get_relative_eef_pose(arm="right") + proprio["base_pos"], proprio["base_orn"] = robot.get_position_orientation() + + is_contact = detect_robot_collision_in_sim(robot) + + +def main(args): + np.random.seed(6) + # Load the config + config_filename = os.path.join(og.example_config_path, "tiago_primitives.yaml") + config = yaml.load(open(config_filename, "r"), Loader=yaml.FullLoader) + + # Update it to create a custom environment and run some actions + config["scene"]["scene_model"] = "Rs_int" + config["scene"]["load_object_categories"] = ["floors", "ceilings", "walls", "coffee_table"] + + config["objects"] = [ + { + "type": "PrimitiveObject", + "name": "box", + "primitive_type": "Cube", + "manipulable": True, + # ^ Should the robot be allowed to interact w/ this object? + # Used if need to randomly select object in DialogEnvironment + "rgba": [1.0, 0, 0, 1.0], + "scale": [0.1, 0.05, 0.1], + # "size": 0.05, + "position": [-0.5, -0.7, 0.5], + "orientation": [0.0004835024010390043, -0.00029672126402147114, -0.11094563454389572, 0.9938263297080994], + }, + { + "type": "DatasetObject", + "name": "table", + "category": "breakfast_table", + "model": "rjgmmy", + "manipulable": False, + "scale": [0.3, 0.3, 0.3], + "position": [-0.7, 0.5, 0.2], + "orientation": [0, 0, 0, 1], + }, + { + "type": "PrimitiveObject", + "name": "pad", + "primitive_type": "Disk", + "rgba": [0.0, 0, 1.0, 1.0], + "radius": 0.08, + "position": [-0.3, -0.8, 0.5], + }, + ] + + # Load the environment + env = og.Environment(configs=config) + scene = env.scene + robot = env.robots[0] + + vid_logger = VideoLogger(args, env) + action_primitives = PickPlaceSemanticActionPrimitives(env, vid_logger, enable_head_tracking=False) + + subtask_success_counter = Counter() + for i in range(args.num_trajs): + print(f"---------------- Episode {i} ------------------") + start_time = time.time() + + custom_reset(env, robot, args, vid_logger) + + obs, obs_info = env.get_obs() + + for _ in range(50): + og.sim.step() + vid_logger.save_im_text() + + pick_success = action_primitives._grasp("box") + place_success = action_primitives._place_on_top("box", "pad") + + task_success = pick_success and place_success + num_subtasks_completed = [int(pick_success), int(place_success), 0].index(0) + vid_logger.make_video(prefix=f"rew{num_subtasks_completed}") + subtask_success_counter["entire"] += int(task_success) + subtask_success_counter["pick"] += int(pick_success) + subtask_success_counter["place"] += int(place_success) + + print(f"num successes: {subtask_success_counter['entire']} / {i + 1}\n{subtask_success_counter}") + + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Episode {i}: execution time: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + """ + Example usage: + python omnigibson/examples/action_primitives/pick_place_example.py --out-dir /home/albert/scratch/20240924 --num-trajs 1 + """ + start_time = time.time() + parser = ArgumentParser() + parser.add_argument("--out-dir", type=str, required=True) + parser.add_argument("--num-trajs", type=int, required=True) + parser.add_argument("--vid-downscale-factor", type=float, default=2.0) + parser.add_argument("--vid-speedup", type=float, default=2) + args = parser.parse_args() + main(args) + end_time = time.time() + + elapsed_time = end_time - start_time + print(f"Total execution time: {elapsed_time:.2f} seconds") diff --git a/omnigibson/objects/controllable_object.py b/omnigibson/objects/controllable_object.py index 691a7e0ae..c323d07b6 100644 --- a/omnigibson/objects/controllable_object.py +++ b/omnigibson/objects/controllable_object.py @@ -436,7 +436,7 @@ def step(self): # By default, the control type is None and the control value is 0 (th.zeros) - i.e. no control applied u_type_vec = th.tensor([ControlType.NONE] * self.n_dof) for group, ctrl in control.items(): - idx = self._controllers[group].dof_idx + idx = self._controllers[group].dof_idx.long() u_vec[idx] = ctrl["value"] u_type_vec[idx] = ctrl["type"] diff --git a/omnigibson/prims/entity_prim.py b/omnigibson/prims/entity_prim.py index fd0a512dc..f51983b63 100644 --- a/omnigibson/prims/entity_prim.py +++ b/omnigibson/prims/entity_prim.py @@ -1018,7 +1018,7 @@ def set_position_orientation(self, position=None, orientation=None, frame: Liter frame (Literal): The frame in which to set the position and orientation. Defaults to world. scene frame sets position relative to the scene. """ - assert frame in ["world", "scene"], f"Invalid frame '{frame}'. Must be 'world', or 'scene'." + assert frame in ["world", "scene", "parent"], f"Invalid frame '{frame}'. Must be 'world', 'scene', or 'parent'" # If kinematic only, clear cache for the root link if self.kinematic_only: diff --git a/omnigibson/prims/xform_prim.py b/omnigibson/prims/xform_prim.py index 27f71f8f0..6fea9561c 100644 --- a/omnigibson/prims/xform_prim.py +++ b/omnigibson/prims/xform_prim.py @@ -344,7 +344,7 @@ def get_local_pose(self): logger.warning( 'get_local_pose is deprecated and will be removed in a future release. Use get_position_orientation(frame="parent") instead' ) - return self.get_position_orientation(self.prim_path, frame="parent") + return self.get_position_orientation(frame="parent") def set_local_pose(self, position=None, orientation=None): """ @@ -359,7 +359,7 @@ def set_local_pose(self, position=None, orientation=None): logger.warning( 'set_local_pose is deprecated and will be removed in a future release. Use set_position_orientation(position=position, orientation=orientation, frame="parent") instead' ) - return self.set_position_orientation(self.prim_path, position, orientation, frame="parent") + return self.set_position_orientation(position, orientation, frame="parent") def get_world_scale(self): """ diff --git a/omnigibson/robots/manipulation_robot.py b/omnigibson/robots/manipulation_robot.py index b4d76f5c7..9960c5d1e 100644 --- a/omnigibson/robots/manipulation_robot.py +++ b/omnigibson/robots/manipulation_robot.py @@ -1306,7 +1306,7 @@ def _handle_assisted_grasping(self): # stays the same across different controllers and control modes (absolute / delta). This way, # a zero action will actually keep the AG setting where it already is. controller = self._controllers[f"gripper_{arm}"] - controlled_joints = controller.dof_idx + controlled_joints = controller.dof_idx.long() threshold = th.mean( th.stack([self.joint_lower_limits[controlled_joints], self.joint_upper_limits[controlled_joints]]), dim=0, diff --git a/omnigibson/robots/tiago.py b/omnigibson/robots/tiago.py index 54c80528d..9a9e9fea1 100644 --- a/omnigibson/robots/tiago.py +++ b/omnigibson/robots/tiago.py @@ -57,6 +57,7 @@ def __init__( default_arm_pose="diagonal15", # Unique to Tiago variant="default", + default_arm_side="left", **kwargs, ): """ @@ -109,6 +110,7 @@ def __init__( {"vertical", "diagonal15", "diagonal30", "diagonal45", "horizontal"} If either reset_joint_pos is not None or default_reset_mode is "tuck", this will be ignored. Otherwise the reset_joint_pos will be initialized to the precomputed joint positions that represents default_arm_pose. + default_arm_side (str): One of {"left", "right"}. Used to set self.default_arm attribute. variant (str): Which variant of the robot should be loaded. One of "default", "wrist_cam" kwargs (dict): Additional keyword arguments that are used for other super() calls from subclasses, allowing for flexible compositions of various object subclasses (e.g.: Robot is USDObject + ControllableObject). @@ -116,6 +118,7 @@ def __init__( # Store args assert variant in ("default", "wrist_cam"), f"Invalid Tiago variant specified {variant}!" self._variant = variant + self.default_arm_side = default_arm_side # Run super init super().__init__( @@ -160,6 +163,15 @@ def n_arms(cls): def arm_names(cls): return ["left", "right"] + @property + def default_arm(self): + if self.default_arm_side == "left": + return self.arm_names[0] + elif self.default_arm_side == "right": + return self.arm_names[1] + else: + raise NotImplementedError + @property def tucked_default_joint_pos(self): pos = th.zeros(self.n_dof) @@ -200,6 +212,23 @@ def discrete_action_list(self): def _create_discrete_action_space(self): raise ValueError("Tiago does not support discrete actions!") + def reset(self, right_hand_joints_pos=None, head_joints_pos=None): + """ + Reset should not change the robot base pose. + We need to cache and restore the base joints to the world. + """ + base_joint_positions = self.get_joint_positions()[self.base_idx] + super().reset() + self.set_joint_positions(base_joint_positions, indices=self.base_idx) + + # set the head pose + if head_joints_pos is not None: + self.set_joint_positions(head_joints_pos, indices=self.camera_control_idx) + + # reset the hand joints to a specific position + if right_hand_joints_pos is not None: + self.set_joint_positions(right_hand_joints_pos, indices=self.arm_control_idx["right"]) + def _post_load(self): super()._post_load() # The eef gripper links should be visual-only. They only contain a "ghost" box volume for detecting objects @@ -214,6 +243,17 @@ def _post_load(self): def base_footprint_link_name(self): return "base_footprint" + def _get_proprioception_dict(self): + dic = super()._get_proprioception_dict() + + # Add trunk info + joint_positions = ControllableObjectViewAPI.get_joint_positions(self.articulation_root_path) + joint_velocities = ControllableObjectViewAPI.get_joint_velocities(self.articulation_root_path) + dic["trunk_qpos"] = joint_positions[self.trunk_control_idx] + dic["trunk_qvel"] = joint_velocities[self.trunk_control_idx] + + return dic + @property def controller_order(self): controllers = ["base", "camera"] @@ -235,6 +275,52 @@ def _default_controllers(self): controllers["gripper_{}".format(arm)] = "MultiFingerGripperController" return controllers + @property + def _default_base_controller_configs(self): + dic = { + "name": "JointController", + "control_freq": self._control_freq, + "control_limits": self.control_limits, + "use_delta_commands": False, + "use_impedances": False, + "motor_type": "velocity", + "dof_idx": self.base_control_idx, + } + return dic + + @property + def _default_controller_config(self): + # Grab defaults from super method first + cfg = super()._default_controller_config + + # Get default base controller for omnidirectional Tiago + cfg["base"] = {"JointController": self._default_base_controller_configs} + + for arm in self.arm_names: + for arm_cfg in cfg["arm_{}".format(arm)].values(): + + if arm == "left": + # Need to override joint idx being controlled to include trunk in default arm controller configs + arm_control_idx = th.cat([self.trunk_control_idx, self.arm_control_idx[arm]]) + arm_cfg["dof_idx"] = arm_control_idx + + # Need to modify the default joint positions also if this is a null joint controller + if arm_cfg["name"] == "NullJointController": + arm_cfg["default_command"] = self.reset_joint_pos[arm_control_idx] + + # If using rigid trunk, we also clamp its limits + # TODO: How to handle for right arm which has a fixed trunk internally even though the trunk is moving + # via the left arm?? + if self.rigid_trunk: + arm_cfg["control_limits"]["position"][0][self.trunk_control_idx] = self.untucked_default_joint_pos[ + self.trunk_control_idx + ] + arm_cfg["control_limits"]["position"][1][self.trunk_control_idx] = self.untucked_default_joint_pos[ + self.trunk_control_idx + ] + + return cfg + @property def assisted_grasp_start_points(self): return { diff --git a/omnigibson/utils/deprecated_utils.py b/omnigibson/utils/deprecated_utils.py index 3df775f22..9f54f6ddb 100644 --- a/omnigibson/utils/deprecated_utils.py +++ b/omnigibson/utils/deprecated_utils.py @@ -529,9 +529,10 @@ def set_joint_positions( carb.log_warn("ArticulationView needs to be initialized.") return if not omni.timeline.get_timeline_interface().is_stopped() and self._physics_view is not None: + positions = positions.float() indices = self._backend_utils.resolve_indices(indices, self.count, self._device) joint_indices = self._backend_utils.resolve_indices(joint_indices, self.num_dof, self._device) - new_dof_pos = self._physics_view.get_dof_positions() + new_dof_pos = self._physics_view.get_dof_positions().float() new_dof_pos = self._backend_utils.assign( self._backend_utils.move_data(positions, device=self._device), new_dof_pos, diff --git a/omnigibson/utils/motion_planning_utils.py b/omnigibson/utils/motion_planning_utils.py index 3cfe392c9..232e836d4 100644 --- a/omnigibson/utils/motion_planning_utils.py +++ b/omnigibson/utils/motion_planning_utils.py @@ -106,6 +106,7 @@ def get_angle_between_poses(p1, p2): segment = [] segment.append(p2[0] - p1[0]) segment.append(p2[1] - p1[1]) + segment = th.tensor(segment) return th.arctan2(segment[1], segment[0]) def create_state(space, x, y, yaw): @@ -122,7 +123,7 @@ def state_valid_fn(q): x = q.getX() y = q.getY() yaw = q.getYaw() - pose = ([x, y, 0.0], T.euler2quat((0, 0, yaw))) + pose = ([x, y, 0.0], T.euler2quat(th.tensor([0, 0, yaw]))) return not set_base_and_detect_collision(context, pose) def remove_unnecessary_rotations(path): @@ -364,7 +365,7 @@ def state_valid_fn(q): eef_pose = [q[i] for i in range(6)] control_joint_pos = ik_solver.solve( target_pos=eef_pose[:3], - target_quat=T.axisangle2quat(eef_pose[3:]), + target_quat=T.axisangle2quat(th.tensor(eef_pose[3:])), max_iterations=1000, ) @@ -479,6 +480,7 @@ def set_arm_and_detect_collision(context, joint_pos): if link in robot_copy.meshes[robot_copy_type].keys(): for mesh_name, mesh in robot_copy.meshes[robot_copy_type][link].items(): relative_pose = robot_copy.relative_poses[robot_copy_type][link][mesh_name] + pose = [th.tensor(arr) for arr in pose] mesh_pose = T.pose_transform(*pose, *relative_pose) translation = lazy.pxr.Gf.Vec3d(*th.tensor(mesh_pose[0], dtype=th.float32).tolist()) mesh.GetAttribute("xformOp:translate").Set(translation) diff --git a/omnigibson/utils/video_logging_utils.py b/omnigibson/utils/video_logging_utils.py new file mode 100644 index 000000000..ea6ee817d --- /dev/null +++ b/omnigibson/utils/video_logging_utils.py @@ -0,0 +1,192 @@ +import datetime +import os + +import cv2 +import numpy as np +import torch as th +from matplotlib import font_manager +from PIL import Image, ImageDraw, ImageFont + +import omnigibson as og + + +class VideoLogger: + def __init__(self, args, env=None): + self.vid_downscale_factor = args.vid_downscale_factor + assert isinstance(args.vid_speedup, int) + self.vid_speedup = args.vid_speedup + self.env = env + # ^ only needed if using save_im_text(...) instead of save_obs(...) + now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + self.out_dir = os.path.join(args.out_dir, now) + os.makedirs(self.out_dir, exist_ok=False) + self.clear_ims() + + # Text settings + self.text_font_size = 36 + self.line_spacing = 1.2 + self.num_frames_to_show_text = 45 * args.vid_speedup # num frames to keep text on for the video + self.text_to_num_frames_remaining_map = {} + font = font_manager.FontProperties(family="Ubuntu", style="italic") + italics = ImageFont.truetype(font_manager.findfont(font), self.text_font_size) + font = font_manager.FontProperties(family="Ubuntu") + non_italics = ImageFont.truetype(font_manager.findfont(font), self.text_font_size) + self.fonts = dict( + italics=italics, + non_italics=non_italics, + ) + self.text_align = "top-left" + + def save_im_text(self, text=""): + im = og.sim.viewer_camera._get_obs()[0]["rgb"][:, :, :3] + obs, obs_info = self.env.get_obs() + im_robot_view = obs["robot0"]["robot0:eyes:Camera:0"]["rgb"][:, :, :3] + self.save_obs(im, im_robot_view, text) + + def save_obs(self, im_arr, im_arr_robot_view, text=""): + def resize_im_arr(im_arr, downscale_factor=1): + # assumes im_arr is (H, W, 3) + im_arr = cv2.resize(im_arr, tuple([int(x) for x in (np.array(im_arr.shape[:2][::-1]) // downscale_factor)])) + return im_arr + + if th.is_tensor(im_arr): + im_arr = im_arr.cpu().numpy() + if th.is_tensor(im_arr_robot_view): + im_arr_robot_view = im_arr_robot_view.cpu().numpy() + + im_arr = resize_im_arr(im_arr, self.vid_downscale_factor) + im_arr_robot_view = resize_im_arr(im_arr_robot_view) + im_arr_w_robot_view = self.overlay_robot_view_on_imgs(im_arr, im_arr_robot_view) + im_arr_w_robot_view = self.maybe_add_text(im_arr_w_robot_view, text) + self.ims.append(im_arr) + self.robot_view_ims.append(im_arr_robot_view) + self.ims_w_robot_view.append(im_arr_w_robot_view) + + def save_obs_batch(self, im_arr_list, im_arr_robot_view): + assert len(im_arr_list) == len(im_arr_robot_view) + for im_arr, im_arr_robot_view in zip(im_arr_list, im_arr_robot_view): + self.save_obs(im_arr, im_arr_robot_view) + + def get_textbox_size(self, text, font): + # Get image text size on dummy image + im_dummy = Image.new(mode="P", size=(0, 0)) + draw_dummy = ImageDraw.Draw(im_dummy) + _, _, text_w, text_h = draw_dummy.textbbox((0, 0), text=text, font=self.fonts["italics"]) + return text_w, text_h + + def maybe_add_text(self, im_arr, new_text=""): + if new_text: + self.text_to_num_frames_remaining_map[new_text] = self.num_frames_to_show_text + + im = Image.fromarray(im_arr) + im_w, im_h = im.size + draw = ImageDraw.Draw(im) + + # Draw speedup + speedup_text = f"{self.vid_speedup}x" + text_w, text_h = self.get_textbox_size(speedup_text, self.fonts["non_italics"]) + draw.text( + (0.98 * (im_w - text_w), 0.98 * (im_h - text_h)), speedup_text, font=self.fonts["non_italics"], fill="white" + ) + + # Refresh counters at the end + keys_to_remove = [] + num_active_texts = len( + [ + (text, num_frames_left) + for text, num_frames_left in self.text_to_num_frames_remaining_map.items() + if num_frames_left > 0 + ] + ) + + texts_to_draw = [] + active_text_idx = 0 + for text, num_frames_left in self.text_to_num_frames_remaining_map.items(): + + # No longer an active word; was placed on enough frames already. + if num_frames_left <= 0: + keys_to_remove.append(text) + continue + + # Center the text + # Get image text size on dummy image + text_w, text_h = self.get_textbox_size(text, self.fonts["italics"]) + if self.text_align == "center": + x = 0.5 * (im_w - text_w) + y = 0.5 * (im_h - (num_active_texts - active_text_idx) * self.line_spacing * text_h) + elif self.text_align == "top-left": + x = 0.05 * im_w + y = 0.05 * im_h + active_text_idx * self.line_spacing * text_h + else: + raise NotImplementedError + + color = (0xF4, 0xE5, 0xBB) if "robot" in text.lower() else (0xBB, 0xCA, 0xF4) + # draw.text((x, y), text, font=self.fonts['italics'], fill=color) + texts_to_draw.append((text, (x, y), (text_w, text_h), self.fonts["italics"], color)) + + active_text_idx += 1 + self.text_to_num_frames_remaining_map[text] -= 1 + + if len(texts_to_draw) > 0: + # Draw translucent box behind text + draw = ImageDraw.Draw(im, "RGBA") + pad = 0.2 * self.text_font_size + + top_text_xy_pos = texts_to_draw[0][1] + bottom_text_xy_pos = texts_to_draw[-1][1] + largest_width_text_pos_box_size = max( + [(pos, box_size) for _, pos, box_size, _, _ in texts_to_draw], key=lambda pos_size: pos_size[1][0] + ) + largest_width_text_xy_pos, largest_width_text_xy_size = largest_width_text_pos_box_size + + bottom_text_xy_size = texts_to_draw[-1][2] + min_x, min_y = np.array([largest_width_text_xy_pos[0], top_text_xy_pos[1]]) - pad + max_x = largest_width_text_xy_pos[0] + largest_width_text_xy_size[0] + pad + max_y = bottom_text_xy_pos[1] + bottom_text_xy_size[1] + pad + draw.rectangle(((min_x, min_y), (max_x, max_y)), fill=(0, 0, 0, 64)) + + # Actually draw the text over the box + for text, pos, _, font, color in texts_to_draw: + draw.text(pos, text, font=font, fill=color) + + for key in keys_to_remove: + self.text_to_num_frames_remaining_map.pop(key) + + return np.asarray(im) + + def overlay_robot_view_on_imgs(self, im, robot_view_im): + assert im.shape[:2] > robot_view_im.shape[:2] + im_w_robot_view = np.copy(im) + h, w = robot_view_im.shape[:2] + # Make patch in upper right + im_w_robot_view[:h, -w:, :] = robot_view_im + return im_w_robot_view + + def clear_ims(self): + self.ims = [] + self.robot_view_ims = [] + self.ims_w_robot_view = [] + + def make_video(self, prefix): + imgs = np.array(self.ims_w_robot_view) + self.save_video(imgs, prefix) + self.clear_ims() + + def save_video(self, imgs, prefix): + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + frame_height, frame_width = imgs[0].shape[:2] + now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + out_dir = os.path.join(self.out_dir, prefix) + if not os.path.exists(out_dir): + os.makedirs(out_dir) + out_path = os.path.join(out_dir, f"{prefix}_{now}.mp4") + out = cv2.VideoWriter(out_path, fourcc, 30.0, (frame_width, frame_height)) + for i, frame in enumerate(imgs): + if i % self.vid_speedup != 0: + # Drop the frames to cause speedup. + continue + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + out.write(frame) + out.release() + cv2.destroyAllWindows() + print(f"saved video to {out_path}") diff --git a/omnigibson/utils/vision_utils.py b/omnigibson/utils/vision_utils.py index 7cb6f54a4..c15497161 100644 --- a/omnigibson/utils/vision_utils.py +++ b/omnigibson/utils/vision_utils.py @@ -129,7 +129,7 @@ def remap(self, old_mapping, new_mapping, image, image_keys=None): self.key_array[key] = new_key # Apply remapping - remapped_img = self.key_array[image] + remapped_img = self.key_array[image.long()] # Make sure all values are correctly remapped and not equal to the default value assert th.all(remapped_img != th.iinfo(th.int32).max), "Not all keys in the image are in the key array!" remapped_labels = {}