From f0cf1092201f66724bf4d38f3b222beaa5e208f3 Mon Sep 17 00:00:00 2001 From: Brax Team Date: Tue, 12 Dec 2023 12:00:59 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 590287478 Change-Id: Iab96c438d2dda8fcb575e6ef03091bdcdb00c559 --- README.md | 13 ++- brax/base.py | 18 +++- brax/envs/ant.py | 13 ++- brax/envs/assets/humanoid.xml | 2 +- brax/envs/base.py | 3 +- brax/envs/humanoid.py | 13 ++- brax/io/json.py | 4 +- brax/io/mjcf.py | 1 + brax/mjx/__init__.py | 14 +++ brax/mjx/base.py | 30 +++++++ brax/mjx/perf_test.py | 43 +++++++++ brax/mjx/pipeline.py | 89 +++++++++++++++++++ brax/mjx/pipeline_test.py | 48 ++++++++++ brax/training/agents/ppo/train.py | 54 ++++++++++- .../composer/components/common.py | 6 +- .../experimental/composer/reward_functions.py | 8 +- 16 files changed, 330 insertions(+), 29 deletions(-) create mode 100644 brax/mjx/__init__.py create mode 100644 brax/mjx/base.py create mode 100644 brax/mjx/perf_test.py create mode 100644 brax/mjx/pipeline.py create mode 100644 brax/mjx/pipeline_test.py diff --git a/README.md b/README.md index 567f6e58..686b91a5 100644 --- a/README.md +++ b/README.md @@ -21,16 +21,15 @@ to minutes: [evolutionary strategies](https://github.com/google/brax/blob/main/brax/training/agents/es). * Learning algorithms that leverage the differentiability of the simulator, such as [analytic policy gradients](https://github.com/google/brax/blob/main/brax/training/agents/apg). -## One API, Three Pipelines +## One API, Four Pipelines -Brax offers three distinct physics pipelines that are easy to swap: +Brax offers four distinct physics pipelines that are easy to swap: +* [MuJoCo XLA - MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) - a JAX +reimplementation of the MuJoCo physics engine. * [Generalized](https://github.com/google/brax/blob/main/brax/v2/generalized/) -calculates motion in [generalized coordinates](https://en.wikipedia.org/wiki/Generalized_coordinates) using the same accurate robot -dynamics algorithms as [MuJoCo](https://mujoco.org/) and [TDS](https://github.com/erwincoumans/tiny-differentiable-simulator). - - NOTE: We plan to import [MuJoCo XLA - MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) as another physics backend, eventually replacing `generalized`. Check out our recent [announcement](https://github.com/google/brax/discussions/409). - +calculates motion in [generalized coordinates](https://en.wikipedia.org/wiki/Generalized_coordinates) +using dynamics algorithms similar to [MuJoCo](https://mujoco.org/) and [TDS](https://github.com/erwincoumans/tiny-differentiable-simulator). * [Positional](https://github.com/google/brax/blob/main/brax/v2/positional/) uses [Position Based Dynamics](https://matthias-research.github.io/pages/publications/posBasedDyn.pdf), a fast but stable method of resolving joint and collision constraints. diff --git a/brax/base.py b/brax/base.py index b113a327..6cdcba6c 100644 --- a/brax/base.py +++ b/brax/base.py @@ -17,7 +17,7 @@ import copy import functools -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from brax import math from flax import struct @@ -25,7 +25,8 @@ from jax import numpy as jp from jax import vmap from jax.tree_util import tree_map - +import mujoco +from mujoco import mjx # f: free, 1: 1-dof, 2: 2-dof, 3: 3-dof Q_WIDTHS = {'f': 7, '1': 1, '2': 2, '3': 3} @@ -617,6 +618,7 @@ class System(Base): matrix_inv_iterations: int = struct.field(pytree_node=False) solver_iterations: int = struct.field(pytree_node=False) solver_maxls: int = struct.field(pytree_node=False) + _model: mujoco.MjModel = struct.field(pytree_node=False, default=None) def num_links(self) -> int: """Returns the number of links in the system.""" @@ -678,6 +680,18 @@ def act_size(self) -> int: """Returns the act dimension for the system.""" return self.actuator.q_id.shape[0] + def set_model(self, model: mujoco.MjModel): + """Sets the source MuJoCo model of this System.""" + object.__setattr__(self, '_model', model) + + def get_model(self) -> mujoco.MjModel: + """Returns the source MuJoCo model of this System.""" + return self._model + + def get_mjx_model(self) -> mjx.Model: + """Returns an MJX model of this System.""" + return mjx.put_model(getattr(self, '_model')) + # below are some operation dispatch derivations diff --git a/brax/envs/ant.py b/brax/envs/ant.py index 49fb3e3e..18c779b6 100644 --- a/brax/envs/ant.py +++ b/brax/envs/ant.py @@ -22,6 +22,7 @@ from etils import epath import jax from jax import numpy as jp +import mujoco class Ant(PipelineEnv): @@ -166,6 +167,12 @@ def __init__( sys = sys.replace(dt=0.005) n_frames = 10 + if backend == 'mjx': + sys._model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON + sys._model.opt.disableflags = mujoco.mjtDisableBit.mjDSBL_EULERDAMP + sys._model.opt.iterations = 1 + sys._model.opt.ls_iterations = 4 + if backend == 'positional': # TODO: does the same actuator strength work as in spring sys = sys.replace( @@ -230,10 +237,8 @@ def step(self, state: State, action: jax.Array) -> State: forward_reward = velocity[0] min_z, max_z = self._healthy_z_range - is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, x=0.0, y=1.0) - is_healthy = jp.where( - pipeline_state.x.pos[0, 2] > max_z, x=0.0, y=is_healthy - ) + is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, 0.0, 1.0) + is_healthy = jp.where(pipeline_state.x.pos[0, 2] > max_z, 0.0, is_healthy) if self._terminate_when_unhealthy: healthy_reward = self._healthy_reward else: diff --git a/brax/envs/assets/humanoid.xml b/brax/envs/assets/humanoid.xml index a536674d..83ba08e1 100644 --- a/brax/envs/assets/humanoid.xml +++ b/brax/envs/assets/humanoid.xml @@ -2,7 +2,7 @@ - + diff --git a/brax/envs/base.py b/brax/envs/base.py index 6351fea4..ba9ba806 100644 --- a/brax/envs/base.py +++ b/brax/envs/base.py @@ -20,11 +20,11 @@ from brax import base from brax.generalized import pipeline as g_pipeline +from brax.mjx import pipeline as m_pipeline from brax.positional import pipeline as p_pipeline from brax.spring import pipeline as s_pipeline from flax import struct import jax -from jax import numpy as jp @struct.dataclass @@ -98,6 +98,7 @@ def __init__( pipeline = { 'generalized': g_pipeline, + 'mjx': m_pipeline, 'spring': s_pipeline, 'positional': p_pipeline, } diff --git a/brax/envs/humanoid.py b/brax/envs/humanoid.py index 3e718418..fdc9520c 100644 --- a/brax/envs/humanoid.py +++ b/brax/envs/humanoid.py @@ -22,6 +22,7 @@ from etils import epath import jax from jax import numpy as jp +import mujoco class Humanoid(PipelineEnv): @@ -201,6 +202,12 @@ def __init__( 350.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0]) # pyformat: disable sys = sys.replace(actuator=sys.actuator.replace(gear=gear)) + if backend == 'mjx': + sys._model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON + sys._model.opt.disableflags = mujoco.mjtDisableBit.mjDSBL_EULERDAMP + sys._model.opt.iterations = 1 + sys._model.opt.ls_iterations = 4 + kwargs['n_frames'] = kwargs.get('n_frames', n_frames) super().__init__(sys=sys, backend=backend, **kwargs) @@ -255,10 +262,8 @@ def step(self, state: State, action: jax.Array) -> State: forward_reward = self._forward_reward_weight * velocity[0] min_z, max_z = self._healthy_z_range - is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, x=0.0, y=1.0) - is_healthy = jp.where( - pipeline_state.x.pos[0, 2] > max_z, x=0.0, y=is_healthy - ) + is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, 0.0, 1.0) + is_healthy = jp.where(pipeline_state.x.pos[0, 2] > max_z, 0.0, is_healthy) if self._terminate_when_unhealthy: healthy_reward = self._healthy_reward else: diff --git a/brax/io/json.py b/brax/io/json.py index 015eee58..26fba362 100644 --- a/brax/io/json.py +++ b/brax/io/json.py @@ -27,13 +27,15 @@ # State attributes needed for the visualizer. _STATE_ATTR = ['x', 'contact'] +_IGNORE_FIELDS = ['_model', '_mjx_model'] + def _to_dict(obj): """Converts python object to a json encodeable object.""" if isinstance(obj, list) or isinstance(obj, tuple): return [_to_dict(s) for s in obj] if isinstance(obj, dict): - return {k: _to_dict(v) for k, v in obj.items()} + return {k: _to_dict(v) for k, v in obj.items() if k not in _IGNORE_FIELDS} if isinstance(obj, jax.Array): return _to_dict(obj.tolist()) if hasattr(obj, '__dict__'): diff --git a/brax/io/mjcf.py b/brax/io/mjcf.py index 46903c0a..7da5bd85 100644 --- a/brax/io/mjcf.py +++ b/brax/io/mjcf.py @@ -497,6 +497,7 @@ def load_model(mj: mujoco.MjModel) -> System: ) sys = jax.tree_map(jp.array, sys) + sys.set_model(mj) return sys diff --git a/brax/mjx/__init__.py b/brax/mjx/__init__.py new file mode 100644 index 00000000..5920485e --- /dev/null +++ b/brax/mjx/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/brax/mjx/base.py b/brax/mjx/base.py new file mode 100644 index 00000000..3bb5aa76 --- /dev/null +++ b/brax/mjx/base.py @@ -0,0 +1,30 @@ +# Copyright 2023 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint:disable=g-multiple-import +"""Brax adapter for MJX physics engine.""" + +from brax import base +from flax import struct +from mujoco import mjx + + +@struct.dataclass +class State(base.State): + """Dynamic state that changes after every step. + + Attributes: + data: mjx.Data + """ + data: mjx.Data diff --git a/brax/mjx/perf_test.py b/brax/mjx/perf_test.py new file mode 100644 index 00000000..dc25fc42 --- /dev/null +++ b/brax/mjx/perf_test.py @@ -0,0 +1,43 @@ +# Copyright 2023 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint:disable=g-multiple-import +"""PBD perf tests.""" + +from absl.testing import absltest +from brax import test_utils +from brax.mjx import pipeline +import jax +from jax import numpy as jp + + +class PerfTest(absltest.TestCase): + + def test_pipeline_ant(self): + sys = test_utils.load_fixture('ant.xml') + + def init_fn(rng): + rng1, rng2 = jax.random.split(rng, 2) + q = jax.random.uniform(rng1, (sys.q_size(),), minval=-0.1, maxval=0.1) + qd = 0.1 * jax.random.normal(rng2, (sys.qd_size(),)) + return pipeline.init(sys, q, qd) + + def step_fn(state): + return pipeline.step(sys, state, jp.zeros(sys.act_size())) + + test_utils.benchmark('mjx pipeline ant', init_fn, step_fn) + + +if __name__ == '__main__': + absltest.main() diff --git a/brax/mjx/pipeline.py b/brax/mjx/pipeline.py new file mode 100644 index 00000000..5090b39e --- /dev/null +++ b/brax/mjx/pipeline.py @@ -0,0 +1,89 @@ +# Copyright 2023 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Physics pipeline for fully articulated dynamics and collisiion.""" +# pylint:disable=g-multiple-import +# pylint:disable=g-importing-member +from brax.base import Motion, System, Transform +from brax.mjx.base import State +import jax +from jax import numpy as jp +from mujoco import mjx + + +def init( + sys: System, q: jax.Array, qd: jax.Array, debug: bool = False +) -> State: + """Initializes physics state. + + Args: + sys: a brax system + q: (q_size,) joint angle vector + qd: (qd_size,) joint velocity vector + debug: if True, adds contact to the state for debugging + + Returns: + state: initial physics state + """ + del debug # ignored in mjx pipeline + + model = sys.get_mjx_model() + data = mjx.make_data(model) + data = data.replace(qpos=q, qvel=qd) + data = mjx.forward(model, data) + + q, qd = data.qpos, data.qvel + x = Transform(pos=data.xpos[1:], rot=data.xquat[1:]) + cvel = Motion(vel=data.cvel[1:, 3:], ang=data.cvel[1:, :3]) + offset = data.xpos[1:, :] - data.subtree_com[model.body_rootid[1:]] + offset = Transform.create(pos=offset) + xd = offset.vmap().do(cvel) + contact = None + + return State(q, qd, x, xd, contact, data) + + +def step( + sys: System, state: State, act: jax.Array, debug: bool = False +) -> State: + """Performs a single physics step using position-based dynamics. + + Resolves actuator forces, joints, and forces at acceleration level, and + resolves collisions at velocity level with baumgarte stabilization. + + Args: + sys: system defining the kinematic tree and other properties + state: physics state prior to step + act: (act_size,) actuator input vector + debug: if True, adds contact to the state for debugging + + Returns: + x: updated link transform in world frame + xd: updated link motion in world frame + """ + del debug # ignored in mjx pipeline + + model = sys.get_mjx_model() + data = state.data.replace(ctrl=act) + data = mjx.step(model, data) + + q, qd = data.qpos, data.qvel + x = Transform(pos=data.xpos[1:], rot=data.xquat[1:]) + cvel = Motion(vel=data.cvel[1:, 3:], ang=data.cvel[1:, :3]) + offset = data.xpos[1:, :] - data.subtree_com[model.body_rootid[1:]] + offset = Transform.create(pos=offset) + xd = offset.vmap().do(cvel) + contact = None + + return State(q, qd, x, xd, contact, data) diff --git a/brax/mjx/pipeline_test.py b/brax/mjx/pipeline_test.py new file mode 100644 index 00000000..3c28424e --- /dev/null +++ b/brax/mjx/pipeline_test.py @@ -0,0 +1,48 @@ +# Copyright 2023 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint:disable=g-multiple-import +"""Tests for spring physics pipeline.""" + +from absl.testing import absltest +from brax import test_utils +from brax.mjx import pipeline +import jax +from jax import numpy as jp +import mujoco +import numpy as np + + +class PipelineTest(absltest.TestCase): + + def test_pendulum(self): + sys = test_utils.load_fixture('double_pendulum.xml') + + state = pipeline.init(sys, sys.init_q, jp.zeros(sys.qd_size())) + step_fn = jax.jit(pipeline.step) + for _ in range(20): + state = step_fn(sys, state, jp.zeros(sys.act_size())) + + # compare against mujoco + model = test_utils.load_fixture_mujoco('double_pendulum.xml') + data = mujoco.MjData(model) + mujoco.mj_step(model, data, 20) + + np.testing.assert_almost_equal(data.qpos, state.q, decimal=4) + np.testing.assert_almost_equal(data.qvel, state.qd, decimal=3) + np.testing.assert_almost_equal(data.xpos[1:], state.x.pos, decimal=4) + + +if __name__ == '__main__': + absltest.main() diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index a0e16736..52ce7e86 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -19,7 +19,7 @@ import functools import time -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union from absl import logging from brax import base @@ -104,7 +104,57 @@ def train( Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]] ] = None, ): - """PPO training.""" + """PPO training. + + Args: + environment: the environment to train + num_timesteps: the total number of environment steps to use during training + episode_length: the length of an environment episode + action_repeat: the number of timesteps to repeat an action + num_envs: the number of parallel environments to use for rollouts + NOTE: `num_envs` must be divisible by the total number of chips since each + chip gets `num_envs // total_number_of_chips` environments to roll out + NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since + data generated by `num_envs` parallel envs gets used for gradient + updates over `num_minibatches` of data, where each minibatch has a + leading dimension of `batch_size` + max_devices_per_host: maximum number of chips to use per host process + num_eval_envs: the number of envs to use for evluation. Each env will run 1 + episode, and all envs run in parallel during eval. + learning_rate: learning rate for ppo loss + entropy_cost: entropy reward for ppo loss, higher values increase entropy + of the policy + discounting: discounting rate + seed: random seed + unroll_length: the number of timesteps to unroll in each environment. The + PPO loss is computed over `unroll_length` timesteps + batch_size: the batch size for each minibatch SGD step + num_minibatches: the number of times to run the SGD step, each with a + different minibatch with leading dimension of `batch_size` + num_updates_per_batch: the number of times to run the gradient update over + all minibatches before doing a new environment rollout + num_evals: the number of evals to run during the entire training run. + Increasing the number of evals increases total training time + num_resets_per_eval: the number of environment resets to run between each + eval. The environment resets occur on the host + normalize_observations: whether to normalize observations + reward_scaling: float scaling for reward + clipping_epsilon: clipping epsilon for PPO loss + gae_lambda: General advantage estimation lambda + deterministic_eval: whether to run the eval with a deterministic policy + network_factory: function that generates networks for policy and value + functions + progress_fn: a user-defined callback function for reporting/plotting metrics + normalize_advantage: whether to normalize advantage estimate + eval_env: an optional environment for eval only, defaults to `environment` + policy_params_fn: a user-defined callback function that can be used for + saving policy checkpoints + randomization_fn: a user-defined callback function that generates randomized + environments + + Returns: + Tuple of (make_policy function, network params, metrics) + """ assert batch_size * num_minibatches % num_envs == 0 xt = time.time() diff --git a/brax/v1/experimental/composer/components/common.py b/brax/v1/experimental/composer/components/common.py index 9a7e1933..785b1e2f 100644 --- a/brax/v1/experimental/composer/components/common.py +++ b/brax/v1/experimental/composer/components/common.py @@ -28,7 +28,7 @@ def upright_term_fn(done, sys, qp: brax.QP, info: brax.Info, component): up = jnp.array([0., 0., 1.]) torso_up = math.rotate(up, rot) torso_is_up = jnp.dot(torso_up, up) - done = jnp.where(torso_is_up < 0.0, x=1.0, y=done) + done = jnp.where(torso_is_up < 0.0, 1.0, done) return done @@ -45,6 +45,6 @@ def height_term_fn(done, z_offset = component.get('term_params', {}).get('z_offset', 0.0) index = sim_utils.names2indices(sys.config, component['root'], 'body')[0][0] z = qp.pos[index][2] - done = jnp.where(z < min_height + z_offset, x=1.0, y=done) - done = jnp.where(z > max_height + z_offset, x=1.0, y=done) + done = jnp.where(z < min_height + z_offset, 1.0, done) + done = jnp.where(z > max_height + z_offset, 1.0, done) return done diff --git a/brax/v1/experimental/composer/reward_functions.py b/brax/v1/experimental/composer/reward_functions.py index 7d0c115c..b72b8444 100644 --- a/brax/v1/experimental/composer/reward_functions.py +++ b/brax/v1/experimental/composer/reward_functions.py @@ -61,8 +61,8 @@ def fn(*args, **kwargs): score = reward reward = (reward + offset) * scale score *= jnp.sign(scale) - reward = jnp.where(done, x=reward + done_bonus, y=reward) - score = jnp.where(done, x=score + done_bonus, y=score) + reward = jnp.where(done, reward + done_bonus, reward) + score = jnp.where(done, score + done_bonus, score) return reward, score, done return fn @@ -168,8 +168,8 @@ def distance_reward(action: jnp.ndarray, # instead of clipping, terminate # dist = jnp.clip(dist, a_min=min_dist, a_max=max_dist) done = jnp.zeros_like(dist) - done = jnp.where(dist < min_dist, x=jnp.ones_like(done), y=done) - done = jnp.where(dist > max_dist, x=jnp.ones_like(done), y=done) + done = jnp.where(dist < min_dist, jnp.ones_like(done), done) + done = jnp.where(dist > max_dist, jnp.ones_like(done), done) return -dist, done