-
Notifications
You must be signed in to change notification settings - Fork 391
[Feature] Add support for loading datasets from local Minari cache #3068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
26aa0ff
a56c508
e90f4d0
cc43f9a
23aecec
9321adb
a4993c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3341,6 +3341,39 @@ def test_d4rl_iteration(self, task, split_trajs): | |
|
||
_MINARI_DATASETS = [] | ||
|
||
MUJOCO_ENVIRONMENTS = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explicitly specified current set of Minari supported datasets for integration with Gym environments There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specified datasets are those used in Mujoco and D4RL |
||
"Hopper-v5", | ||
"Pusher-v5", | ||
"Humanoid-v5", | ||
"InvertedDoublePendulum-v5", | ||
"HalfCheetah-v5", | ||
"Swimmer-v5", | ||
"Walker2d-v5", | ||
"Ant-v5", | ||
"Reacher-v5", | ||
] | ||
|
||
D4RL_ENVIRONMENTS = [ | ||
"AntMaze_UMaze-v5", | ||
"AdroitHandPen-v1", | ||
"AntMaze_Medium-v4", | ||
"AntMaze_Large_Diverse_GR-v4", | ||
"AntMaze_Large-v4", | ||
"AntMaze_Medium_Diverse_GR-v4", | ||
"PointMaze_OpenDense-v3", | ||
"PointMaze_UMaze-v3", | ||
"PointMaze_LargeDense-v3", | ||
"PointMaze_Medium-v3", | ||
"PointMaze_UMazeDense-v3", | ||
"PointMaze_MediumDense-v3", | ||
"PointMaze_Large-v3", | ||
"PointMaze_Open-v3", | ||
"FrankaKitchen-v1", | ||
"AdroitHandDoor-v1", | ||
"AdroitHandHammer-v1", | ||
"AdroitHandRelocate-v1", | ||
] | ||
|
||
|
||
def _minari_init(): | ||
"""Initialize Minari datasets list. Returns True if already initialized.""" | ||
|
@@ -3373,30 +3406,155 @@ def _minari_init(): | |
return False | ||
|
||
|
||
# Initialize with placeholder values for parametrization | ||
# These will be replaced with actual dataset names when the first Minari test runs | ||
_MINARI_DATASETS = [str(i) for i in range(20)] | ||
def get_random_minigrid_datasets(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only use Minari datasets from Minigrid This is because current version of Minari cannot serialize custom MissionSpace objects, which are used in most Minigrid environments There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This means you cannot create your custom dataset from a minigrid environment directly; you will need to modify Mission space |
||
""" | ||
Fetch 5 random Minigrid datasets from the Minari server. | ||
""" | ||
import minari | ||
|
||
all_minigrid = [ | ||
dataset | ||
for dataset in minari.list_remote_datasets( | ||
latest_version=True, compatible_minari_version=True | ||
).keys() | ||
if dataset.startswith("minigrid/") | ||
] | ||
|
||
if len(all_minigrid) < 5: | ||
raise RuntimeError("Not enough minigrid datasets found on Minari server.") | ||
indices = torch.randperm(len(all_minigrid))[:5] | ||
return [all_minigrid[idx] for idx in indices] | ||
|
||
|
||
def get_random_atari_envs(): | ||
""" | ||
Fetch 10 random Atari environments using ale_py and torch. | ||
""" | ||
import ale_py | ||
import gymnasium as gym | ||
|
||
gym.register_envs(ale_py) | ||
|
||
env_specs = gym.envs.registry.values() | ||
all_env_ids = [env_spec.id for env_spec in env_specs] | ||
atari_env_ids = [env_id for env_id in all_env_ids if env_id.startswith("ALE")] | ||
if len(atari_env_ids) < 10: | ||
raise RuntimeError("Not enough Atari environments found.") | ||
indices = torch.randperm(len(atari_env_ids))[:10] | ||
return [atari_env_ids[idx] for idx in indices] | ||
|
||
|
||
def custom_minari_init(custom_envs, num_episodes=5): | ||
""" | ||
Initialize custom Minari datasets for the given environments. | ||
""" | ||
import gymnasium | ||
import gymnasium_robotics | ||
from minari import DataCollector | ||
|
||
gymnasium.register_envs(gymnasium_robotics) | ||
|
||
custom_dataset_ids = [] | ||
for env_id in custom_envs: | ||
dataset_id = f"{env_id.lower()}/test-custom-local-v1" | ||
env = gymnasium.make(env_id) | ||
collector = DataCollector(env) | ||
|
||
for ep in range(num_episodes): | ||
collector.reset(seed=123 + ep) | ||
|
||
while True: | ||
action = collector.action_space.sample() | ||
_, _, terminated, truncated, _ = collector.step(action) | ||
if terminated or truncated: | ||
break | ||
|
||
collector.create_dataset( | ||
dataset_id=dataset_id, | ||
algorithm_name="RandomPolicy", | ||
code_permalink="https://github.com/Farama-Foundation/Minari", | ||
author="Farama", | ||
author_email="[email protected]", | ||
eval_env=env_id, | ||
) | ||
custom_dataset_ids.append(dataset_id) | ||
|
||
return custom_dataset_ids | ||
|
||
|
||
@pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari not found") | ||
@pytest.mark.slow | ||
class TestMinari: | ||
@pytest.mark.parametrize("split", [False, True]) | ||
@pytest.mark.parametrize("dataset_idx", range(20)) | ||
@pytest.mark.parametrize( | ||
"dataset_idx", | ||
# Only use a static upper bound; do not call any function that imports minari globally. | ||
range(50) | ||
) | ||
def test_load(self, dataset_idx, split): | ||
# Initialize Minari datasets if not already done | ||
if not _minari_init(): | ||
pytest.skip("Failed to initialize Minari datasets") | ||
""" | ||
Test loading from custom datasets for Mujoco and D4RL, | ||
Minari remote datasets for Minigrid, and random Atari environments. | ||
""" | ||
import minari | ||
|
||
# Get the actual dataset name from the initialized list | ||
if dataset_idx >= len(_MINARI_DATASETS): | ||
pytest.skip(f"Dataset index {dataset_idx} out of range") | ||
custom_envs = MUJOCO_ENVIRONMENTS + D4RL_ENVIRONMENTS | ||
num_custom = len(custom_envs) | ||
try: | ||
minigrid_datasets = get_random_minigrid_datasets() | ||
except Exception: | ||
minigrid_datasets = [] | ||
num_minigrid = len(minigrid_datasets) | ||
try: | ||
atari_envs = get_random_atari_envs() | ||
except Exception: | ||
atari_envs = [] | ||
num_atari = len(atari_envs) | ||
total_datasets = num_custom + num_minigrid + num_atari | ||
|
||
if dataset_idx >= total_datasets: | ||
pytest.skip("Index out of range for available datasets") | ||
|
||
if dataset_idx < num_custom: | ||
# Custom dataset for Mujoco/D4RL | ||
custom_dataset_ids = custom_minari_init( | ||
[custom_envs[dataset_idx]], num_episodes=5 | ||
) | ||
dataset_id = custom_dataset_ids[0] | ||
data = MinariExperienceReplay( | ||
dataset_id=dataset_id, | ||
split_trajs=split, | ||
batch_size=32, | ||
load_from_local_minari=True, | ||
) | ||
cleanup_needed = True | ||
|
||
elif dataset_idx < num_custom + num_minigrid: | ||
# Minigrid datasets from Minari server | ||
minigrid_idx = dataset_idx - num_custom | ||
dataset_id = minigrid_datasets[minigrid_idx] | ||
data = MinariExperienceReplay( | ||
dataset_id=dataset_id, | ||
batch_size=32, | ||
split_trajs=split, | ||
download="force", | ||
) | ||
cleanup_needed = False | ||
|
||
else: | ||
# Atari environment datasets | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For Atari environments, select a random subset of environments. |
||
atari_idx = dataset_idx - num_custom - num_minigrid | ||
env_id = atari_envs[atari_idx] | ||
custom_dataset_ids = custom_minari_init([env_id], num_episodes=5) | ||
dataset_id = custom_dataset_ids[0] | ||
data = MinariExperienceReplay( | ||
dataset_id=dataset_id, | ||
split_trajs=split, | ||
batch_size=32, | ||
load_from_local_minari=True, | ||
) | ||
cleanup_needed = True | ||
|
||
selected_dataset = _MINARI_DATASETS[dataset_idx] | ||
torchrl_logger.info(f"dataset {selected_dataset}") | ||
data = MinariExperienceReplay( | ||
selected_dataset, batch_size=32, split_trajs=split | ||
) | ||
t0 = time.time() | ||
for i, sample in enumerate(data): | ||
t1 = time.time() | ||
|
@@ -3407,6 +3565,10 @@ def test_load(self, dataset_idx, split): | |
if i == 10: | ||
break | ||
|
||
# Clean up custom datasets after running local dataset tests | ||
if cleanup_needed: | ||
minari.delete_dataset(dataset_id=dataset_id) | ||
|
||
def test_minari_preproc(self, tmpdir): | ||
dataset = MinariExperienceReplay( | ||
"D4RL/pointmaze/large-v2", | ||
|
@@ -3453,6 +3615,66 @@ def fn(data): | |
assert sample["data"].shape == torch.Size([32, 8]) | ||
assert sample["next", "data"].shape == torch.Size([32, 8]) | ||
|
||
@pytest.mark.skipif( | ||
not _has_minari or not _has_gymnasium, reason="Minari or Gym not available" | ||
) | ||
def test_local_minari_dataset_loading(self): | ||
import minari | ||
from minari import DataCollector | ||
|
||
if not _minari_init(): | ||
pytest.skip("Failed to initialize Minari datasets") | ||
|
||
dataset_id = "cartpole/test-local-v1" | ||
|
||
# Create dataset using Gym + DataCollector | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Custom minari dataset creation from a gymnasium environment |
||
env = gymnasium.make("CartPole-v1") | ||
env = DataCollector(env, record_infos=True) | ||
for _ in range(50): | ||
env.reset(seed=123) | ||
while True: | ||
action = env.action_space.sample() | ||
obs, rew, terminated, truncated, info = env.step(action) | ||
if terminated or truncated: | ||
break | ||
|
||
env.create_dataset( | ||
dataset_id=dataset_id, | ||
algorithm_name="RandomPolicy", | ||
code_permalink="https://github.com/Farama-Foundation/Minari", | ||
author="Farama", | ||
author_email="[email protected]", | ||
eval_env="CartPole-v1", | ||
) | ||
|
||
# Load from local cache | ||
data = MinariExperienceReplay( | ||
dataset_id=dataset_id, | ||
split_trajs=False, | ||
batch_size=32, | ||
download=False, | ||
sampler=SamplerWithoutReplacement(drop_last=True), | ||
prefetch=2, | ||
load_from_local_minari=True, | ||
) | ||
|
||
t0 = time.time() | ||
for i, sample in enumerate(data): | ||
t1 = time.time() | ||
torchrl_logger.info( | ||
f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms" | ||
) | ||
assert data.metadata["action_space"].is_in( | ||
sample["action"] | ||
), "Invalid action sample" | ||
assert data.metadata["observation_space"].is_in( | ||
sample["observation"] | ||
), "Invalid observation sample" | ||
t0 = time.time() | ||
if i == 10: | ||
break | ||
|
||
minari.delete_dataset(dataset_id="cartpole/test-local-v1") | ||
|
||
@pytest.mark.slow | ||
class TestRoboset: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: new dependencies were added for creating custom Minari datasets from gym environments: