Skip to content

Commit

Permalink
update data wrapper for recording videos
Browse files Browse the repository at this point in the history
  • Loading branch information
cremebrule committed Jan 22, 2025
1 parent d49b5f9 commit efe7808
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/demo_collection.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ playback_env = DataPlaybackWrapper.create_from_hdf5(
)

# Playback the entire dataset and record observations
playback_env.playback_dataset(record=True)
playback_env.playback_dataset(record_data=True)

# Save the recorded playback data
playback_env.save_data()
Expand Down
66 changes: 54 additions & 12 deletions omnigibson/envs/data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
from pathlib import Path

import h5py
import imageio
import torch as th

import omnigibson as og
import omnigibson.lazy as lazy
from omnigibson.envs.env_wrapper import EnvironmentWrapper
from omnigibson.envs.env_wrapper import EnvironmentWrapper, create_wrapper
from omnigibson.macros import gm
from omnigibson.objects.object_base import BaseObject
from omnigibson.sensors.vision_sensor import VisionSensor
from omnigibson.utils.config_utils import TorchEncoder
from omnigibson.utils.python_utils import create_object_from_init_info, h5py_group_to_torch
from omnigibson.utils.python_utils import create_object_from_init_info, h5py_group_to_torch, assert_valid_key
from omnigibson.utils.ui_utils import create_module_logger

# Create module logger
Expand Down Expand Up @@ -420,13 +421,14 @@ def create_from_hdf5(
cls,
input_path,
output_path,
robot_obs_modalities,
robot_obs_modalities=tuple(),
robot_sensor_config=None,
external_sensors_config=None,
include_sensor_names=None,
exclude_sensor_names=None,
n_render_iterations=5,
only_successes=False,
include_env_wrapper=False,
):
"""
Create a DataPlaybackWrapper environment instance form the recorded demonstration info
Expand Down Expand Up @@ -458,6 +460,7 @@ def create_from_hdf5(
the physical state changes. Increasing this number will improve the rendered quality at the expense of
speed.
only_successes (bool): Whether to only save successful episodes
include_env_wrapper (bool): Whether to include environment wrapper stored in the underlying env config
Returns:
DataPlaybackWrapper: Generated playback environment
Expand All @@ -483,18 +486,27 @@ def create_from_hdf5(

# Set observation modalities and update sensor config
for robot_cfg in config["robots"]:
robot_cfg["obs_modalities"] = robot_obs_modalities
robot_cfg["obs_modalities"] = list(robot_obs_modalities)
robot_cfg["include_sensor_names"] = include_sensor_names
robot_cfg["exclude_sensor_names"] = exclude_sensor_names

if robot_sensor_config is not None:
robot_cfg["sensor_config"] = robot_sensor_config
if external_sensors_config is not None:
config["env"]["external_sensors"] = external_sensors_config
# config["env"]["external_sensors"] = merge_nested_dicts(
# base_dict=config["env"]["external_sensors"],
# extra_dict=external_sensors_config,
# inplace=True,
# )

# Load env
env = og.Environment(configs=config)

# Optionally include the desired environment wrapper specified in the config
if include_env_wrapper:
env = create_wrapper(env=env)

# Wrap and return env
return cls(
env=env,
Expand Down Expand Up @@ -548,14 +560,17 @@ def _parse_step_data(self, action, obs, reward, terminated, truncated, info):
step_data["truncated"] = truncated
return step_data

def playback_episode(self, episode_id, record=True):
def playback_episode(self, episode_id, record_data=True, video_writer=None, video_rgb_key=None):
"""
Playback episode @episode_id, and optionally record observation data if @record is True
Args:
episode_id (int): Episode to playback. This should be a valid demo ID number from the inputted collected
data hdf5 file
record (bool): Whether to record data during playback or not
record_data (bool): Whether to record data during playback or not
video_writer (None or imageio.Writer): If specified, writer object that RGB frames will be written to
video_rgb_key (None or str): If specified, observation key representing the RGB frames to write to video.
If @video_writer is specified, this must also be specified!
"""
data_grp = self.input_hdf5["data"]
assert f"demo_{episode_id}" in data_grp, f"No valid episode with ID {episode_id} found!"
Expand All @@ -579,7 +594,7 @@ def playback_episode(self, episode_id, record=True):
og.sim.load_state(state[0, : int(state_size[0])], serialized=True)

# If record, record initial observations
if record:
if record_data:
init_obs, _, _, _, init_info = self.env.step(action=action[0], n_render_iterations=self.n_render_iterations)
step_data = {"obs": self._process_obs(obs=init_obs, info=init_info)}
self.current_traj_history.append(step_data)
Expand Down Expand Up @@ -611,7 +626,7 @@ def playback_episode(self, episode_id, record=True):
self.current_obs, _, _, _, info = self.env.step(action=a, n_render_iterations=self.n_render_iterations)

# If recording, record data
if record:
if record_data:
step_data = self._parse_step_data(
action=a,
obs=self.current_obs,
Expand All @@ -622,17 +637,44 @@ def playback_episode(self, episode_id, record=True):
)
self.current_traj_history.append(step_data)

# If writing to video, write desired frame
if video_writer is not None:
assert_valid_key(video_rgb_key, self.current_obs.keys(), "video_rgb_key")
video_writer.append_data(self.current_obs[video_rgb_key][:, :, :3].numpy())

self.step_count += 1

if record:
if record_data:
self.flush_current_traj()

def playback_dataset(self, record=True):
def playback_dataset(self, record_data=False, video_writer=None, video_rgb_key=None):
"""
Playback all episodes from the input HDF5 file, and optionally record observation data if @record is True
Args:
record (bool): Whether to record data during playback or not
record_data (bool): Whether to record data during playback or not
video_writer (None or imageio.Writer): If specified, writer object that RGB frames will be written to
video_rgb_key (None or str): If specified, observation key representing the RGB frames to write to video.
If @video_writer is specified, this must also be specified!
"""
for episode_id in range(self.input_hdf5["data"].attrs["n_episodes"]):
self.playback_episode(episode_id=episode_id, record=record)
self.playback_episode(
episode_id=episode_id,
record_data=record_data,
video_writer=video_writer,
video_rgb_key=video_rgb_key,
)

def create_video_writer(self, fpath, fps=30):
"""
Creates a video writer to write video frames to when playing back the dataset
Args:
fpath (str): Absolute path that the generated video writer will write to. Should end in .mp4
fps (int): Desired frames per second when generating video
Returns:
imageio.Writer: Generated video writer
"""
assert fpath.endswith(".mp4"), f"Video writer fpath must end with .mp4! Got: {fpath}"
return imageio.get_writer(fpath, fps=fps)
2 changes: 1 addition & 1 deletion tests/test_data_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,5 @@ def test_data_collect_and_playback():
n_render_iterations=1,
only_successes=False,
)
env.playback_dataset(record=True)
env.playback_dataset(record_data=True)
env.save_data()

0 comments on commit efe7808

Please sign in to comment.