diff --git a/brax/envs/base.py b/brax/envs/base.py index 9097ac2f..b9436b2c 100644 --- a/brax/envs/base.py +++ b/brax/envs/base.py @@ -16,7 +16,7 @@ """A brax environment for training and inference.""" import abc -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union from brax import base from brax.generalized import pipeline as g_pipeline @@ -28,13 +28,14 @@ import jax import numpy as np +ObservationSize = Union[int, Mapping[str, Union[Tuple[int, ...], int]]] @struct.dataclass class State(base.Base): """Environment state for training and inference.""" pipeline_state: Optional[base.State] - obs: jax.Array + obs: Union[jax.Array, Mapping[str, jax.Array]] reward: jax.Array done: jax.Array metrics: Dict[str, jax.Array] = struct.field(default_factory=dict) @@ -54,7 +55,7 @@ def step(self, state: State, action: jax.Array) -> State: @property @abc.abstractmethod - def observation_size(self) -> int: + def observation_size(self) -> ObservationSize: """The size of the observation vector returned in step and reset.""" @property @@ -139,10 +140,13 @@ def dt(self) -> jax.Array: return self.sys.opt.timestep * self._n_frames # pytype: disable=attribute-error @property - def observation_size(self) -> int: + def observation_size(self) -> ObservationSize: rng = jax.random.PRNGKey(0) reset_state = self.unwrapped.reset(rng) - return reset_state.obs.shape[-1] + obs = reset_state.obs + if isinstance(obs, jax.Array): + return obs.shape[-1] + return jax.tree_util.tree_map(lambda x: x.shape, obs) @property def action_size(self) -> int: @@ -176,7 +180,7 @@ def step(self, state: State, action: jax.Array) -> State: return self.env.step(state, action) @property - def observation_size(self) -> int: + def observation_size(self) -> ObservationSize: return self.env.observation_size @property diff --git a/brax/envs/fast.py b/brax/envs/fast.py index f7351b71..f7b312d8 100644 --- a/brax/envs/fast.py +++ b/brax/envs/fast.py @@ -28,6 +28,7 @@ def __init__(self, **kwargs): self._dt = 0.02 self._reset_count = 0 self._step_count = 0 + self._use_dict_obs = kwargs.get('use_dict_obs', False) def reset(self, rng: jax.Array) -> State: self._reset_count += 1 @@ -39,6 +40,7 @@ def reset(self, rng: jax.Array) -> State: contact=None ) obs = jp.zeros(2) + obs = {'state': obs} if self._use_dict_obs else obs reward, done = jp.array(0.0), jp.array(0.0) return State(pipeline_state, obs, reward, done) @@ -53,6 +55,7 @@ def step(self, state: State, action: jax.Array) -> State: xd=state.pipeline_state.xd.replace(vel=vel), ) obs = jp.array([pos[0], vel[0]]) + obs = {'state': obs} if self._use_dict_obs else obs reward = pos[0] return state.replace(pipeline_state=qp, obs=obs, reward=reward) diff --git a/brax/envs/wrappers/training.py b/brax/envs/wrappers/training.py index a9b91b70..d4a364d2 100644 --- a/brax/envs/wrappers/training.py +++ b/brax/envs/wrappers/training.py @@ -130,7 +130,7 @@ def where_done(x, y): pipeline_state = jax.tree.map( where_done, state.info['first_pipeline_state'], state.pipeline_state ) - obs = where_done(state.info['first_obs'], state.obs) + obs = jax.tree.map(where_done, state.info['first_obs'], state.obs) return state.replace(pipeline_state=pipeline_state, obs=obs) diff --git a/brax/training/agents/apg/train.py b/brax/training/agents/apg/train.py index 351a7166..4482b63a 100644 --- a/brax/training/agents/apg/train.py +++ b/brax/training/agents/apg/train.py @@ -131,11 +131,15 @@ def train( reset_fn = jax.jit(jax.vmap(env.reset)) step_fn = jax.jit(jax.vmap(env.step)) + obs_size = env.observation_size + if isinstance(obs_size, Dict): + raise NotImplementedError("Dictionary observations not implemented in APG") + normalize = lambda x, y: x if normalize_observations: normalize = running_statistics.normalize apg_network = network_factory( - env.observation_size, + obs_size, env.action_size, preprocess_observations_fn=normalize) make_policy = apg_networks.make_inference_fn(apg_network) diff --git a/brax/training/agents/ars/train.py b/brax/training/agents/ars/train.py index 70d8a003..77395e25 100644 --- a/brax/training/agents/ars/train.py +++ b/brax/training/agents/ars/train.py @@ -119,6 +119,8 @@ def train( ) obs_size = env.observation_size + if isinstance(obs_size, Dict): + raise NotImplementedError("Dictionary observations not implemented in ARS") normalize_fn = lambda x, y: x if normalize_observations: diff --git a/brax/training/agents/es/train.py b/brax/training/agents/es/train.py index 7351f3dc..c3b20ba5 100644 --- a/brax/training/agents/es/train.py +++ b/brax/training/agents/es/train.py @@ -146,7 +146,9 @@ def train( ) obs_size = env.observation_size - + if isinstance(obs_size, Dict): + raise NotImplementedError("Dictionary observations not implemented in ES") + normalize_fn = lambda x, y: x if normalize_observations: normalize_fn = running_statistics.normalize diff --git a/brax/training/agents/ppo/losses.py b/brax/training/agents/ppo/losses.py index c477739c..202aa796 100644 --- a/brax/training/agents/ppo/losses.py +++ b/brax/training/agents/ppo/losses.py @@ -135,9 +135,9 @@ def compute_ppo_loss( data.observation) baseline = value_apply(normalizer_params, params.value, data.observation) - + terminal_obs = jax.tree_util.tree_map(lambda x: x[-1], data.next_observation) bootstrap_value = value_apply(normalizer_params, params.value, - data.next_observation[-1]) + terminal_obs) rewards = data.reward * reward_scaling truncation = data.extras['state_extras']['truncation'] diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index a7932fc6..137db298 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -233,12 +233,16 @@ def train( key_envs = jnp.reshape(key_envs, (local_devices_to_use, -1) + key_envs.shape[1:]) env_state = reset_fn(key_envs) + ndarray_obs = isinstance(env_state.obs, jnp.ndarray) # Check whether observations are in dictionary form. + + # Discard the batch axes over devices and envs. + obs_shape = jax.tree_util.tree_map(lambda x: x.shape[2:], env_state.obs) normalize = lambda x, y: x if normalize_observations: normalize = running_statistics.normalize ppo_network = network_factory( - env_state.obs.shape[-1], + obs_shape, env.action_size, preprocess_observations_fn=normalize) make_policy = ppo_networks.make_inference_fn(ppo_network) @@ -332,7 +336,7 @@ def f(carry, unused_t): # Update normalization params and normalize observations. normalizer_params = running_statistics.update( training_state.normalizer_params, - data.observation, + data.observation if ndarray_obs else data.observation['state'], pmap_axis_name=_PMAP_AXIS_NAME) (optimizer_state, params, _), metrics = jax.lax.scan( @@ -389,11 +393,12 @@ def training_epoch_with_timing( value=ppo_network.value_network.init(key_value), ) + obs_shape = env_state.obs.shape if ndarray_obs else env_state.obs['state'].shape training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars params=init_params, normalizer_params=running_statistics.init_state( - specs.Array(env_state.obs.shape[-1:], jnp.dtype('float32'))), + specs.Array(obs_shape[-1:], jnp.dtype('float32'))), env_steps=0) if ( diff --git a/brax/training/agents/ppo/train_test.py b/brax/training/agents/ppo/train_test.py index 408bd4f4..bd54af01 100644 --- a/brax/training/agents/ppo/train_test.py +++ b/brax/training/agents/ppo/train_test.py @@ -27,10 +27,10 @@ class PPOTest(parameterized.TestCase): """Tests for PPO module.""" - - def testTrain(self): + @parameterized.parameters(True, False) + def testTrain(self, use_dict_obs): """Test PPO with a simple env.""" - fast = envs.get_environment('fast') + fast = envs.get_environment('fast', use_dict_obs=use_dict_obs) _, _, metrics = ppo.train( fast, num_timesteps=2**15, diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py index 64053a81..08da8bbf 100644 --- a/brax/training/agents/sac/train.py +++ b/brax/training/agents/sac/train.py @@ -191,6 +191,9 @@ def train( ) obs_size = env.observation_size + if isinstance(obs_size, Dict): + raise NotImplementedError("Dictionary observations not implemented in SAC") + action_size = env.action_size normalize_fn = lambda x, y: x diff --git a/brax/training/networks.py b/brax/training/networks.py index 23e041aa..10a429c0 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -15,7 +15,7 @@ """Network definitions.""" import dataclasses -from typing import Any, Callable, Sequence, Tuple +from typing import Any, Callable, Mapping, Sequence, Tuple import warnings from brax.training import types @@ -82,10 +82,13 @@ def __call__(self, data: jnp.ndarray): hidden = self.activation(hidden) return hidden +def get_obs_state_size(obs_size: types.ObservationSize) -> int: + obs_size = obs_size['state'] if isinstance(obs_size, Mapping) else obs_size + return jax.tree_util.tree_flatten(obs_size)[0][-1] # Size can be tuple or int. def make_policy_network( param_size: int, - obs_size: int, + obs_size: types.ObservationSize, preprocess_observations_fn: types.PreprocessObservationFn = types .identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (256, 256), @@ -100,30 +103,36 @@ def make_policy_network( layer_norm=layer_norm) def apply(processor_params, policy_params, obs): + obs = (obs if isinstance(obs, jnp.ndarray) + else obs['state']) # state-only in the case of dict obs. obs = preprocess_observations_fn(obs, processor_params) return policy_module.apply(policy_params, obs) + obs_size = get_obs_state_size(obs_size) dummy_obs = jnp.zeros((1, obs_size)) return FeedForwardNetwork( init=lambda key: policy_module.init(key, dummy_obs), apply=apply) def make_value_network( - obs_size: int, + obs_size: types.ObservationSize, preprocess_observations_fn: types.PreprocessObservationFn = types .identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (256, 256), activation: ActivationFn = linen.relu) -> FeedForwardNetwork: - """Creates a policy network.""" + """Creates a value network.""" value_module = MLP( layer_sizes=list(hidden_layer_sizes) + [1], activation=activation, kernel_init=jax.nn.initializers.lecun_uniform()) - def apply(processor_params, policy_params, obs): + def apply(processor_params, value_params, obs): + obs = (obs if isinstance(obs, jnp.ndarray) + else obs['state']) # state-only in the case of dict obs. obs = preprocess_observations_fn(obs, processor_params) - return jnp.squeeze(value_module.apply(policy_params, obs), axis=-1) + return jnp.squeeze(value_module.apply(value_params, obs), axis=-1) + obs_size = get_obs_state_size(obs_size) dummy_obs = jnp.zeros((1, obs_size)) return FeedForwardNetwork( init=lambda key: value_module.init(key, dummy_obs), apply=apply) diff --git a/brax/training/types.py b/brax/training/types.py index 63fbf129..61839cf6 100644 --- a/brax/training/types.py +++ b/brax/training/types.py @@ -14,7 +14,7 @@ """Brax training types.""" -from typing import Any, Mapping, NamedTuple, Tuple, TypeVar +from typing import Any, Mapping, NamedTuple, Tuple, TypeVar, Union from brax.training.acme.types import NestedArray import jax.numpy as jnp @@ -30,7 +30,8 @@ Params = Any PRNGKey = jnp.ndarray Metrics = Mapping[str, jnp.ndarray] -Observation = jnp.ndarray +Observation = Union[jnp.ndarray, Mapping[str, jnp.ndarray]] +ObservationSize = Union[int, Mapping[str, Union[Tuple[int, ...], int]]] Action = jnp.ndarray Extra = Mapping[str, Any] PolicyParams = Any @@ -79,7 +80,7 @@ class NetworkFactory(Protocol[NetworkType]): def __call__( self, - observation_size: int, + observation_size: ObservationSize, action_size: int, preprocess_observations_fn: PreprocessObservationFn = identity_observation_preprocessor