Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make necessary changes to support asymmetric actor critic. #562

Merged
merged 9 commits into from
Dec 2, 2024
27 changes: 22 additions & 5 deletions brax/envs/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,21 @@ def __init__(self, **kwargs):
self._reset_count = 0
self._step_count = 0
self._use_dict_obs = kwargs.get('use_dict_obs', False)
self._asymmetric_obs = kwargs.get('asymmetric_obs', False)
if self._asymmetric_obs and not self._use_dict_obs:
raise ValueError('asymmetric_obs requires use_dict_obs=True')

def _get_obs(self):
if not self._use_dict_obs:
return jp.zeros(2)

obs = {'state': jp.zeros(2)}
if self._asymmetric_obs:
obs['privileged_state'] = jp.zeros(4)
return obs

def reset(self, rng: jax.Array) -> State:
del rng # Unused.
self._reset_count += 1
pipeline_state = base.State(
q=jp.zeros(1),
Expand All @@ -39,8 +52,7 @@ def reset(self, rng: jax.Array) -> State:
xd=base.Motion.create(vel=jp.zeros(3)),
contact=None
)
obs = jp.zeros(2)
obs = {'state': obs} if self._use_dict_obs else obs
obs = self._get_obs()
reward, done = jp.array(0.0), jp.array(0.0)
return State(pipeline_state, obs, reward, done)

Expand All @@ -54,8 +66,7 @@ def step(self, state: State, action: jax.Array) -> State:
x=state.pipeline_state.x.replace(pos=pos),
xd=state.pipeline_state.xd.replace(vel=vel),
)
obs = jp.array([pos[0], vel[0]])
obs = {'state': obs} if self._use_dict_obs else obs
obs = self._get_obs()
reward = pos[0]

return state.replace(pipeline_state=qp, obs=obs, reward=reward)
Expand All @@ -70,7 +81,13 @@ def step_count(self):

@property
def observation_size(self):
return 2
if not self._use_dict_obs:
return 2

obs = {'state': 2}
if self._asymmetric_obs:
obs['privileged_state'] = 4
return obs

@property
def action_size(self):
Expand Down
11 changes: 7 additions & 4 deletions brax/training/agents/ppo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def make_ppo_networks(
.identity_observation_preprocessor,
policy_hidden_layer_sizes: Sequence[int] = (32,) * 4,
value_hidden_layer_sizes: Sequence[int] = (256,) * 5,
activation: networks.ActivationFn = linen.swish) -> PPONetworks:
activation: networks.ActivationFn = linen.swish,
policy_obs_key: str = 'state',
value_obs_key: str = 'state') -> PPONetworks:
"""Make PPO networks with preprocessor."""
parametric_action_distribution = distribution.NormalTanhDistribution(
event_size=action_size)
Expand All @@ -77,13 +79,14 @@ def make_ppo_networks(
observation_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=policy_hidden_layer_sizes,
activation=activation)
activation=activation,
obs_key=policy_obs_key)
value_network = networks.make_value_network(
observation_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=value_hidden_layer_sizes,
activation=activation)

activation=activation,
obs_key=value_obs_key)
return PPONetworks(
policy_network=policy_network,
value_network=value_network,
Expand Down
10 changes: 5 additions & 5 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ def train(
key_envs = jnp.reshape(key_envs,
(local_devices_to_use, -1) + key_envs.shape[1:])
env_state = reset_fn(key_envs)
ndarray_obs = isinstance(env_state.obs, jnp.ndarray) # Check whether observations are in dictionary form.

# Discard the batch axes over devices and envs.
obs_shape = jax.tree_util.tree_map(lambda x: x.shape[2:], env_state.obs)
Expand Down Expand Up @@ -336,7 +335,7 @@ def f(carry, unused_t):
# Update normalization params and normalize observations.
normalizer_params = running_statistics.update(
training_state.normalizer_params,
data.observation if ndarray_obs else data.observation['state'],
data.observation,
btaba marked this conversation as resolved.
Show resolved Hide resolved
pmap_axis_name=_PMAP_AXIS_NAME)

(optimizer_state, params, _), metrics = jax.lax.scan(
Expand Down Expand Up @@ -393,12 +392,13 @@ def training_epoch_with_timing(
value=ppo_network.value_network.init(key_value),
)

obs_shape = env_state.obs.shape if ndarray_obs else env_state.obs['state'].shape
obs_shape = jax.tree_util.tree_map(
lambda x: specs.Array(x.shape[-1:], jnp.dtype('float32')), env_state.obs
)
training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray
optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars
params=init_params,
normalizer_params=running_statistics.init_state(
specs.Array(obs_shape[-1:], jnp.dtype('float32'))),
normalizer_params=running_statistics.init_state(obs_shape),
env_steps=0)

if (
Expand Down
42 changes: 42 additions & 0 deletions brax/training/agents/ppo/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
"""PPO tests."""
import pickle

import functools
from absl.testing import absltest
from absl.testing import parameterized
from brax import envs
from brax.training.acme import running_statistics
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
import jax
import jax.numpy as jp


class PPOTest(parameterized.TestCase):
Expand Down Expand Up @@ -131,6 +133,46 @@ def get_offset(rng):
randomization_fn=rand_fn,
)

def testTrainAsymmetricActorCritic(self):
"""Test PPO with asymmetric actor critic."""
env = envs.get_environment('fast', asymmetric_obs=True, use_dict_obs=True)

network_factory = functools.partial(
ppo_networks.make_ppo_networks,
policy_hidden_layer_sizes=(32,),
value_hidden_layer_sizes=(32,),
policy_obs_key='state',
value_obs_key='privileged_state'
)

_, (_, policy_params, value_params), _ = ppo.train(
env,
num_timesteps=2**15,
episode_length=1000,
num_envs=64,
learning_rate=3e-4,
entropy_cost=1e-2,
discounting=0.95,
unroll_length=5,
batch_size=64,
num_minibatches=8,
num_updates_per_batch=4,
normalize_observations=False,
seed=2,
reward_scaling=10,
normalize_advantage=False,
network_factory=network_factory,
)

self.assertEqual(
policy_params['params']['hidden_0']['kernel'].shape,
(env.observation_size['state'], 32),
)
self.assertEqual(
value_params['params']['hidden_0']['kernel'].shape,
(env.observation_size['privileged_state'], 32),
)


if __name__ == '__main__':
absltest.main()
20 changes: 10 additions & 10 deletions brax/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def __call__(self, data: jnp.ndarray):
hidden = self.activation(hidden)
return hidden

def get_obs_state_size(obs_size: types.ObservationSize) -> int:
obs_size = obs_size['state'] if isinstance(obs_size, Mapping) else obs_size
def get_obs_state_size(obs_size: types.ObservationSize, obs_key: str) -> int:
obs_size = obs_size[obs_key] if isinstance(obs_size, Mapping) else obs_size
return jax.tree_util.tree_flatten(obs_size)[0][-1] # Size can be tuple or int.

def make_policy_network(
Expand All @@ -94,7 +94,8 @@ def make_policy_network(
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: ActivationFn = linen.relu,
kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
layer_norm: bool = False) -> FeedForwardNetwork:
layer_norm: bool = False,
obs_key: str = 'state') -> FeedForwardNetwork:
"""Creates a policy network."""
policy_module = MLP(
layer_sizes=list(hidden_layer_sizes) + [param_size],
Expand All @@ -103,12 +104,11 @@ def make_policy_network(
layer_norm=layer_norm)

def apply(processor_params, policy_params, obs):
obs = (obs if isinstance(obs, jnp.ndarray)
else obs['state']) # state-only in the case of dict obs.
obs = preprocess_observations_fn(obs, processor_params)
obs = obs if isinstance(obs, jnp.ndarray) else obs[obs_key]
return policy_module.apply(policy_params, obs)

obs_size = get_obs_state_size(obs_size)
obs_size = get_obs_state_size(obs_size, obs_key)
dummy_obs = jnp.zeros((1, obs_size))
return FeedForwardNetwork(
init=lambda key: policy_module.init(key, dummy_obs), apply=apply)
Expand All @@ -119,20 +119,20 @@ def make_value_network(
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: ActivationFn = linen.relu) -> FeedForwardNetwork:
activation: ActivationFn = linen.relu,
obs_key: str = 'state') -> FeedForwardNetwork:
"""Creates a value network."""
value_module = MLP(
layer_sizes=list(hidden_layer_sizes) + [1],
activation=activation,
kernel_init=jax.nn.initializers.lecun_uniform())

def apply(processor_params, value_params, obs):
obs = (obs if isinstance(obs, jnp.ndarray)
else obs['state']) # state-only in the case of dict obs.
obs = preprocess_observations_fn(obs, processor_params)
obs = obs if isinstance(obs, jnp.ndarray) else obs[obs_key]
return jnp.squeeze(value_module.apply(value_params, obs), axis=-1)

obs_size = get_obs_state_size(obs_size)
obs_size = get_obs_state_size(obs_size, obs_key)
dummy_obs = jnp.zeros((1, obs_size))
return FeedForwardNetwork(
init=lambda key: value_module.init(key, dummy_obs), apply=apply)
Expand Down
Loading