Skip to content

[Feature] Add support for loading datasets from local Minari cache #3068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 237 additions & 15 deletions test/test_libs.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: new dependencies were added for creating custom Minari datasets from gym environments:

  • ale_py (for Atari)
  • gymnasium_robotics (for D4RL & Mujoco)

Original file line number Diff line number Diff line change
Expand Up @@ -3341,6 +3341,39 @@ def test_d4rl_iteration(self, task, split_trajs):

_MINARI_DATASETS = []

MUJOCO_ENVIRONMENTS = [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicitly specified current set of Minari supported datasets for integration with Gym environments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specified datasets are those used in Mujoco and D4RL

"Hopper-v5",
"Pusher-v5",
"Humanoid-v5",
"InvertedDoublePendulum-v5",
"HalfCheetah-v5",
"Swimmer-v5",
"Walker2d-v5",
"Ant-v5",
"Reacher-v5",
]

D4RL_ENVIRONMENTS = [
"AntMaze_UMaze-v5",
"AdroitHandPen-v1",
"AntMaze_Medium-v4",
"AntMaze_Large_Diverse_GR-v4",
"AntMaze_Large-v4",
"AntMaze_Medium_Diverse_GR-v4",
"PointMaze_OpenDense-v3",
"PointMaze_UMaze-v3",
"PointMaze_LargeDense-v3",
"PointMaze_Medium-v3",
"PointMaze_UMazeDense-v3",
"PointMaze_MediumDense-v3",
"PointMaze_Large-v3",
"PointMaze_Open-v3",
"FrankaKitchen-v1",
"AdroitHandDoor-v1",
"AdroitHandHammer-v1",
"AdroitHandRelocate-v1",
]


def _minari_init():
"""Initialize Minari datasets list. Returns True if already initialized."""
Expand Down Expand Up @@ -3373,30 +3406,155 @@ def _minari_init():
return False


# Initialize with placeholder values for parametrization
# These will be replaced with actual dataset names when the first Minari test runs
_MINARI_DATASETS = [str(i) for i in range(20)]
def get_random_minigrid_datasets():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only use Minari datasets from Minigrid

This is because current version of Minari cannot serialize custom MissionSpace objects, which are used in most Minigrid environments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means you cannot create your custom dataset from a minigrid environment directly; you will need to modify Mission space

"""
Fetch 5 random Minigrid datasets from the Minari server.
"""
import minari

all_minigrid = [
dataset
for dataset in minari.list_remote_datasets(
latest_version=True, compatible_minari_version=True
).keys()
if dataset.startswith("minigrid/")
]

if len(all_minigrid) < 5:
raise RuntimeError("Not enough minigrid datasets found on Minari server.")
indices = torch.randperm(len(all_minigrid))[:5]
return [all_minigrid[idx] for idx in indices]


def get_random_atari_envs():
"""
Fetch 10 random Atari environments using ale_py and torch.
"""
import ale_py
import gymnasium as gym

gym.register_envs(ale_py)

env_specs = gym.envs.registry.values()
all_env_ids = [env_spec.id for env_spec in env_specs]
atari_env_ids = [env_id for env_id in all_env_ids if env_id.startswith("ALE")]
if len(atari_env_ids) < 10:
raise RuntimeError("Not enough Atari environments found.")
indices = torch.randperm(len(atari_env_ids))[:10]
return [atari_env_ids[idx] for idx in indices]


def custom_minari_init(custom_envs, num_episodes=5):
"""
Initialize custom Minari datasets for the given environments.
"""
import gymnasium
import gymnasium_robotics
from minari import DataCollector

gymnasium.register_envs(gymnasium_robotics)

custom_dataset_ids = []
for env_id in custom_envs:
dataset_id = f"{env_id.lower()}/test-custom-local-v1"
env = gymnasium.make(env_id)
collector = DataCollector(env)

for ep in range(num_episodes):
collector.reset(seed=123 + ep)

while True:
action = collector.action_space.sample()
_, _, terminated, truncated, _ = collector.step(action)
if terminated or truncated:
break

collector.create_dataset(
dataset_id=dataset_id,
algorithm_name="RandomPolicy",
code_permalink="https://github.com/Farama-Foundation/Minari",
author="Farama",
author_email="[email protected]",
eval_env=env_id,
)
custom_dataset_ids.append(dataset_id)

return custom_dataset_ids


@pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari not found")
@pytest.mark.slow
class TestMinari:
@pytest.mark.parametrize("split", [False, True])
@pytest.mark.parametrize("dataset_idx", range(20))
@pytest.mark.parametrize(
"dataset_idx",
# Only use a static upper bound; do not call any function that imports minari globally.
range(50)
)
def test_load(self, dataset_idx, split):
# Initialize Minari datasets if not already done
if not _minari_init():
pytest.skip("Failed to initialize Minari datasets")
"""
Test loading from custom datasets for Mujoco and D4RL,
Minari remote datasets for Minigrid, and random Atari environments.
"""
import minari

# Get the actual dataset name from the initialized list
if dataset_idx >= len(_MINARI_DATASETS):
pytest.skip(f"Dataset index {dataset_idx} out of range")
custom_envs = MUJOCO_ENVIRONMENTS + D4RL_ENVIRONMENTS
num_custom = len(custom_envs)
try:
minigrid_datasets = get_random_minigrid_datasets()
except Exception:
minigrid_datasets = []
num_minigrid = len(minigrid_datasets)
try:
atari_envs = get_random_atari_envs()
except Exception:
atari_envs = []
num_atari = len(atari_envs)
total_datasets = num_custom + num_minigrid + num_atari

if dataset_idx >= total_datasets:
pytest.skip("Index out of range for available datasets")

if dataset_idx < num_custom:
# Custom dataset for Mujoco/D4RL
custom_dataset_ids = custom_minari_init(
[custom_envs[dataset_idx]], num_episodes=5
)
dataset_id = custom_dataset_ids[0]
data = MinariExperienceReplay(
dataset_id=dataset_id,
split_trajs=split,
batch_size=32,
load_from_local_minari=True,
)
cleanup_needed = True

elif dataset_idx < num_custom + num_minigrid:
# Minigrid datasets from Minari server
minigrid_idx = dataset_idx - num_custom
dataset_id = minigrid_datasets[minigrid_idx]
data = MinariExperienceReplay(
dataset_id=dataset_id,
batch_size=32,
split_trajs=split,
download="force",
)
cleanup_needed = False

else:
# Atari environment datasets
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Atari environments, select a random subset of environments.

atari_idx = dataset_idx - num_custom - num_minigrid
env_id = atari_envs[atari_idx]
custom_dataset_ids = custom_minari_init([env_id], num_episodes=5)
dataset_id = custom_dataset_ids[0]
data = MinariExperienceReplay(
dataset_id=dataset_id,
split_trajs=split,
batch_size=32,
load_from_local_minari=True,
)
cleanup_needed = True

selected_dataset = _MINARI_DATASETS[dataset_idx]
torchrl_logger.info(f"dataset {selected_dataset}")
data = MinariExperienceReplay(
selected_dataset, batch_size=32, split_trajs=split
)
t0 = time.time()
for i, sample in enumerate(data):
t1 = time.time()
Expand All @@ -3407,6 +3565,10 @@ def test_load(self, dataset_idx, split):
if i == 10:
break

# Clean up custom datasets after running local dataset tests
if cleanup_needed:
minari.delete_dataset(dataset_id=dataset_id)

def test_minari_preproc(self, tmpdir):
dataset = MinariExperienceReplay(
"D4RL/pointmaze/large-v2",
Expand Down Expand Up @@ -3453,6 +3615,66 @@ def fn(data):
assert sample["data"].shape == torch.Size([32, 8])
assert sample["next", "data"].shape == torch.Size([32, 8])

@pytest.mark.skipif(
not _has_minari or not _has_gymnasium, reason="Minari or Gym not available"
)
def test_local_minari_dataset_loading(self):
import minari
from minari import DataCollector

if not _minari_init():
pytest.skip("Failed to initialize Minari datasets")

dataset_id = "cartpole/test-local-v1"

# Create dataset using Gym + DataCollector
Copy link
Contributor Author

@Ibinarriaga8 Ibinarriaga8 Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Custom minari dataset creation from a gymnasium environment

env = gymnasium.make("CartPole-v1")
env = DataCollector(env, record_infos=True)
for _ in range(50):
env.reset(seed=123)
while True:
action = env.action_space.sample()
obs, rew, terminated, truncated, info = env.step(action)
if terminated or truncated:
break

env.create_dataset(
dataset_id=dataset_id,
algorithm_name="RandomPolicy",
code_permalink="https://github.com/Farama-Foundation/Minari",
author="Farama",
author_email="[email protected]",
eval_env="CartPole-v1",
)

# Load from local cache
data = MinariExperienceReplay(
dataset_id=dataset_id,
split_trajs=False,
batch_size=32,
download=False,
sampler=SamplerWithoutReplacement(drop_last=True),
prefetch=2,
load_from_local_minari=True,
)

t0 = time.time()
for i, sample in enumerate(data):
t1 = time.time()
torchrl_logger.info(
f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms"
)
assert data.metadata["action_space"].is_in(
sample["action"]
), "Invalid action sample"
assert data.metadata["observation_space"].is_in(
sample["observation"]
), "Invalid observation sample"
t0 = time.time()
if i == 10:
break

minari.delete_dataset(dataset_id="cartpole/test-local-v1")

@pytest.mark.slow
class TestRoboset:
Expand Down
52 changes: 46 additions & 6 deletions torchrl/data/datasets/minari_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ class MinariExperienceReplay(BaseDatasetExperienceReplay):
it is assumed that any ``truncated`` or ``terminated`` signal is
equivalent to the end of a trajectory.
Defaults to ``False``.
load_from_local_minari (bool, optional): if ``True``, the dataset will be loaded directly
from the local Minari cache (typically located at ``~/.minari/datasets``),
bypassing any remote download. This is useful when working with custom
Minari datasets previously generated and stored locally, or when network
access should be avoided. If the dataset is not found in the expected
cache directory, a ``FileNotFoundError`` will be raised.
Defaults to ``False``.


Attributes:
available_datasets: a list of accepted entries to be downloaded.
Expand Down Expand Up @@ -167,6 +175,7 @@ def __init__(
prefetch: int | None = None,
transform: torchrl.envs.Transform | None = None, # noqa-F821
split_trajs: bool = False,
load_from_local_minari: bool = False,
):
self.dataset_id = dataset_id
if root is None:
Expand All @@ -175,7 +184,13 @@ def __init__(
self.root = root
self.split_trajs = split_trajs
self.download = download
if self.download == "force" or (self.download and not self._is_downloaded()):
self.load_from_local_minari = load_from_local_minari

if (
self.download == "force"
or (self.download and not self._is_downloaded())
or self.load_from_local_minari
):
if self.download == "force":
try:
if os.path.exists(self.data_path_root):
Expand Down Expand Up @@ -240,13 +255,38 @@ def _download_and_preproc(self):

with tempfile.TemporaryDirectory() as tmpdir:
os.environ["MINARI_DATASETS_PATH"] = tmpdir
minari.download_dataset(dataset_id=self.dataset_id)
parent_dir = Path(tmpdir) / self.dataset_id / "data"

td_data = TensorDict()
total_steps = 0
torchrl_logger.info("first read through data to create data structure...")
h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5")
td_data = TensorDict()

if self.load_from_local_minari:
# Load minari dataset from user's local Minari cache

minari_cache_dir = os.path.expanduser("~/.minari/datasets")
os.environ["MINARI_DATASETS_PATH"] = minari_cache_dir
parent_dir = Path(minari_cache_dir) / self.dataset_id / "data"
h5_path = parent_dir / "main_data.hdf5"

if not h5_path.exists():
raise FileNotFoundError(
f"{h5_path} does not exist in local Minari cache!"
)

torchrl_logger.info(
f"loading dataset from local Minari cache at {h5_path}"
)
h5_data = PersistentTensorDict.from_h5(h5_path)

else:
minari.download_dataset(dataset_id=self.dataset_id)

parent_dir = Path(tmpdir) / self.dataset_id / "data"

torchrl_logger.info(
"first read through data to create data structure..."
)
h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5")

# populate the tensordict
episode_dict = {}
for i, (episode_key, episode) in enumerate(h5_data.items()):
Expand Down
Loading