Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590287478
Change-Id: Iab96c438d2dda8fcb575e6ef03091bdcdb00c559
  • Loading branch information
Brax Team authored and erikfrey committed Dec 12, 2023
1 parent 1630403 commit f0cf109
Show file tree
Hide file tree
Showing 16 changed files with 330 additions and 29 deletions.
13 changes: 6 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,15 @@ to minutes:
[evolutionary strategies](https://github.com/google/brax/blob/main/brax/training/agents/es).
* Learning algorithms that leverage the differentiability of the simulator, such as [analytic policy gradients](https://github.com/google/brax/blob/main/brax/training/agents/apg).

## One API, Three Pipelines
## One API, Four Pipelines

Brax offers three distinct physics pipelines that are easy to swap:
Brax offers four distinct physics pipelines that are easy to swap:

* [MuJoCo XLA - MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) - a JAX
reimplementation of the MuJoCo physics engine.
* [Generalized](https://github.com/google/brax/blob/main/brax/v2/generalized/)
calculates motion in [generalized coordinates](https://en.wikipedia.org/wiki/Generalized_coordinates) using the same accurate robot
dynamics algorithms as [MuJoCo](https://mujoco.org/) and [TDS](https://github.com/erwincoumans/tiny-differentiable-simulator).

NOTE: We plan to import [MuJoCo XLA - MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) as another physics backend, eventually replacing `generalized`. Check out our recent [announcement](https://github.com/google/brax/discussions/409).

calculates motion in [generalized coordinates](https://en.wikipedia.org/wiki/Generalized_coordinates)
using dynamics algorithms similar to [MuJoCo](https://mujoco.org/) and [TDS](https://github.com/erwincoumans/tiny-differentiable-simulator).
* [Positional](https://github.com/google/brax/blob/main/brax/v2/positional/)
uses [Position Based Dynamics](https://matthias-research.github.io/pages/publications/posBasedDyn.pdf),
a fast but stable method of resolving joint and collision constraints.
Expand Down
18 changes: 16 additions & 2 deletions brax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@

import copy
import functools
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

from brax import math
from flax import struct
import jax
from jax import numpy as jp
from jax import vmap
from jax.tree_util import tree_map

import mujoco
from mujoco import mjx

# f: free, 1: 1-dof, 2: 2-dof, 3: 3-dof
Q_WIDTHS = {'f': 7, '1': 1, '2': 2, '3': 3}
Expand Down Expand Up @@ -617,6 +618,7 @@ class System(Base):
matrix_inv_iterations: int = struct.field(pytree_node=False)
solver_iterations: int = struct.field(pytree_node=False)
solver_maxls: int = struct.field(pytree_node=False)
_model: mujoco.MjModel = struct.field(pytree_node=False, default=None)

def num_links(self) -> int:
"""Returns the number of links in the system."""
Expand Down Expand Up @@ -678,6 +680,18 @@ def act_size(self) -> int:
"""Returns the act dimension for the system."""
return self.actuator.q_id.shape[0]

def set_model(self, model: mujoco.MjModel):
"""Sets the source MuJoCo model of this System."""
object.__setattr__(self, '_model', model)

def get_model(self) -> mujoco.MjModel:
"""Returns the source MuJoCo model of this System."""
return self._model

def get_mjx_model(self) -> mjx.Model:
"""Returns an MJX model of this System."""
return mjx.put_model(getattr(self, '_model'))


# below are some operation dispatch derivations

Expand Down
13 changes: 9 additions & 4 deletions brax/envs/ant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from etils import epath
import jax
from jax import numpy as jp
import mujoco


class Ant(PipelineEnv):
Expand Down Expand Up @@ -166,6 +167,12 @@ def __init__(
sys = sys.replace(dt=0.005)
n_frames = 10

if backend == 'mjx':
sys._model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
sys._model.opt.disableflags = mujoco.mjtDisableBit.mjDSBL_EULERDAMP
sys._model.opt.iterations = 1
sys._model.opt.ls_iterations = 4

if backend == 'positional':
# TODO: does the same actuator strength work as in spring
sys = sys.replace(
Expand Down Expand Up @@ -230,10 +237,8 @@ def step(self, state: State, action: jax.Array) -> State:
forward_reward = velocity[0]

min_z, max_z = self._healthy_z_range
is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, x=0.0, y=1.0)
is_healthy = jp.where(
pipeline_state.x.pos[0, 2] > max_z, x=0.0, y=is_healthy
)
is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, 0.0, 1.0)
is_healthy = jp.where(pipeline_state.x.pos[0, 2] > max_z, 0.0, is_healthy)
if self._terminate_when_unhealthy:
healthy_reward = self._healthy_reward
else:
Expand Down
2 changes: 1 addition & 1 deletion brax/envs/assets/humanoid.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<compiler angle="degree" inertiafromgeom="true"/>
<default>
<joint armature="1" damping="1" limited="true"/>
<geom conaffinity="0" condim="1" contype="0" material="geom"/>
<geom conaffinity="0" condim="3" contype="0" material="geom"/>
<motor ctrllimited="true" ctrlrange="-.4 .4"/>
</default>
<!-- Removed RK4 integrator for brax. -->
Expand Down
3 changes: 2 additions & 1 deletion brax/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

from brax import base
from brax.generalized import pipeline as g_pipeline
from brax.mjx import pipeline as m_pipeline
from brax.positional import pipeline as p_pipeline
from brax.spring import pipeline as s_pipeline
from flax import struct
import jax
from jax import numpy as jp


@struct.dataclass
Expand Down Expand Up @@ -98,6 +98,7 @@ def __init__(

pipeline = {
'generalized': g_pipeline,
'mjx': m_pipeline,
'spring': s_pipeline,
'positional': p_pipeline,
}
Expand Down
13 changes: 9 additions & 4 deletions brax/envs/humanoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from etils import epath
import jax
from jax import numpy as jp
import mujoco


class Humanoid(PipelineEnv):
Expand Down Expand Up @@ -201,6 +202,12 @@ def __init__(
350.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0]) # pyformat: disable
sys = sys.replace(actuator=sys.actuator.replace(gear=gear))

if backend == 'mjx':
sys._model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
sys._model.opt.disableflags = mujoco.mjtDisableBit.mjDSBL_EULERDAMP
sys._model.opt.iterations = 1
sys._model.opt.ls_iterations = 4

kwargs['n_frames'] = kwargs.get('n_frames', n_frames)

super().__init__(sys=sys, backend=backend, **kwargs)
Expand Down Expand Up @@ -255,10 +262,8 @@ def step(self, state: State, action: jax.Array) -> State:
forward_reward = self._forward_reward_weight * velocity[0]

min_z, max_z = self._healthy_z_range
is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, x=0.0, y=1.0)
is_healthy = jp.where(
pipeline_state.x.pos[0, 2] > max_z, x=0.0, y=is_healthy
)
is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, 0.0, 1.0)
is_healthy = jp.where(pipeline_state.x.pos[0, 2] > max_z, 0.0, is_healthy)
if self._terminate_when_unhealthy:
healthy_reward = self._healthy_reward
else:
Expand Down
4 changes: 3 additions & 1 deletion brax/io/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@
# State attributes needed for the visualizer.
_STATE_ATTR = ['x', 'contact']

_IGNORE_FIELDS = ['_model', '_mjx_model']


def _to_dict(obj):
"""Converts python object to a json encodeable object."""
if isinstance(obj, list) or isinstance(obj, tuple):
return [_to_dict(s) for s in obj]
if isinstance(obj, dict):
return {k: _to_dict(v) for k, v in obj.items()}
return {k: _to_dict(v) for k, v in obj.items() if k not in _IGNORE_FIELDS}
if isinstance(obj, jax.Array):
return _to_dict(obj.tolist())
if hasattr(obj, '__dict__'):
Expand Down
1 change: 1 addition & 0 deletions brax/io/mjcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def load_model(mj: mujoco.MjModel) -> System:
)

sys = jax.tree_map(jp.array, sys)
sys.set_model(mj)

return sys

Expand Down
14 changes: 14 additions & 0 deletions brax/mjx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2023 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

30 changes: 30 additions & 0 deletions brax/mjx/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2023 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint:disable=g-multiple-import
"""Brax adapter for MJX physics engine."""

from brax import base
from flax import struct
from mujoco import mjx


@struct.dataclass
class State(base.State):
"""Dynamic state that changes after every step.
Attributes:
data: mjx.Data
"""
data: mjx.Data
43 changes: 43 additions & 0 deletions brax/mjx/perf_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2023 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint:disable=g-multiple-import
"""PBD perf tests."""

from absl.testing import absltest
from brax import test_utils
from brax.mjx import pipeline
import jax
from jax import numpy as jp


class PerfTest(absltest.TestCase):

def test_pipeline_ant(self):
sys = test_utils.load_fixture('ant.xml')

def init_fn(rng):
rng1, rng2 = jax.random.split(rng, 2)
q = jax.random.uniform(rng1, (sys.q_size(),), minval=-0.1, maxval=0.1)
qd = 0.1 * jax.random.normal(rng2, (sys.qd_size(),))
return pipeline.init(sys, q, qd)

def step_fn(state):
return pipeline.step(sys, state, jp.zeros(sys.act_size()))

test_utils.benchmark('mjx pipeline ant', init_fn, step_fn)


if __name__ == '__main__':
absltest.main()
89 changes: 89 additions & 0 deletions brax/mjx/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2023 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Physics pipeline for fully articulated dynamics and collisiion."""
# pylint:disable=g-multiple-import
# pylint:disable=g-importing-member
from brax.base import Motion, System, Transform
from brax.mjx.base import State
import jax
from jax import numpy as jp
from mujoco import mjx


def init(
sys: System, q: jax.Array, qd: jax.Array, debug: bool = False
) -> State:
"""Initializes physics state.
Args:
sys: a brax system
q: (q_size,) joint angle vector
qd: (qd_size,) joint velocity vector
debug: if True, adds contact to the state for debugging
Returns:
state: initial physics state
"""
del debug # ignored in mjx pipeline

model = sys.get_mjx_model()
data = mjx.make_data(model)
data = data.replace(qpos=q, qvel=qd)
data = mjx.forward(model, data)

q, qd = data.qpos, data.qvel
x = Transform(pos=data.xpos[1:], rot=data.xquat[1:])
cvel = Motion(vel=data.cvel[1:, 3:], ang=data.cvel[1:, :3])
offset = data.xpos[1:, :] - data.subtree_com[model.body_rootid[1:]]
offset = Transform.create(pos=offset)
xd = offset.vmap().do(cvel)
contact = None

return State(q, qd, x, xd, contact, data)


def step(
sys: System, state: State, act: jax.Array, debug: bool = False
) -> State:
"""Performs a single physics step using position-based dynamics.
Resolves actuator forces, joints, and forces at acceleration level, and
resolves collisions at velocity level with baumgarte stabilization.
Args:
sys: system defining the kinematic tree and other properties
state: physics state prior to step
act: (act_size,) actuator input vector
debug: if True, adds contact to the state for debugging
Returns:
x: updated link transform in world frame
xd: updated link motion in world frame
"""
del debug # ignored in mjx pipeline

model = sys.get_mjx_model()
data = state.data.replace(ctrl=act)
data = mjx.step(model, data)

q, qd = data.qpos, data.qvel
x = Transform(pos=data.xpos[1:], rot=data.xquat[1:])
cvel = Motion(vel=data.cvel[1:, 3:], ang=data.cvel[1:, :3])
offset = data.xpos[1:, :] - data.subtree_com[model.body_rootid[1:]]
offset = Transform.create(pos=offset)
xd = offset.vmap().do(cvel)
contact = None

return State(q, qd, x, xd, contact, data)
Loading

0 comments on commit f0cf109

Please sign in to comment.