Skip to content

Commit

Permalink
Add dict obs support for PPO (#559)
Browse files Browse the repository at this point in the history
* add dict obs support for PPO and raise notImplementedError for other training agents

* handle when you do not normalise observations and do not have state in your dict-valued obs

* clean-up, update obseration_size typing, avoid observation_size call

* support the basic dict obs case in networks.py

* add test for dict obs

* fix nits

* fix nits

* assume ndarray obs are vectors
  • Loading branch information
Andrew-Luo1 authored Nov 27, 2024
1 parent 9ede872 commit e615f42
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 26 deletions.
16 changes: 10 additions & 6 deletions brax/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""A brax environment for training and inference."""

import abc
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union

from brax import base
from brax.generalized import pipeline as g_pipeline
Expand All @@ -28,13 +28,14 @@
import jax
import numpy as np

ObservationSize = Union[int, Mapping[str, Union[Tuple[int, ...], int]]]

@struct.dataclass
class State(base.Base):
"""Environment state for training and inference."""

pipeline_state: Optional[base.State]
obs: jax.Array
obs: Union[jax.Array, Mapping[str, jax.Array]]
reward: jax.Array
done: jax.Array
metrics: Dict[str, jax.Array] = struct.field(default_factory=dict)
Expand All @@ -54,7 +55,7 @@ def step(self, state: State, action: jax.Array) -> State:

@property
@abc.abstractmethod
def observation_size(self) -> int:
def observation_size(self) -> ObservationSize:
"""The size of the observation vector returned in step and reset."""

@property
Expand Down Expand Up @@ -139,10 +140,13 @@ def dt(self) -> jax.Array:
return self.sys.opt.timestep * self._n_frames # pytype: disable=attribute-error

@property
def observation_size(self) -> int:
def observation_size(self) -> ObservationSize:
rng = jax.random.PRNGKey(0)
reset_state = self.unwrapped.reset(rng)
return reset_state.obs.shape[-1]
obs = reset_state.obs
if isinstance(obs, jax.Array):
return obs.shape[-1]
return jax.tree_util.tree_map(lambda x: x.shape, obs)

@property
def action_size(self) -> int:
Expand Down Expand Up @@ -176,7 +180,7 @@ def step(self, state: State, action: jax.Array) -> State:
return self.env.step(state, action)

@property
def observation_size(self) -> int:
def observation_size(self) -> ObservationSize:
return self.env.observation_size

@property
Expand Down
3 changes: 3 additions & 0 deletions brax/envs/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, **kwargs):
self._dt = 0.02
self._reset_count = 0
self._step_count = 0
self._use_dict_obs = kwargs.get('use_dict_obs', False)

def reset(self, rng: jax.Array) -> State:
self._reset_count += 1
Expand All @@ -39,6 +40,7 @@ def reset(self, rng: jax.Array) -> State:
contact=None
)
obs = jp.zeros(2)
obs = {'state': obs} if self._use_dict_obs else obs
reward, done = jp.array(0.0), jp.array(0.0)
return State(pipeline_state, obs, reward, done)

Expand All @@ -53,6 +55,7 @@ def step(self, state: State, action: jax.Array) -> State:
xd=state.pipeline_state.xd.replace(vel=vel),
)
obs = jp.array([pos[0], vel[0]])
obs = {'state': obs} if self._use_dict_obs else obs
reward = pos[0]

return state.replace(pipeline_state=qp, obs=obs, reward=reward)
Expand Down
2 changes: 1 addition & 1 deletion brax/envs/wrappers/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def where_done(x, y):
pipeline_state = jax.tree.map(
where_done, state.info['first_pipeline_state'], state.pipeline_state
)
obs = where_done(state.info['first_obs'], state.obs)
obs = jax.tree.map(where_done, state.info['first_obs'], state.obs)
return state.replace(pipeline_state=pipeline_state, obs=obs)


Expand Down
6 changes: 5 additions & 1 deletion brax/training/agents/apg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,15 @@ def train(
reset_fn = jax.jit(jax.vmap(env.reset))
step_fn = jax.jit(jax.vmap(env.step))

obs_size = env.observation_size
if isinstance(obs_size, Dict):
raise NotImplementedError("Dictionary observations not implemented in APG")

normalize = lambda x, y: x
if normalize_observations:
normalize = running_statistics.normalize
apg_network = network_factory(
env.observation_size,
obs_size,
env.action_size,
preprocess_observations_fn=normalize)
make_policy = apg_networks.make_inference_fn(apg_network)
Expand Down
2 changes: 2 additions & 0 deletions brax/training/agents/ars/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def train(
)

obs_size = env.observation_size
if isinstance(obs_size, Dict):
raise NotImplementedError("Dictionary observations not implemented in ARS")

normalize_fn = lambda x, y: x
if normalize_observations:
Expand Down
4 changes: 3 additions & 1 deletion brax/training/agents/es/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def train(
)

obs_size = env.observation_size

if isinstance(obs_size, Dict):
raise NotImplementedError("Dictionary observations not implemented in ES")

normalize_fn = lambda x, y: x
if normalize_observations:
normalize_fn = running_statistics.normalize
Expand Down
4 changes: 2 additions & 2 deletions brax/training/agents/ppo/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def compute_ppo_loss(
data.observation)

baseline = value_apply(normalizer_params, params.value, data.observation)

terminal_obs = jax.tree_util.tree_map(lambda x: x[-1], data.next_observation)
bootstrap_value = value_apply(normalizer_params, params.value,
data.next_observation[-1])
terminal_obs)

rewards = data.reward * reward_scaling
truncation = data.extras['state_extras']['truncation']
Expand Down
11 changes: 8 additions & 3 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,16 @@ def train(
key_envs = jnp.reshape(key_envs,
(local_devices_to_use, -1) + key_envs.shape[1:])
env_state = reset_fn(key_envs)
ndarray_obs = isinstance(env_state.obs, jnp.ndarray) # Check whether observations are in dictionary form.

# Discard the batch axes over devices and envs.
obs_shape = jax.tree_util.tree_map(lambda x: x.shape[2:], env_state.obs)

normalize = lambda x, y: x
if normalize_observations:
normalize = running_statistics.normalize
ppo_network = network_factory(
env_state.obs.shape[-1],
obs_shape,
env.action_size,
preprocess_observations_fn=normalize)
make_policy = ppo_networks.make_inference_fn(ppo_network)
Expand Down Expand Up @@ -332,7 +336,7 @@ def f(carry, unused_t):
# Update normalization params and normalize observations.
normalizer_params = running_statistics.update(
training_state.normalizer_params,
data.observation,
data.observation if ndarray_obs else data.observation['state'],
pmap_axis_name=_PMAP_AXIS_NAME)

(optimizer_state, params, _), metrics = jax.lax.scan(
Expand Down Expand Up @@ -389,11 +393,12 @@ def training_epoch_with_timing(
value=ppo_network.value_network.init(key_value),
)

obs_shape = env_state.obs.shape if ndarray_obs else env_state.obs['state'].shape
training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray
optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars
params=init_params,
normalizer_params=running_statistics.init_state(
specs.Array(env_state.obs.shape[-1:], jnp.dtype('float32'))),
specs.Array(obs_shape[-1:], jnp.dtype('float32'))),
env_steps=0)

if (
Expand Down
6 changes: 3 additions & 3 deletions brax/training/agents/ppo/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
class PPOTest(parameterized.TestCase):
"""Tests for PPO module."""


def testTrain(self):
@parameterized.parameters(True, False)
def testTrain(self, use_dict_obs):
"""Test PPO with a simple env."""
fast = envs.get_environment('fast')
fast = envs.get_environment('fast', use_dict_obs=use_dict_obs)
_, _, metrics = ppo.train(
fast,
num_timesteps=2**15,
Expand Down
3 changes: 3 additions & 0 deletions brax/training/agents/sac/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def train(
)

obs_size = env.observation_size
if isinstance(obs_size, Dict):
raise NotImplementedError("Dictionary observations not implemented in SAC")

action_size = env.action_size

normalize_fn = lambda x, y: x
Expand Down
21 changes: 15 additions & 6 deletions brax/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Network definitions."""

import dataclasses
from typing import Any, Callable, Sequence, Tuple
from typing import Any, Callable, Mapping, Sequence, Tuple
import warnings

from brax.training import types
Expand Down Expand Up @@ -82,10 +82,13 @@ def __call__(self, data: jnp.ndarray):
hidden = self.activation(hidden)
return hidden

def get_obs_state_size(obs_size: types.ObservationSize) -> int:
obs_size = obs_size['state'] if isinstance(obs_size, Mapping) else obs_size
return jax.tree_util.tree_flatten(obs_size)[0][-1] # Size can be tuple or int.

def make_policy_network(
param_size: int,
obs_size: int,
obs_size: types.ObservationSize,
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (256, 256),
Expand All @@ -100,30 +103,36 @@ def make_policy_network(
layer_norm=layer_norm)

def apply(processor_params, policy_params, obs):
obs = (obs if isinstance(obs, jnp.ndarray)
else obs['state']) # state-only in the case of dict obs.
obs = preprocess_observations_fn(obs, processor_params)
return policy_module.apply(policy_params, obs)

obs_size = get_obs_state_size(obs_size)
dummy_obs = jnp.zeros((1, obs_size))
return FeedForwardNetwork(
init=lambda key: policy_module.init(key, dummy_obs), apply=apply)


def make_value_network(
obs_size: int,
obs_size: types.ObservationSize,
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: ActivationFn = linen.relu) -> FeedForwardNetwork:
"""Creates a policy network."""
"""Creates a value network."""
value_module = MLP(
layer_sizes=list(hidden_layer_sizes) + [1],
activation=activation,
kernel_init=jax.nn.initializers.lecun_uniform())

def apply(processor_params, policy_params, obs):
def apply(processor_params, value_params, obs):
obs = (obs if isinstance(obs, jnp.ndarray)
else obs['state']) # state-only in the case of dict obs.
obs = preprocess_observations_fn(obs, processor_params)
return jnp.squeeze(value_module.apply(policy_params, obs), axis=-1)
return jnp.squeeze(value_module.apply(value_params, obs), axis=-1)

obs_size = get_obs_state_size(obs_size)
dummy_obs = jnp.zeros((1, obs_size))
return FeedForwardNetwork(
init=lambda key: value_module.init(key, dummy_obs), apply=apply)
Expand Down
7 changes: 4 additions & 3 deletions brax/training/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Brax training types."""

from typing import Any, Mapping, NamedTuple, Tuple, TypeVar
from typing import Any, Mapping, NamedTuple, Tuple, TypeVar, Union

from brax.training.acme.types import NestedArray
import jax.numpy as jnp
Expand All @@ -30,7 +30,8 @@
Params = Any
PRNGKey = jnp.ndarray
Metrics = Mapping[str, jnp.ndarray]
Observation = jnp.ndarray
Observation = Union[jnp.ndarray, Mapping[str, jnp.ndarray]]
ObservationSize = Union[int, Mapping[str, Union[Tuple[int, ...], int]]]
Action = jnp.ndarray
Extra = Mapping[str, Any]
PolicyParams = Any
Expand Down Expand Up @@ -79,7 +80,7 @@ class NetworkFactory(Protocol[NetworkType]):

def __call__(
self,
observation_size: int,
observation_size: ObservationSize,
action_size: int,
preprocess_observations_fn:
PreprocessObservationFn = identity_observation_preprocessor
Expand Down

0 comments on commit e615f42

Please sign in to comment.