Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721892762
Change-Id: I51595f8e285d9f4614f6d7251597941441cbcc5d
  • Loading branch information
Brax Team authored and btaba committed Jan 31, 2025
1 parent 02de34b commit 8526f9a
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 36 deletions.
14 changes: 7 additions & 7 deletions brax/envs/inverted_double_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class InvertedDoublePendulum(PipelineEnv):


# pyformat: disable
"""### Description
r"""### Description
This environment originates from control theory and builds on the cartpole
environment based on the work done by Barto, Sutton, and Anderson in
Expand Down Expand Up @@ -117,10 +117,10 @@ class InvertedDoublePendulum(PipelineEnv):
### Episode Termination
The episode terminates when the y_coordinate of the tip of the second
The episode terminates when the y_coordinate of the tip of the second
pole $\leq 1$.
Note: The maximum standing height of the system is 1.2 m when all the parts
Note: The maximum standing height of the system is 1.2 m when all the parts
are perpendicularly vertical on top of each other.
"""
# pyformat: enable
Expand All @@ -144,7 +144,7 @@ def __init__(self, backend='generalized', **kwargs):

def reset(self, rng: jax.Array) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)
_, rng1, rng2 = jax.random.split(rng, 3)

q = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=-0.01, maxval=0.01
Expand All @@ -163,11 +163,11 @@ def step(self, state: State, action: jax.Array) -> State:
pipeline_state = self.pipeline_step(state.pipeline_state, action)

tip = pipeline_state.x.take(2).do(
base.Transform.create(pos=jp.array([0.0, 0.0, 0.6]))
base.Transform.create(pos=jp.array([0.0, 0.0, 0.6]))
)
x, _, y = tip.pos
v1, v2 = pipeline_state.qd[1:]

dist_penalty = 0.01 * x**2 + (y - 2) ** 2
vel_penalty = 1e-3 * v1**2 + 5e-3 * v2**2
alive_bonus = 10
Expand Down
148 changes: 148 additions & 0 deletions brax/training/agents/ppo/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Checkpointing for PPO."""

import inspect
import json
import logging
from typing import Any, Dict, Tuple, Union

from brax.training import types
from brax.training.acme import running_statistics
from brax.training.agents.ppo import networks as ppo_networks
from etils import epath
from flax import linen
from flax.training import orbax_utils
from ml_collections import config_dict
from orbax import checkpoint as ocp

_CONFIG_FNAME = 'config.json'


def _get_default_kwargs(func: Any) -> Dict[str, Any]:
"""Returns the default kwargs of a function."""
return {
p.name: p.default
for p in inspect.signature(func).parameters.values()
if p.default is not inspect.Parameter.empty
}


def ppo_config(
observation_size: types.ObservationSize,
action_size: int,
normalize_observations: bool,
network_factory: types.NetworkFactory[ppo_networks.PPONetworks],
) -> config_dict.ConfigDict:
"""Returns a config dict for re-creating PPO params from a checkpoint."""
config = config_dict.ConfigDict()
kwargs = _get_default_kwargs(network_factory)

if (
kwargs.get('preprocess_observations_fn')
!= types.identity_observation_preprocessor
):
raise ValueError(
'preprocess_observations_fn must be identity_observation_preprocessor'
)
del kwargs['preprocess_observations_fn']
if kwargs.get('activation') != linen.swish:
raise ValueError('activation must be swish')
del kwargs['activation']

config.network_factory_kwargs = kwargs
config.normalize_observations = normalize_observations
config.observation_size = observation_size
config.action_size = action_size
return config


def save(
path: Union[str, epath.Path],
step: int,
params: Tuple[Any, ...],
config: config_dict.ConfigDict,
):
"""Saves a checkpoint."""
ckpt_path = epath.Path(path) / f'{step:012d}'
logging.info('saving checkpoint to %s', ckpt_path.as_posix())

if not ckpt_path.exists():
ckpt_path.mkdir(parents=True)

config_path = epath.Path(path) / _CONFIG_FNAME
if not config_path.exists():
config_path.write_text(config.to_json())

orbax_checkpointer = ocp.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(params)
orbax_checkpointer.save(ckpt_path, params, force=True, save_args=save_args)


def load(
path: Union[str, epath.Path],
):
"""Loads PPO checkpoint."""
path = epath.Path(path)
if not path.exists():
raise ValueError(f'PPO checkpoint path does not exist: {path.as_posix()}')

logging.info('restoring from checkpoint %s', path.as_posix())

orbax_checkpointer = ocp.PyTreeCheckpointer()
target = orbax_checkpointer.restore(path, item=None)
target[0] = running_statistics.RunningStatisticsState(**target[0])

return target


def _get_network(
config: config_dict.ConfigDict,
network_factory: types.NetworkFactory[ppo_networks.PPONetworks],
) -> ppo_networks.PPONetworks:
"""Generates a PPO network given config."""
normalize = lambda x, y: x
if config.normalize_observations:
normalize = running_statistics.normalize
ppo_network = network_factory(
config.to_dict()['observation_size'],
config.action_size,
preprocess_observations_fn=normalize,
**config.network_factory_kwargs,
)
return ppo_network


def load_policy(
path: Union[str, epath.Path],
network_factory: types.NetworkFactory[
ppo_networks.PPONetworks
] = ppo_networks.make_ppo_networks,
deterministic: bool = True,
):
"""Loads policy inference function from PPO checkpoint."""
path = epath.Path(path)

config_path = path.parent / _CONFIG_FNAME
if not config_path.exists():
raise ValueError(f'PPO config file not found at {config_path.as_posix()}')

config = config_dict.create(**json.loads(config_path.read_text()))

params = load(path)
ppo_network = _get_network(config, network_factory)
make_inference_fn = ppo_networks.make_inference_fn(ppo_network)

return make_inference_fn(params, deterministic=deterministic)
98 changes: 98 additions & 0 deletions brax/training/agents/ppo/checkpoint_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test PPO checkpointing."""

import functools

from absl.testing import absltest
from brax.training.acme import running_statistics
from brax.training.agents.ppo import checkpoint
from brax.training.agents.ppo import losses as ppo_losses
from brax.training.agents.ppo import networks as ppo_networks
from etils import epath
import jax
from jax import numpy as jp


class CheckpointTest(absltest.TestCase):

def test_ppo_params_config(self):
network_factory = functools.partial(
ppo_networks.make_ppo_networks,
policy_hidden_layer_sizes=(16, 21, 13),
)
config = checkpoint.ppo_config(
action_size=3,
observation_size=1,
normalize_observations=True,
network_factory=network_factory,
)
self.assertEqual(
config.network_factory_kwargs.to_dict()["policy_hidden_layer_sizes"],
(16, 21, 13),
)
self.assertEqual(config.action_size, 3)
self.assertEqual(config.observation_size, 1)

def test_save_and_load_checkpoint(self):
path = self.create_tempdir("test")
network_factory = functools.partial(
ppo_networks.make_ppo_networks,
policy_hidden_layer_sizes=(16, 21, 13),
)
config = checkpoint.ppo_config(
observation_size=1,
action_size=3,
normalize_observations=True,
network_factory=network_factory,
)

# Generate network params for saving a dummy checkpoint.
normalize = lambda x, y: x
if config.normalize_observations:
normalize = running_statistics.normalize
ppo_network = network_factory(
config.observation_size,
config.action_size,
preprocess_observations_fn=normalize,
**config.network_factory_kwargs,
)
dummy_key = jax.random.PRNGKey(0)
network_params = ppo_losses.PPONetworkParams(
policy=ppo_network.policy_network.init(dummy_key),
value=ppo_network.value_network.init(dummy_key),
)
normalizer_params = running_statistics.init_state(
jax.tree_util.tree_map(jp.zeros, config.observation_size)
)
params = (normalizer_params, network_params.policy, network_params.value)

# Save and load a checkpoint.
checkpoint.save(
path.full_path,
step=1,
params=params,
config=config,
)

policy_fn = checkpoint.load_policy(
epath.Path(path.full_path) / "000000000001",
)
out = policy_fn(jp.zeros(1), jax.random.PRNGKey(0))
self.assertEqual(out[0].shape, (3,))


if __name__ == "__main__":
absltest.main()
2 changes: 1 addition & 1 deletion brax/training/agents/ppo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def policy(


def make_ppo_networks(
observation_size: int,
observation_size: types.ObservationSize,
action_size: int,
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
policy_hidden_layer_sizes: Sequence[int] = (32,) * 4,
Expand Down
Loading

0 comments on commit 8526f9a

Please sign in to comment.