diff --git a/setup.py b/setup.py index 73ffd00ac..32a34614d 100644 --- a/setup.py +++ b/setup.py @@ -209,6 +209,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: "tensorboard>=1.14", "huggingface_sb3>=2.2.1", "datasets>=2.8.0", + "hydra-core>=1.3.2", ], tests_require=TESTS_REQUIRE, extras_require={ diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 62b459a0d..8ad7d795b 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -187,6 +187,7 @@ def __init__( if self.demo_batch_size % self.demo_minibatch_size != 0: raise ValueError("Batch size must be a multiple of minibatch size.") self._demo_data_loader = None + self._demonstrations: Optional[base.AnyTransitions] = None self._endless_expert_iterator = None super().__init__( demonstrations=demonstrations, @@ -298,12 +299,16 @@ def reward_test(self) -> reward_nets.RewardNet: """Reward used to train policy at "test" time after adversarial training.""" def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None: + self._demonstrations = demonstrations self._demo_data_loader = base.make_data_loader( demonstrations, self.demo_batch_size, ) self._endless_expert_iterator = util.endless_iter(self._demo_data_loader) + def get_demonstrations(self) -> Optional[base.AnyTransitions]: + return self._demonstrations + def _next_expert_batch(self) -> Mapping: assert self._endless_expert_iterator is not None return next(self._endless_expert_iterator) diff --git a/src/imitation_cli/__init__.py b/src/imitation_cli/__init__.py new file mode 100644 index 000000000..8e6132367 --- /dev/null +++ b/src/imitation_cli/__init__.py @@ -0,0 +1 @@ +"""Hydra configurations and scripts that form a CLI for imitation.""" diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py new file mode 100644 index 000000000..d74f4bd5b --- /dev/null +++ b/src/imitation_cli/airl.py @@ -0,0 +1,105 @@ +"""Config and run configuration for AIRL.""" +import dataclasses +import logging +import pathlib +from typing import Any, Dict, Sequence, cast + +import hydra +import torch as th +from hydra.core.config_store import ConfigStore +from hydra.utils import instantiate +from omegaconf import MISSING + +from imitation.policies import serialize +from imitation_cli.algorithm_configurations import airl as airl_cfg +from imitation_cli.utils import environment as environment_cfg +from imitation_cli.utils import ( + policy_evaluation, + randomness, + reward_network, + rl_algorithm, + trajectories, +) + + +@dataclasses.dataclass +class RunConfig: + """Config for running AIRL.""" + + rng: randomness.Config = randomness.Config(seed=0) + total_timesteps: int = int(1e6) + checkpoint_interval: int = 0 + + environment: environment_cfg.Config = MISSING + airl: airl_cfg.Config = MISSING + evaluation: policy_evaluation.Config = MISSING + # This ensures that the working directory is changed + # to the hydra output dir + hydra: Any = dataclasses.field(default_factory=lambda: dict(job=dict(chdir=True))) + + +cs = ConfigStore.instance() +environment_cfg.register_configs("environment", "${rng}") +trajectories.register_configs("airl/demonstrations", "${environment}", "${rng}") +rl_algorithm.register_configs("airl/gen_algo", "${environment}", "${rng.seed}") +reward_network.register_configs("airl/reward_net", "${environment}") +policy_evaluation.register_configs("evaluation", "${environment}", "${rng}") + +cs.store( + name="airl_run_base", + node=RunConfig( + airl=airl_cfg.Config( + venv="${environment}", # type: ignore[arg-type] + ), + ), +) + + +@hydra.main( + version_base=None, + config_path="config", + config_name="airl_run", +) +def run_airl(cfg: RunConfig) -> Dict[str, Any]: + from imitation.algorithms.adversarial import airl + from imitation.data import rollout + from imitation.data.types import TrajectoryWithRew + + trainer: airl.AIRL = instantiate(cfg.airl) + + checkpoints_path = pathlib.Path("checkpoints") + + def save(path: str): + """Save discriminator and generator.""" + # We implement this here and not in Trainer since we do not want to actually + # serialize the whole Trainer (including e.g. expert demonstrations). + save_path = checkpoints_path / path + save_path.mkdir(parents=True, exist_ok=True) + + th.save(trainer.reward_train, save_path / "reward_train.pt") + th.save(trainer.reward_test, save_path / "reward_test.pt") + serialize.save_stable_model(save_path / "gen_policy", trainer.gen_algo) + + def callback(round_num: int, /) -> None: + if cfg.checkpoint_interval > 0 and round_num % cfg.checkpoint_interval == 0: + logging.log(logging.INFO, f"Saving checkpoint at round {round_num}") + save(f"{round_num:05d}") + + trainer.train(cfg.total_timesteps, callback) + imit_stats = policy_evaluation.eval_policy(trainer.policy, cfg.evaluation) + + # Save final artifacts. + if cfg.checkpoint_interval >= 0: + logging.log(logging.INFO, "Saving final checkpoint.") + save("final") + + return { + "imit_stats": imit_stats, + "expert_stats": rollout.rollout_stats( + cast(Sequence[TrajectoryWithRew], trainer.get_demonstrations()), + ), + } + + +if __name__ == "__main__": + run_airl() diff --git a/src/imitation_cli/algorithm_configurations/__init__.py b/src/imitation_cli/algorithm_configurations/__init__.py new file mode 100644 index 000000000..dfe338fdc --- /dev/null +++ b/src/imitation_cli/algorithm_configurations/__init__.py @@ -0,0 +1 @@ +"""Structured Hydra configuration for Imitation algorithms.""" diff --git a/src/imitation_cli/algorithm_configurations/airl.py b/src/imitation_cli/algorithm_configurations/airl.py new file mode 100644 index 000000000..30e5309df --- /dev/null +++ b/src/imitation_cli/algorithm_configurations/airl.py @@ -0,0 +1,33 @@ +"""Config for AIRL.""" +import dataclasses +from typing import Optional + +from omegaconf import MISSING + +from imitation_cli.utils import environment as environment_cfg +from imitation_cli.utils import ( + optimizer_class, + reward_network, + rl_algorithm, + trajectories, +) + + +@dataclasses.dataclass +class Config: + """Config for AIRL.""" + + _target_: str = "imitation.algorithms.adversarial.airl.AIRL" + venv: environment_cfg.Config = MISSING + demonstrations: trajectories.Config = MISSING + gen_algo: rl_algorithm.Config = MISSING + reward_net: reward_network.Config = MISSING + demo_batch_size: int = 64 + n_disc_updates_per_round: int = 2 + disc_opt_cls: optimizer_class.Config = optimizer_class.Adam + gen_train_timesteps: Optional[int] = None + gen_replay_buffer_capacity: Optional[int] = None + init_tensorboard: bool = False + init_tensorboard_graph: bool = False + debug_use_ground_truth: bool = False + allow_variable_horizon: bool = False diff --git a/src/imitation_cli/config/airl_optuna.yaml b/src/imitation_cli/config/airl_optuna.yaml new file mode 100644 index 000000000..93224346b --- /dev/null +++ b/src/imitation_cli/config/airl_optuna.yaml @@ -0,0 +1,26 @@ +defaults: + - airl_run_base + - environment: gym_env + - airl/reward_net: shaped + - airl/gen_algo: ppo + - evaluation: default_evaluation + - airl/demonstrations: generated + - airl/demonstrations/expert_policy: random + - override hydra/sweeper: optuna + - _self_ + +total_timesteps: 40000 +checkpoint_interval: 1 + +airl: + demo_batch_size: 128 + demonstrations: + total_timesteps: 10 + allow_variable_horizon: true + +hydra: + mode: MULTIRUN + sweeper: + params: + environment: cartpole,pendulum + airl/reward_net: basic,shaped,small_ensemble diff --git a/src/imitation_cli/config/airl_run.yaml b/src/imitation_cli/config/airl_run.yaml new file mode 100644 index 000000000..2bb5364f0 --- /dev/null +++ b/src/imitation_cli/config/airl_run.yaml @@ -0,0 +1,19 @@ +defaults: + - airl_run_base + - environment: cartpole + - airl/reward_net: shaped + - airl/gen_algo: ppo + - evaluation: default_evaluation + - airl/demonstrations: generated + - airl/demonstrations/expert_policy: random +# - environment@airl.reward_net.environment: pendulum # This is how we inject a different environment + - _self_ + +total_timesteps: 40000 +checkpoint_interval: 1 + +airl: + demo_batch_size: 128 + demonstrations: + total_timesteps: 10 + allow_variable_horizon: true diff --git a/src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml b/src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml new file mode 100644 index 000000000..54ffc74cf --- /dev/null +++ b/src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml @@ -0,0 +1,25 @@ +defaults: + - airl_run_base + - environment: gym_env + - airl/reward_net: shaped + - airl/gen_algo: ppo + - evaluation: default_evaluation + - airl/demonstrations: generated + - airl/demonstrations/expert_policy: random + - _self_ + +total_timesteps: 40000 +checkpoint_interval: 1 + +airl: + demo_batch_size: 128 + demonstrations: + total_timesteps: 10 + allow_variable_horizon: true + +hydra: + mode: MULTIRUN + sweeper: + params: + environment: cartpole,pendulum + airl/reward_net: basic,shaped,small_ensemble diff --git a/src/imitation_cli/utils/__init__.py b/src/imitation_cli/utils/__init__.py new file mode 100644 index 000000000..f3dc34d7c --- /dev/null +++ b/src/imitation_cli/utils/__init__.py @@ -0,0 +1 @@ +"""Configurations to be used as ingredient to algorithm configurations.""" diff --git a/src/imitation_cli/utils/activation_function_class.py b/src/imitation_cli/utils/activation_function_class.py new file mode 100644 index 000000000..da51c070c --- /dev/null +++ b/src/imitation_cli/utils/activation_function_class.py @@ -0,0 +1,37 @@ +"""Classes for configuring activation functions.""" +import dataclasses +from enum import Enum + +import torch +from hydra.core.config_store import ConfigStore + + +class ActivationFunctionClass(Enum): + """Enum of activation function classes.""" + + TanH = torch.nn.Tanh + ReLU = torch.nn.ReLU + LeakyReLU = torch.nn.LeakyReLU + + +@dataclasses.dataclass +class Config: + """Base class for activation function configs.""" + + activation_function_class: ActivationFunctionClass + _target_: str = "imitation_cli.utils.activation_function_class.Config.make" + + @staticmethod + def make(activation_function_class: ActivationFunctionClass) -> type: + return activation_function_class.value + + +TanH = Config(ActivationFunctionClass.TanH) +ReLU = Config(ActivationFunctionClass.ReLU) +LeakyReLU = Config(ActivationFunctionClass.LeakyReLU) + + +def register_configs(group: str): + cs = ConfigStore.instance() + for cls in ActivationFunctionClass: + cs.store(group=group, name=cls.name.lower(), node=Config(cls)) diff --git a/src/imitation_cli/utils/environment.py b/src/imitation_cli/utils/environment.py new file mode 100644 index 000000000..d824965c9 --- /dev/null +++ b/src/imitation_cli/utils/environment.py @@ -0,0 +1,66 @@ +"""Configuration for Gym environments.""" +from __future__ import annotations + +import dataclasses +import typing +from typing import Optional, Union, cast + +if typing.TYPE_CHECKING: + from stable_baselines3.common.vec_env import VecEnv + +from hydra.core.config_store import ConfigStore +from hydra.utils import instantiate +from omegaconf import MISSING + +from imitation_cli.utils import randomness + + +@dataclasses.dataclass +class Config: + """Configuration for Gym environments.""" + + _target_: str = "imitation_cli.utils.environment.Config.make" + env_name: str = MISSING # The environment to train on + n_envs: int = 8 # number of environments in VecEnv + # TODO: when setting this to true this is really slow for some reason + parallel: bool = False # Use SubprocVecEnv rather than DummyVecEnv + max_episode_steps: int = MISSING # Set to positive int to limit episode horizons + env_make_kwargs: dict = dataclasses.field( + default_factory=dict, + ) # The kwargs passed to `spec.make`. + rng: randomness.Config = MISSING + + @staticmethod + def make(log_dir: Optional[str] = None, **kwargs) -> VecEnv: + from imitation.util import util + + return util.make_vec_env(log_dir=log_dir, **kwargs) + + +def make_rollout_venv(environment_config: Config) -> VecEnv: + from imitation.data import wrappers + + return instantiate( + environment_config, + log_dir=None, + post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)], + ) + + +def register_configs( + group: str, + default_rng: Union[randomness.Config, str] = MISSING, +): + default_rng = cast(randomness.Config, default_rng) + cs = ConfigStore.instance() + cs.store(group=group, name="gym_env", node=Config(rng=default_rng)) + cs.store( + group=group, + name="cartpole", + node=Config(env_name="CartPole-v0", max_episode_steps=500, rng=default_rng), + ) + cs.store( + group=group, + name="pendulum", + node=Config(env_name="Pendulum-v1", max_episode_steps=500, rng=default_rng), + ) diff --git a/src/imitation_cli/utils/feature_extractor_class.py b/src/imitation_cli/utils/feature_extractor_class.py new file mode 100644 index 000000000..2beb0e1e3 --- /dev/null +++ b/src/imitation_cli/utils/feature_extractor_class.py @@ -0,0 +1,35 @@ +"""Register Hydra configs for stable_baselines3 feature extractors.""" +import dataclasses +from enum import Enum + +import stable_baselines3.common.torch_layers as torch_layers +from hydra.core.config_store import ConfigStore + + +class FeatureExtractorClass(Enum): + """Enum of feature extractor classes.""" + + FlattenExtractor = torch_layers.FlattenExtractor + NatureCNN = torch_layers.NatureCNN + + +@dataclasses.dataclass +class Config: + """Base config for stable_baselines3 feature extractors.""" + + feature_extractor_class: FeatureExtractorClass + _target_: str = "imitation_cli.utils.feature_extractor_class.Config.make" + + @staticmethod + def make(feature_extractor_class: FeatureExtractorClass) -> type: + return feature_extractor_class.value + + +FlattenExtractor = Config(FeatureExtractorClass.FlattenExtractor) +NatureCNN = Config(FeatureExtractorClass.NatureCNN) + + +def register_configs(group: str): + cs = ConfigStore.instance() + for cls in FeatureExtractorClass: + cs.store(group=group, name=cls.name.lower(), node=Config(cls)) diff --git a/src/imitation_cli/utils/optimizer_class.py b/src/imitation_cli/utils/optimizer_class.py new file mode 100644 index 000000000..17a3d01ca --- /dev/null +++ b/src/imitation_cli/utils/optimizer_class.py @@ -0,0 +1,35 @@ +"""Register optimizer classes with Hydra.""" +import dataclasses +from enum import Enum + +import torch +from hydra.core.config_store import ConfigStore + + +class OptimizerClass(Enum): + """Enum of optimizer classes.""" + + Adam = torch.optim.Adam + SGD = torch.optim.SGD + + +@dataclasses.dataclass +class Config: + """Base config for optimizer classes.""" + + optimizer_class: OptimizerClass + _target_: str = "imitation_cli.utils.optimizer_class.Config.make" + + @staticmethod + def make(optimizer_class: OptimizerClass) -> type: + return optimizer_class.value + + +Adam = Config(OptimizerClass.Adam) +SGD = Config(OptimizerClass.SGD) + + +def register_configs(group: str): + cs = ConfigStore.instance() + for cls in OptimizerClass: + cs.store(group=group, name=cls.name.lower(), node=Config(cls)) diff --git a/src/imitation_cli/utils/policy.py b/src/imitation_cli/utils/policy.py new file mode 100644 index 000000000..26fe18a3f --- /dev/null +++ b/src/imitation_cli/utils/policy.py @@ -0,0 +1,213 @@ +"""Configurable policies for SB3 Base Policies.""" "" +from __future__ import annotations + +import dataclasses +import pathlib +import typing +from typing import Any, Dict, List, Optional, Union, cast + +if typing.TYPE_CHECKING: + from stable_baselines3.common.vec_env import VecEnv + from stable_baselines3.common.policies import BasePolicy + +from hydra.core.config_store import ConfigStore +from hydra.utils import instantiate +from omegaconf import MISSING + +from imitation_cli.utils import activation_function_class as act_fun_class_cfg +from imitation_cli.utils import environment as environment_cfg +from imitation_cli.utils import feature_extractor_class as feature_extractor_class_cfg +from imitation_cli.utils import optimizer_class as optimizer_class_cfg +from imitation_cli.utils import schedule + + +@dataclasses.dataclass +class Config: + """Base configuration for policies.""" + + _target_: str = MISSING + environment: environment_cfg.Config = MISSING + + +@dataclasses.dataclass +class Random(Config): + """Configuration for a random policy.""" + + _target_: str = "imitation_cli.utils.policy.Random.make" + + @staticmethod + def make(environment: VecEnv) -> BasePolicy: + from imitation.policies import base + + return base.RandomPolicy( + environment.observation_space, + environment.action_space, + ) + + +@dataclasses.dataclass +class ZeroPolicy(Config): + """Configuration for a zero policy.""" + + _target_: str = "imitation_cli.utils.policy.ZeroPolicy.make" + + @staticmethod + def make(environment: VecEnv) -> BasePolicy: + from imitation.policies import base + + return base.ZeroPolicy(environment.observation_space, environment.action_space) + + +@dataclasses.dataclass +class ActorCriticPolicy(Config): + """Configuration for a stable-baselines3 ActorCriticPolicy.""" + + _target_: str = "imitation_cli.utils.policy.ActorCriticPolicy.make" + lr_schedule: schedule.Config = schedule.FixedSchedule(3e-4) + net_arch: Optional[Dict[str, List[int]]] = None + activation_fn: act_fun_class_cfg.Config = act_fun_class_cfg.TanH + ortho_init: bool = True + use_sde: bool = False + log_std_init: float = 0.0 + full_std: bool = True + use_expln: bool = False + squash_output: bool = False + features_extractor_class: feature_extractor_class_cfg.Config = ( + feature_extractor_class_cfg.FlattenExtractor + ) + features_extractor_kwargs: Optional[Dict[str, Any]] = None + share_features_extractor: bool = True + normalize_images: bool = True + optimizer_class: optimizer_class_cfg.Config = optimizer_class_cfg.Adam + optimizer_kwargs: Optional[Dict[str, Any]] = None + + @staticmethod + def make_args( + activation_fn: act_fun_class_cfg.Config, + features_extractor_class: feature_extractor_class_cfg.Config, + optimizer_class: optimizer_class_cfg.Config, + **kwargs, + ): + del kwargs["_target_"] + del kwargs["environment"] + + kwargs["activation_fn"] = instantiate(activation_fn) + kwargs["features_extractor_class"] = instantiate(features_extractor_class) + kwargs["optimizer_class"] = instantiate(optimizer_class) + + return dict( + **kwargs, + ) + + @staticmethod + def make( + environment: VecEnv, + **kwargs, + ) -> BasePolicy: + import stable_baselines3 as sb3 + + return sb3.common.policies.ActorCriticPolicy( + observation_space=environment.observation_space, + action_space=environment.action_space, + **kwargs, + ) + + +@dataclasses.dataclass +class Loaded(Config): + """Base configuration for a policy that is loaded from somewhere.""" + + policy_type: str = ( + "PPO" # The SB3 policy class. Only SAC and PPO supported as of now + ) + + @staticmethod + def type_to_class(policy_type: str): + import stable_baselines3 as sb3 + + policy_type = policy_type.lower() + if policy_type == "ppo": + return sb3.PPO + if policy_type == "ppo": + return sb3.SAC + raise ValueError(f"Unknown policy type {policy_type}") + + +@dataclasses.dataclass +class PolicyOnDisk(Loaded): + """Configuration for a policy that is loaded from a path on disk.""" + + _target_: str = "imitation_cli.utils.policy.PolicyOnDisk.make" + path: pathlib.Path = MISSING + + @staticmethod + def make( + environment: VecEnv, + policy_type: str, + path: pathlib.Path, + ) -> BasePolicy: + from imitation.policies import serialize + + return serialize.load_stable_baselines_model( + Loaded.type_to_class(policy_type), + str(path), + environment, + ).policy + + +@dataclasses.dataclass +class PolicyFromHuggingface(Loaded): + """Configuration for a policy that is loaded from a HuggingFace model.""" + + _target_: str = "imitation_cli.utils.policy.PolicyFromHuggingface.make" + _recursive_: bool = False + organization: str = "HumanCompatibleAI" + + @staticmethod + def make( + environment: environment_cfg.Config, + policy_type: str, + organization: str, + ) -> BasePolicy: + import huggingface_sb3 as hfsb3 + + from imitation.policies import serialize + + model_name = hfsb3.ModelName( + policy_type.lower(), + hfsb3.EnvironmentName(environment.env_name), + ) + repo_id = hfsb3.ModelRepoId(organization, model_name) + filename = hfsb3.load_from_hub(repo_id, model_name.filename) + model = serialize.load_stable_baselines_model( + Loaded.type_to_class(policy_type), + filename, + instantiate(environment), + ) + return model.policy + + +def register_configs( + group: str, + default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING, +): + default_environment = cast(environment_cfg.Config, default_environment) + cs = ConfigStore.instance() + cs.store(group=group, name="random", node=Random(environment=default_environment)) + cs.store(group=group, name="zero", node=ZeroPolicy(environment=default_environment)) + cs.store( + group=group, + name="on_disk", + node=PolicyOnDisk(environment=default_environment), + ) + cs.store( + group=group, + name="from_huggingface", + node=PolicyFromHuggingface(environment=default_environment), + ) + cs.store( + group=group, + name="actor_critic", + node=ActorCriticPolicy(environment=default_environment), + ) + schedule.register_configs(group=group + "/lr_schedule") diff --git a/src/imitation_cli/utils/policy_evaluation.py b/src/imitation_cli/utils/policy_evaluation.py new file mode 100644 index 000000000..192750049 --- /dev/null +++ b/src/imitation_cli/utils/policy_evaluation.py @@ -0,0 +1,101 @@ +"""Code to evaluate trained policies.""" +from __future__ import annotations + +import dataclasses +import typing +from typing import Optional, Union, cast + +from hydra.utils import call +from omegaconf import MISSING + +from imitation_cli.utils import environment as environment_cfg +from imitation_cli.utils import randomness + +if typing.TYPE_CHECKING: + from stable_baselines3.common import base_class, policies + + +@dataclasses.dataclass +class Config: + """Configuration for evaluating a policy.""" + + environment: environment_cfg.Config = MISSING + n_episodes_eval: int = 50 + rng: randomness.Config = MISSING + + +def register_configs( + group: str, + default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING, + default_rng: Optional[Union[randomness.Config, str]] = MISSING, +) -> None: + from hydra.core.config_store import ConfigStore + + default_environment = cast(environment_cfg.Config, default_environment) + default_rng = cast(randomness.Config, default_rng) + + cs = ConfigStore.instance() + cs.store( + name="default_evaluation", + group=group, + node=Config(environment=default_environment, rng=default_rng), + ) + cs.store( + name="fast_evaluation", + group=group, + node=Config( + environment=default_environment, + rng=default_rng, + n_episodes_eval=2, + ), + ) + + +def eval_policy( + rl_algo: Union[base_class.BaseAlgorithm, policies.BasePolicy], + config: Config, +) -> typing.Mapping[str, float]: + """Evaluation of imitation learned policy. + + Has the side effect of setting `rl_algo`'s environment to `venv` + if it is a `BaseAlgorithm`. + + Args: + rl_algo: Algorithm to evaluate. + config: Configuration for evaluation. + + Returns: + A dictionary with two keys. "imit_stats" gives the return value of + `rollout_stats()` on rollouts test-reward-wrapped environment, using the final + policy (remember that the ground-truth reward can be recovered from the + "monitor_return" key). "expert_stats" gives the return value of + `rollout_stats()` on the expert demonstrations loaded from `rollout_path`. + """ + from stable_baselines3.common import base_class, vec_env + + from imitation.data import rollout + + sample_until_eval = rollout.make_min_episodes(config.n_episodes_eval) + venv = call(config.environment) + rng = call(config.rng) + + if isinstance(rl_algo, base_class.BaseAlgorithm): + # Set RL algorithm's env to venv, removing any cruft wrappers that the RL + # algorithm's environment may have accumulated. + rl_algo.set_env(venv) + # Generate trajectories with the RL algorithm's env - SB3 may apply wrappers + # under the hood to get it to work with the RL algorithm (e.g. transposing + # images, so they can be fed into CNNs). + train_env = rl_algo.get_env() + assert train_env is not None + else: + train_env = venv + + train_env = typing.cast(vec_env.VecEnv, train_env) + trajs = rollout.generate_trajectories( + rl_algo, + train_env, + sample_until=sample_until_eval, + rng=rng, + ) + return rollout.rollout_stats(trajs) diff --git a/src/imitation_cli/utils/randomness.py b/src/imitation_cli/utils/randomness.py new file mode 100644 index 000000000..68c4feeb5 --- /dev/null +++ b/src/imitation_cli/utils/randomness.py @@ -0,0 +1,28 @@ +"""Utilities for seeding random number generators.""" +from __future__ import annotations + +import dataclasses +import typing + +from omegaconf import MISSING + +if typing.TYPE_CHECKING: + import numpy as np + + +@dataclasses.dataclass +class Config: + """Configuration for seeding random number generators.""" + + _target_: str = "imitation_cli.utils.randomness.Config.make" + seed: int = MISSING + + @staticmethod + def make(seed: int) -> np.random.Generator: + import numpy as np + import torch + + np.random.seed(seed) + torch.manual_seed(seed) + + return np.random.default_rng(seed) diff --git a/src/imitation_cli/utils/reward_network.py b/src/imitation_cli/utils/reward_network.py new file mode 100644 index 000000000..43dbfcef4 --- /dev/null +++ b/src/imitation_cli/utils/reward_network.py @@ -0,0 +1,141 @@ +"""Reward network configuration.""" +from __future__ import annotations + +import dataclasses +import typing +from typing import Optional, Union, cast + +if typing.TYPE_CHECKING: + from stable_baselines3.common.vec_env import VecEnv + from imitation.rewards.reward_nets import RewardNet + +from hydra.core.config_store import ConfigStore +from hydra.utils import instantiate +from omegaconf import MISSING + +import imitation_cli.utils.environment as environment_cfg + + +@dataclasses.dataclass +class Config: + """Base configuration for reward networks.""" + + _target_: str = MISSING + environment: environment_cfg.Config = MISSING + + +@dataclasses.dataclass +class BasicRewardNet(Config): + """Configuration for a basic reward network.""" + + _target_: str = "imitation_cli.utils.reward_network.BasicRewardNet.make" + use_state: bool = True + use_action: bool = True + use_next_state: bool = False + use_done: bool = False + normalize_input_layer: bool = True + + @staticmethod + def make(environment: VecEnv, normalize_input_layer: bool, **kwargs) -> RewardNet: + from imitation.rewards import reward_nets + from imitation.util import networks + + reward_net = reward_nets.BasicRewardNet( + environment.observation_space, + environment.action_space, + **kwargs, + ) + if normalize_input_layer: + return reward_nets.NormalizedRewardNet( + reward_net, + networks.RunningNorm, + ) + else: + return reward_net + + +@dataclasses.dataclass +class BasicShapedRewardNet(BasicRewardNet): + """Configuration for a basic shaped reward network.""" + + _target_: str = "imitation_cli.utils.reward_network.BasicShapedRewardNet.make" + discount_factor: float = 0.99 + + @staticmethod + def make(environment: VecEnv, normalize_input_layer: bool, **kwargs) -> RewardNet: + from imitation.rewards import reward_nets + from imitation.util import networks + + reward_net = reward_nets.BasicShapedRewardNet( + environment.observation_space, + environment.action_space, + **kwargs, + ) + if normalize_input_layer: + return reward_nets.NormalizedRewardNet( + reward_net, + networks.RunningNorm, + ) + else: + return reward_net + + +@dataclasses.dataclass +class RewardEnsemble(Config): + """Configuration for a reward ensemble.""" + + _target_: str = "imitation_cli.utils.reward_network.RewardEnsemble.make" + _recursive_: bool = False + ensemble_size: int = MISSING + ensemble_member_config: BasicRewardNet = MISSING + add_std_alpha: Optional[float] = None + + @staticmethod + def make( + environment: environment_cfg.Config, + ensemble_member_config: BasicRewardNet, + add_std_alpha: Optional[float], + ensemble_size: int, + ) -> RewardNet: + from imitation.rewards import reward_nets + + venv = instantiate(environment) + reward_net = reward_nets.RewardEnsemble( + venv.observation_space, + venv.action_space, + [instantiate(ensemble_member_config) for _ in range(ensemble_size)], + ) + if add_std_alpha is not None: + return reward_nets.AddSTDRewardWrapper( + reward_net, + default_alpha=add_std_alpha, + ) + else: + return reward_net + + +def register_configs( + group: str, + default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING, +): + default_environment = cast(environment_cfg.Config, default_environment) + cs = ConfigStore.instance() + cs.store( + group=group, + name="basic", + node=BasicRewardNet(environment=default_environment), + ) + cs.store( + group=group, + name="shaped", + node=BasicShapedRewardNet(environment=default_environment), + ) + cs.store( + group=group, + name="small_ensemble", + node=RewardEnsemble( + environment=default_environment, + ensemble_size=5, + ensemble_member_config=BasicRewardNet(environment=default_environment), + ), + ) diff --git a/src/imitation_cli/utils/rl_algorithm.py b/src/imitation_cli/utils/rl_algorithm.py new file mode 100644 index 000000000..739f0b9a7 --- /dev/null +++ b/src/imitation_cli/utils/rl_algorithm.py @@ -0,0 +1,130 @@ +"""Configurable RL algorithms.""" +from __future__ import annotations + +import dataclasses +import pathlib +import typing +from typing import Optional, Union, cast + +if typing.TYPE_CHECKING: + import stable_baselines3 as sb3 + from stable_baselines3.common.vec_env import VecEnv + +from hydra.utils import instantiate, to_absolute_path +from omegaconf import MISSING + +from imitation_cli.utils import environment as environment_cfg +from imitation_cli.utils import policy as policy_cfg +from imitation_cli.utils import schedule + + +@dataclasses.dataclass +class Config: + """Base configuration for RL algorithms.""" + + _target_: str = MISSING + environment: environment_cfg.Config = MISSING + + +@dataclasses.dataclass +class PPO(Config): + """Configuration for a stable-baselines3 PPO algorithm.""" + + _target_: str = "imitation_cli.utils.rl_algorithm.PPO.make" + # We disable recursive instantiation, so we can just make the + # arguments of the policy but not the policy itself + _recursive_: bool = False + policy: policy_cfg.ActorCriticPolicy = policy_cfg.ActorCriticPolicy() + learning_rate: schedule.Config = schedule.FixedSchedule(3e-4) + n_steps: int = 2048 + batch_size: int = 64 + n_epochs: int = 10 + gamma: float = 0.99 + gae_lambda: float = 0.95 + clip_range: schedule.Config = schedule.FixedSchedule(0.2) + clip_range_vf: Optional[schedule.Config] = None + normalize_advantage: bool = True + ent_coef: float = 0.0 + vf_coef: float = 0.5 + max_grad_norm: float = 0.5 + use_sde: bool = False + sde_sample_freq: int = -1 + target_kl: Optional[float] = None + tensorboard_log: Optional[str] = None + verbose: int = 0 + seed: int = MISSING + device: str = "auto" + + @staticmethod + def make( + environment: environment_cfg.Config, + policy: policy_cfg.ActorCriticPolicy, + learning_rate: schedule.Config, + clip_range: schedule.Config, + **kwargs, + ) -> sb3.PPO: + import stable_baselines3 as sb3 + + policy_kwargs = policy_cfg.ActorCriticPolicy.make_args( + **typing.cast(dict, policy), + ) + del policy_kwargs["use_sde"] + del policy_kwargs["lr_schedule"] + return sb3.PPO( + policy=sb3.common.policies.ActorCriticPolicy, + policy_kwargs=policy_kwargs, + env=instantiate(environment), + learning_rate=instantiate(learning_rate), + clip_range=instantiate(clip_range), + **kwargs, + ) + + +@dataclasses.dataclass +class PPOOnDisk(Config): + """Configuration for a stable-baselines3 PPO algorithm loaded from disk.""" + + _target_: str = "imitation_cli.utils.rl_algorithm.PPOOnDisk.make" + path: pathlib.Path = MISSING + + @staticmethod + def make(environment: VecEnv, path: pathlib.Path) -> sb3.PPO: + import stable_baselines3 as sb3 + + from imitation.policies import serialize + + return serialize.load_stable_baselines_model( + sb3.PPO, + str(to_absolute_path(str(path))), + environment, + ) + + +def register_configs( + group: str = "rl_algorithm", + default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING, + default_seed: Optional[Union[int, str]] = MISSING, +): + from hydra.core.config_store import ConfigStore + + default_environment = cast(environment_cfg.Config, default_environment) + default_seed = cast(int, default_seed) + + cs = ConfigStore.instance() + cs.store( + name="ppo", + group=group, + node=PPO( + environment=default_environment, + policy=policy_cfg.ActorCriticPolicy(environment=default_environment), + seed=default_seed, + ), + ) + cs.store( + name="ppo_on_disk", + group=group, + node=PPOOnDisk(environment=default_environment), + ) + + schedule.register_configs(group=group + "/learning_rate") + schedule.register_configs(group=group + "/clip_range") diff --git a/src/imitation_cli/utils/schedule.py b/src/imitation_cli/utils/schedule.py new file mode 100644 index 000000000..08091c6a0 --- /dev/null +++ b/src/imitation_cli/utils/schedule.py @@ -0,0 +1,40 @@ +"""Configurations for stable_baselines3 schedules.""" +import dataclasses + +from hydra.core.config_store import ConfigStore +from omegaconf import MISSING + + +@dataclasses.dataclass +class Config: + """Base configuration for schedules.""" + + # Note: we don't define _target_ here so in the subclasses it can be defined last. + # This way we can instantiate a fixed schedule with `FixedSchedule(0.1)`. + # If we defined _target_ here, then we would have to instantiate a fixed schedule + # with `FixedSchedule(val=0.1)`. Otherwise we would set _target_ to 0.1. + pass + + +@dataclasses.dataclass +class FixedSchedule(Config): + """Configuration for a fixed schedule.""" + + val: float = MISSING + _target_: str = "stable_baselines3.common.utils.constant_fn" + + +@dataclasses.dataclass +class LinearSchedule(Config): + """Configuration for a linear schedule.""" + + start: float = MISSING + end: float = MISSING + end_fraction: float = MISSING + _target_: str = "stable_baselines3.common.utils.get_linear_fn" + + +def register_configs(group: str): + cs = ConfigStore.instance() + cs.store(group=group, name="fixed", node=FixedSchedule) + cs.store(group=group, name="linear", node=LinearSchedule) diff --git a/src/imitation_cli/utils/trajectories.py b/src/imitation_cli/utils/trajectories.py new file mode 100644 index 000000000..80c9322e5 --- /dev/null +++ b/src/imitation_cli/utils/trajectories.py @@ -0,0 +1,89 @@ +"""Configurable trajectory sources.""" +from __future__ import annotations + +import dataclasses +import pathlib +import typing +from typing import Optional, Sequence, Union, cast + +if typing.TYPE_CHECKING: + from stable_baselines3.common.policies import BasePolicy + from imitation.data.types import Trajectory + import numpy as np + +from hydra.core.config_store import ConfigStore +from hydra.utils import instantiate +from omegaconf import MISSING + +from imitation_cli.utils import environment as environment_cfg +from imitation_cli.utils import policy, randomness + + +@dataclasses.dataclass +class Config: + """Base configuration for trajectory sources.""" + + _target_: str = MISSING + + +@dataclasses.dataclass +class OnDisk(Config): + """Configuration for loading trajectories from disk.""" + + _target_: str = "imitation_cli.utils.trajectories.OnDisk.make" + path: pathlib.Path = MISSING + + @staticmethod + def make(path: pathlib.Path) -> Sequence[Trajectory]: + from imitation.data import serialize + + return serialize.load(path) + + +@dataclasses.dataclass +class Generated(Config): + """Configuration for generating trajectories from an expert policy.""" + + _target_: str = "imitation_cli.utils.trajectories.Generated.make" + # Note: We disable the recursive flag, so we can extract + # the environment from the expert policy + _recursive_: bool = False + total_timesteps: int = MISSING + expert_policy: policy.Config = MISSING + rng: randomness.Config = MISSING + + @staticmethod + def make( + total_timesteps: int, + expert_policy: BasePolicy, + rng: np.random.Generator, + ) -> Sequence[Trajectory]: + from imitation.data import rollout + + expert = instantiate(expert_policy) + env = instantiate(expert_policy.environment) + rng = instantiate(rng) + return rollout.generate_trajectories( + expert, + env, + rollout.make_sample_until(min_timesteps=total_timesteps), + rng, + deterministic_policy=True, + ) + + +def register_configs( + group: str, + default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING, + default_rng: Optional[Union[randomness.Config, str]] = MISSING, +): + default_environment = cast(environment_cfg.Config, default_environment) + default_rng = cast(randomness.Config, default_rng) + + cs = ConfigStore.instance() + cs.store(group=group, name="on_disk", node=OnDisk) + cs.store(group=group, name="generated", node=Generated(rng=default_rng)) + policy.register_configs( + group=group + "/expert_policy", + default_environment=default_environment, + )