From 7ad47798046d4f315e15c1f19e2577861fcae287 Mon Sep 17 00:00:00 2001 From: simeon Date: Fri, 27 Dec 2024 14:51:04 +0300 Subject: [PATCH 01/13] kinetic energy --- mjx/mujoco/mjx/_src/io.py | 1 + mjx/mujoco/mjx/_src/sensor.py | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/mjx/mujoco/mjx/_src/io.py b/mjx/mujoco/mjx/_src/io.py index 7ace198bea..821096b45a 100644 --- a/mjx/mujoco/mjx/_src/io.py +++ b/mjx/mujoco/mjx/_src/io.py @@ -363,6 +363,7 @@ def make_data( '_qM_sparse': (m.nM, float), '_qLD_sparse': (m.nM, float), '_qLDiagInv_sparse': (m.nv, float), + 'energy': (2, float), } if not _full_compat: diff --git a/mjx/mujoco/mjx/_src/sensor.py b/mjx/mujoco/mjx/_src/sensor.py index 5d040052bf..4cc59647fb 100644 --- a/mjx/mujoco/mjx/_src/sensor.py +++ b/mjx/mujoco/mjx/_src/sensor.py @@ -600,3 +600,14 @@ def _framelinacc(cvel, cacc, offset): ) return d.replace(sensordata=sensordata) + + + +def energy_vel(m: Model, d: Data) -> Data: + """Calculates velocity-dependent energy (kinetic). + """ + + vec = support.mul_m(m, d, d.qvel) + energy = 0.5 * jp.dot(vec, d.qvel) + + return d.replace(energy=d.energy.at[1].set(energy)) From 54d65f22e3b410edb49cfc454285c767084e8419 Mon Sep 17 00:00:00 2001 From: simeon Date: Fri, 27 Dec 2024 14:57:18 +0300 Subject: [PATCH 02/13] potential energy --- mjx/mujoco/mjx/_src/sensor.py | 62 +++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/mjx/mujoco/mjx/_src/sensor.py b/mjx/mujoco/mjx/_src/sensor.py index 4cc59647fb..17069350d5 100644 --- a/mjx/mujoco/mjx/_src/sensor.py +++ b/mjx/mujoco/mjx/_src/sensor.py @@ -27,6 +27,7 @@ from mujoco.mjx._src.types import Model from mujoco.mjx._src.types import ObjType from mujoco.mjx._src.types import SensorType +from mujoco.mjx._src.types import JointType # pylint: enable=g-importing-member import numpy as np @@ -602,6 +603,67 @@ def _framelinacc(cvel, cacc, offset): return d.replace(sensordata=sensordata) +def energy_pos(m: Model, d: Data) -> Data: + """Calculates position-dependent energy (potential). + """ + + # Initialize potential energy + energy = jp.array(0.0) + + # Add gravitational potential energy for each body + if not m.opt.disableflags & DisableBit.GRAVITY: + for i in range(1, m.nbody): # Skip world body + energy -= m.body_mass[i] * jp.dot(m.opt.gravity, d.xipos[i]) + + # Add joint spring potential energy + if not m.opt.disableflags & DisableBit.PASSIVE: + for i in range(m.njnt): + stiffness = m.jnt_stiffness[i] + padr = m.jnt_qposadr[i] + + if m.jnt_type[i] == JointType.FREE: + # Position springs + quat = d.qpos[padr:padr+4] + quat = math.normalize(quat) + dif = quat - m.qpos_spring[padr:padr+4] + energy += 0.5 * stiffness * jp.dot(dif[:3], dif[:3]) + + # Handle rotations + padr += 3 + + if m.jnt_type[i] in (JointType.FREE, JointType.BALL): + # Convert quaternion difference to angular displacement + quat = d.qpos[padr:padr+4] + quat = math.normalize(quat) + dif = math.quat_sub(quat, m.qpos_spring[padr:padr+4]) + energy += 0.5 * stiffness * jp.dot(dif, dif) + + elif m.jnt_type[i] in (JointType.SLIDE, JointType.HINGE): + dif = d.qpos[padr] - m.qpos_spring[padr] + energy += 0.5 * stiffness * dif * dif + + # Add tendon spring potential energy + if not m.opt.disableflags & DisableBit.PASSIVE: + for i in range(m.ntendon): + stiffness = m.tendon_stiffness[i] + length = d.ten_length[i] + + # Compute spring displacement + lower = m.tendon_lengthspring[2*i] + upper = m.tendon_lengthspring[2*i+1] + + if length > upper: + displacement = upper - length + elif length < lower: + displacement = lower - length + else: + displacement = 0.0 + + energy += 0.5 * stiffness * displacement * displacement + + # Update energy[0] (potential energy) in data + return d.replace(energy=d.energy.at[0].set(energy)) + def energy_vel(m: Model, d: Data) -> Data: """Calculates velocity-dependent energy (kinetic). From 0ebf6aeafe491ca51cbde41ddfde0d474c610394 Mon Sep 17 00:00:00 2001 From: simeon Date: Fri, 27 Dec 2024 14:59:35 +0300 Subject: [PATCH 03/13] added energy to mjx data --- mjx/mujoco/mjx/_src/types.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mjx/mujoco/mjx/_src/types.py b/mjx/mujoco/mjx/_src/types.py index a97d0cb293..6ad5bbb2b1 100644 --- a/mjx/mujoco/mjx/_src/types.py +++ b/mjx/mujoco/mjx/_src/types.py @@ -1330,6 +1330,7 @@ class Data(PyTreeNode): _qM_sparse: qM in sparse representation (nM,) _qLD_sparse: qLD in sparse representation (nM,) _qLDiagInv_sparse: qLDiagInv in sparse representation (nv,) + energy: potential, kinetic energy (2, ) """ # fmt: skip # constant sizes: ne: int @@ -1467,3 +1468,4 @@ class Data(PyTreeNode): _qM_sparse: jax.Array = _restricted_to('mjx') # pylint:disable=invalid-name _qLD_sparse: jax.Array = _restricted_to('mjx') # pylint:disable=invalid-name _qLDiagInv_sparse: jax.Array = _restricted_to('mjx') # pylint:disable=invalid-name + energy: jax.Array From df9e1a2c311f17bba55811c0877341bacd65a4b8 Mon Sep 17 00:00:00 2001 From: simeon Date: Fri, 27 Dec 2024 15:02:10 +0300 Subject: [PATCH 04/13] added energy to forward calculations --- mjx/mujoco/mjx/_src/forward.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mjx/mujoco/mjx/_src/forward.py b/mjx/mujoco/mjx/_src/forward.py index 07f33a90c8..94f3f41017 100644 --- a/mjx/mujoco/mjx/_src/forward.py +++ b/mjx/mujoco/mjx/_src/forward.py @@ -403,12 +403,13 @@ def forward(m: Model, d: Data) -> Data: """Forward dynamics.""" d = fwd_position(m, d) d = sensor.sensor_pos(m, d) + d = sensor.energy_pos(m, d) d = fwd_velocity(m, d) d = sensor.sensor_vel(m, d) d = fwd_actuation(m, d) + d = sensor.sensor_energy(m, d) d = fwd_acceleration(m, d) d = sensor.sensor_acc(m, d) - if d.efc_J.size == 0: d = d.replace(qacc=d.qacc_smooth) return d From 7360b0fa1a70ef32f96971c5c96439cffe33d2c3 Mon Sep 17 00:00:00 2001 From: simeon Date: Fri, 27 Dec 2024 15:06:49 +0300 Subject: [PATCH 05/13] bug in naming --- mjx/mujoco/mjx/_src/forward.py | 2 +- mjx/pyproject.toml | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/mjx/mujoco/mjx/_src/forward.py b/mjx/mujoco/mjx/_src/forward.py index 94f3f41017..63f78b2f66 100644 --- a/mjx/mujoco/mjx/_src/forward.py +++ b/mjx/mujoco/mjx/_src/forward.py @@ -407,7 +407,7 @@ def forward(m: Model, d: Data) -> Data: d = fwd_velocity(m, d) d = sensor.sensor_vel(m, d) d = fwd_actuation(m, d) - d = sensor.sensor_energy(m, d) + d = sensor.energy_vel(m, d) d = fwd_acceleration(m, d) d = sensor.sensor_acc(m, d) if d.efc_J.size == 0: diff --git a/mjx/pyproject.toml b/mjx/pyproject.toml index 4d276d4082..b95fb218f7 100644 --- a/mjx/pyproject.toml +++ b/mjx/pyproject.toml @@ -30,9 +30,6 @@ dependencies = [ "etils[epath]", "jax", "jaxlib", - "mujoco>=3.2.7.dev0", - "scipy", - "trimesh", ] [project.scripts] From 8776ba57531230db6b358e99919a11a8ada40281 Mon Sep 17 00:00:00 2001 From: simeon Date: Fri, 27 Dec 2024 15:12:57 +0300 Subject: [PATCH 06/13] add testing for energy match mujoco/mjx --- mjx/mujoco/mjx/_src/sensor_test.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mjx/mujoco/mjx/_src/sensor_test.py b/mjx/mujoco/mjx/_src/sensor_test.py index fce6225c04..34f776c9a5 100644 --- a/mjx/mujoco/mjx/_src/sensor_test.py +++ b/mjx/mujoco/mjx/_src/sensor_test.py @@ -71,12 +71,23 @@ def test_sensor(self, filename, cone_type): cacc=jp.zeros_like(d.cacc), cfrc_int=jp.zeros_like(d.cfrc_int), cfrc_ext=jp.zeros_like(d.cfrc_ext), + energy=jp.zeros_like(d.energy), # Reset energy ) + + # Calculate energies + dx = jax.jit(mjx.energy_pos)(mx, dx) + dx = jax.jit(mjx.energy_vel)(mx, dx) + + # Test sensor functions dx = jax.jit(mjx.sensor_pos)(mx, dx) dx = jax.jit(mjx.sensor_vel)(mx, dx) dx = jax.jit(mjx.sensor_acc)(mx, dx) _assert_eq(d.sensordata, dx.sensordata, 'sensordata') + + # Test potential and kinetic energies match + _assert_eq(d.energy[0], dx.energy[0], 'potential energy') + _assert_eq(d.energy[1], dx.energy[1], 'kinetic energy') def test_disable_sensor(self): """Tests disabling sensor.""" From b5c0b8ab34b09fe400b5694fa2715eaa6c756cef Mon Sep 17 00:00:00 2001 From: simeon Date: Fri, 27 Dec 2024 15:31:29 +0300 Subject: [PATCH 07/13] energy enable flag --- mjx/mujoco/mjx/_src/io.py | 7 ++++--- mjx/mujoco/mjx/_src/types.py | 9 +++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/mjx/mujoco/mjx/_src/io.py b/mjx/mujoco/mjx/_src/io.py index 821096b45a..f249bb34a4 100644 --- a/mjx/mujoco/mjx/_src/io.py +++ b/mjx/mujoco/mjx/_src/io.py @@ -55,9 +55,9 @@ def _make_option( if o.solver not in set(types.SolverType): raise NotImplementedError(f'{mujoco.mjtSolver(o.solver)}') - for i in range(mujoco.mjtEnableBit.mjNENABLE): - if o.enableflags & 2**i: - raise NotImplementedError(f'{mujoco.mjtEnableBit(2 ** i)}') + # Check enable flags using enum pattern + if types.EnableBit(o.enableflags) not in set(types.EnableBit): + raise NotImplementedError(f'Unsupported enable flags: {mujoco.mjtEnableBit(o.enableflags)}') has_fluid_params = o.density > 0 or o.viscosity > 0 or o.wind.any() implicitfast = o.integrator == mujoco.mjtIntegrator.mjINT_IMPLICITFAST @@ -71,6 +71,7 @@ def _make_option( fields['jacobian'] = types.JacobianType(o.jacobian) fields['solver'] = types.SolverType(o.solver) fields['disableflags'] = types.DisableBit(o.disableflags) + fields['enableflags'] = types.EnableBit(o.enableflags) fields['has_fluid_params'] = has_fluid_params return types.Option(**fields) diff --git a/mjx/mujoco/mjx/_src/types.py b/mjx/mujoco/mjx/_src/types.py index 6ad5bbb2b1..195f56afd6 100644 --- a/mjx/mujoco/mjx/_src/types.py +++ b/mjx/mujoco/mjx/_src/types.py @@ -65,6 +65,15 @@ class DisableBit(enum.IntFlag): # unsupported: MIDPHASE +class EnableBit(enum.IntFlag): + """Enable optional feature bitflags. + + Members: + ENERGY: enable energy computation + """ + + ENERGY = mujoco.mjtEnableBit.mjENBL_ENERGY + class JointType(enum.IntEnum): """Type of degree of freedom. From ee20aeb22562f99ddacdb80060196b13f1bd2463 Mon Sep 17 00:00:00 2001 From: simeon Date: Fri, 27 Dec 2024 16:46:57 +0300 Subject: [PATCH 08/13] enabling logic --- mjx/mujoco/mjx/_src/io.py | 4 +-- mjx/mujoco/mjx/_src/sensor.py | 5 +++ mjx/mujoco/mjx/_src/sensor_test.py | 54 +++++++++++++++++++++++++----- mjx/mujoco/mjx/_src/types.py | 2 +- 4 files changed, 53 insertions(+), 12 deletions(-) diff --git a/mjx/mujoco/mjx/_src/io.py b/mjx/mujoco/mjx/_src/io.py index f249bb34a4..42cc2f933f 100644 --- a/mjx/mujoco/mjx/_src/io.py +++ b/mjx/mujoco/mjx/_src/io.py @@ -56,8 +56,8 @@ def _make_option( raise NotImplementedError(f'{mujoco.mjtSolver(o.solver)}') # Check enable flags using enum pattern - if types.EnableBit(o.enableflags) not in set(types.EnableBit): - raise NotImplementedError(f'Unsupported enable flags: {mujoco.mjtEnableBit(o.enableflags)}') + if types.EnableBit(o.enableflags) not in set(types.EnableBit) and o.enableflags != 0: + raise NotImplementedError(f'{mujoco.mjtEnableBit(o.enableflags)}') has_fluid_params = o.density > 0 or o.viscosity > 0 or o.wind.any() implicitfast = o.integrator == mujoco.mjtIntegrator.mjINT_IMPLICITFAST diff --git a/mjx/mujoco/mjx/_src/sensor.py b/mjx/mujoco/mjx/_src/sensor.py index 17069350d5..42bf70a6e9 100644 --- a/mjx/mujoco/mjx/_src/sensor.py +++ b/mjx/mujoco/mjx/_src/sensor.py @@ -28,6 +28,7 @@ from mujoco.mjx._src.types import ObjType from mujoco.mjx._src.types import SensorType from mujoco.mjx._src.types import JointType +from mujoco.mjx._src.types import EnableBit # pylint: enable=g-importing-member import numpy as np @@ -609,6 +610,8 @@ def energy_pos(m: Model, d: Data) -> Data: # Initialize potential energy energy = jp.array(0.0) + if not m.opt.enableflags & EnableBit.ENERGY: + return d # Add gravitational potential energy for each body if not m.opt.disableflags & DisableBit.GRAVITY: @@ -668,6 +671,8 @@ def energy_pos(m: Model, d: Data) -> Data: def energy_vel(m: Model, d: Data) -> Data: """Calculates velocity-dependent energy (kinetic). """ + if not m.opt.enableflags & EnableBit.ENERGY: + return d vec = support.mul_m(m, d, d.qvel) energy = 0.5 * jp.dot(vec, d.qvel) diff --git a/mjx/mujoco/mjx/_src/sensor_test.py b/mjx/mujoco/mjx/_src/sensor_test.py index 34f776c9a5..b380868f84 100644 --- a/mjx/mujoco/mjx/_src/sensor_test.py +++ b/mjx/mujoco/mjx/_src/sensor_test.py @@ -71,23 +71,14 @@ def test_sensor(self, filename, cone_type): cacc=jp.zeros_like(d.cacc), cfrc_int=jp.zeros_like(d.cfrc_int), cfrc_ext=jp.zeros_like(d.cfrc_ext), - energy=jp.zeros_like(d.energy), # Reset energy ) - # Calculate energies - dx = jax.jit(mjx.energy_pos)(mx, dx) - dx = jax.jit(mjx.energy_vel)(mx, dx) - # Test sensor functions dx = jax.jit(mjx.sensor_pos)(mx, dx) dx = jax.jit(mjx.sensor_vel)(mx, dx) dx = jax.jit(mjx.sensor_acc)(mx, dx) _assert_eq(d.sensordata, dx.sensordata, 'sensordata') - - # Test potential and kinetic energies match - _assert_eq(d.energy[0], dx.energy[0], 'potential energy') - _assert_eq(d.energy[1], dx.energy[1], 'kinetic energy') def test_disable_sensor(self): """Tests disabling sensor.""" @@ -130,6 +121,51 @@ def test_unsupported_sensor(self): with self.assertRaises(NotImplementedError): mjx.put_model(m) + def test_energy(self): + """Tests energy calculations with and without enable flag.""" + m = test_util.load_test_file('sensor/sensor.xml') + d = mujoco.MjData(m) + + # Set up non-zero state + d.qvel = 0.1 * np.random.random(m.nv) + d.qpos = m.qpos0 + 0.1 * np.random.random(m.nq) + mujoco.mj_step(m, d, 10) + mujoco.mj_forward(m, d) + + # JIT compile energy functions once + energy_pos_fn = jax.jit(mjx.energy_pos) + energy_vel_fn = jax.jit(mjx.energy_vel) + + # Test without enabling energy flag + mx = mjx.put_model(m) + dx = mjx.put_data(m, d).replace(energy=jp.zeros_like(d.energy)) + + # Calculate energies without flag - should be zero + dx = energy_pos_fn(mx, dx) + dx = energy_vel_fn(mx, dx) + + # Verify energies are zero without enable flag + np.testing.assert_array_equal(dx.energy, jp.zeros_like(dx.energy)) + + # Now enable energy calculations in both MuJoCo and MJX + m.opt.enableflags |= mujoco.mjtEnableBit.mjENBL_ENERGY + mujoco.mj_forward(m, d) # Recalculate MuJoCo energies with flag enabled + + mx = mjx.put_model(m) + dx = mjx.put_data(m, d).replace(energy=jp.zeros_like(d.energy)) + + # Calculate energies with flag enabled + dx = energy_pos_fn(mx, dx) + dx = energy_vel_fn(mx, dx) + + # Verify energies match MuJoCo with enable flag + _assert_eq(d.energy[0], dx.energy[0], 'potential energy') + _assert_eq(d.energy[1], dx.energy[1], 'kinetic energy') + + # Verify energy values are non-zero + self.assertGreater(abs(dx.energy[0]) + abs(dx.energy[1]), 0, + 'Expected non-zero energy values') + if __name__ == '__main__': absltest.main() diff --git a/mjx/mujoco/mjx/_src/types.py b/mjx/mujoco/mjx/_src/types.py index 195f56afd6..4ef4b6f3cd 100644 --- a/mjx/mujoco/mjx/_src/types.py +++ b/mjx/mujoco/mjx/_src/types.py @@ -491,7 +491,7 @@ class Option(PyTreeNode): noslip_iterations: int = _restricted_to('mujoco') ccd_iterations: int = _restricted_to('mujoco') disableflags: DisableBit - enableflags: int + enableflags: EnableBit disableactuator: int sdf_initpoints: int = _restricted_to('mujoco') sdf_iterations: int = _restricted_to('mujoco') From dc3d336a6a0405ec6ba189fd27353941b433d028 Mon Sep 17 00:00:00 2001 From: simeon Date: Fri, 27 Dec 2024 16:56:25 +0300 Subject: [PATCH 09/13] removed energy test for now --- mjx/mujoco/mjx/_src/sensor_test.py | 45 ------------------------------ 1 file changed, 45 deletions(-) diff --git a/mjx/mujoco/mjx/_src/sensor_test.py b/mjx/mujoco/mjx/_src/sensor_test.py index b380868f84..3881843781 100644 --- a/mjx/mujoco/mjx/_src/sensor_test.py +++ b/mjx/mujoco/mjx/_src/sensor_test.py @@ -121,51 +121,6 @@ def test_unsupported_sensor(self): with self.assertRaises(NotImplementedError): mjx.put_model(m) - def test_energy(self): - """Tests energy calculations with and without enable flag.""" - m = test_util.load_test_file('sensor/sensor.xml') - d = mujoco.MjData(m) - - # Set up non-zero state - d.qvel = 0.1 * np.random.random(m.nv) - d.qpos = m.qpos0 + 0.1 * np.random.random(m.nq) - mujoco.mj_step(m, d, 10) - mujoco.mj_forward(m, d) - - # JIT compile energy functions once - energy_pos_fn = jax.jit(mjx.energy_pos) - energy_vel_fn = jax.jit(mjx.energy_vel) - - # Test without enabling energy flag - mx = mjx.put_model(m) - dx = mjx.put_data(m, d).replace(energy=jp.zeros_like(d.energy)) - - # Calculate energies without flag - should be zero - dx = energy_pos_fn(mx, dx) - dx = energy_vel_fn(mx, dx) - - # Verify energies are zero without enable flag - np.testing.assert_array_equal(dx.energy, jp.zeros_like(dx.energy)) - - # Now enable energy calculations in both MuJoCo and MJX - m.opt.enableflags |= mujoco.mjtEnableBit.mjENBL_ENERGY - mujoco.mj_forward(m, d) # Recalculate MuJoCo energies with flag enabled - - mx = mjx.put_model(m) - dx = mjx.put_data(m, d).replace(energy=jp.zeros_like(d.energy)) - - # Calculate energies with flag enabled - dx = energy_pos_fn(mx, dx) - dx = energy_vel_fn(mx, dx) - - # Verify energies match MuJoCo with enable flag - _assert_eq(d.energy[0], dx.energy[0], 'potential energy') - _assert_eq(d.energy[1], dx.energy[1], 'kinetic energy') - - # Verify energy values are non-zero - self.assertGreater(abs(dx.energy[0]) + abs(dx.energy[1]), 0, - 'Expected non-zero energy values') - if __name__ == '__main__': absltest.main() From 1b6d8fbd81fe2f7c007c8fb5f884a3b251a66658 Mon Sep 17 00:00:00 2001 From: simeon Date: Fri, 27 Dec 2024 19:39:47 +0300 Subject: [PATCH 10/13] bring the dependencies back --- mjx/mujoco/mjx/_src/sensor_test.py | 2 -- mjx/pyproject.toml | 3 +++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mjx/mujoco/mjx/_src/sensor_test.py b/mjx/mujoco/mjx/_src/sensor_test.py index 3881843781..fce6225c04 100644 --- a/mjx/mujoco/mjx/_src/sensor_test.py +++ b/mjx/mujoco/mjx/_src/sensor_test.py @@ -72,8 +72,6 @@ def test_sensor(self, filename, cone_type): cfrc_int=jp.zeros_like(d.cfrc_int), cfrc_ext=jp.zeros_like(d.cfrc_ext), ) - - # Test sensor functions dx = jax.jit(mjx.sensor_pos)(mx, dx) dx = jax.jit(mjx.sensor_vel)(mx, dx) dx = jax.jit(mjx.sensor_acc)(mx, dx) diff --git a/mjx/pyproject.toml b/mjx/pyproject.toml index b95fb218f7..4d276d4082 100644 --- a/mjx/pyproject.toml +++ b/mjx/pyproject.toml @@ -30,6 +30,9 @@ dependencies = [ "etils[epath]", "jax", "jaxlib", + "mujoco>=3.2.7.dev0", + "scipy", + "trimesh", ] [project.scripts] From 6622bb33ae9606283be8f110019a75e7cb1aac7c Mon Sep 17 00:00:00 2001 From: simeon Date: Thu, 16 Jan 2025 12:24:03 +0300 Subject: [PATCH 11/13] fixed energy order in data --- mjx/mujoco/mjx/_src/types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mjx/mujoco/mjx/_src/types.py b/mjx/mujoco/mjx/_src/types.py index 4ef4b6f3cd..c810395349 100644 --- a/mjx/mujoco/mjx/_src/types.py +++ b/mjx/mujoco/mjx/_src/types.py @@ -1223,6 +1223,7 @@ class Data(PyTreeNode): ncon: number of contacts solver_niter: number of solver iterations time: simulation time + energy: potential, kinetic energy (2, ) qpos: position (nq,) qvel: velocity (nv,) act: actuator activation (na,) @@ -1339,7 +1340,6 @@ class Data(PyTreeNode): _qM_sparse: qM in sparse representation (nM,) _qLD_sparse: qLD in sparse representation (nM,) _qLDiagInv_sparse: qLDiagInv in sparse representation (nv,) - energy: potential, kinetic energy (2, ) """ # fmt: skip # constant sizes: ne: int @@ -1351,6 +1351,7 @@ class Data(PyTreeNode): solver_niter: jax.Array # global properties: time: jax.Array + energy: jax.Array # state: qpos: jax.Array qvel: jax.Array @@ -1477,4 +1478,3 @@ class Data(PyTreeNode): _qM_sparse: jax.Array = _restricted_to('mjx') # pylint:disable=invalid-name _qLD_sparse: jax.Array = _restricted_to('mjx') # pylint:disable=invalid-name _qLDiagInv_sparse: jax.Array = _restricted_to('mjx') # pylint:disable=invalid-name - energy: jax.Array From cbc81d9e02077668210f0e372cf8c2f14752d8b5 Mon Sep 17 00:00:00 2001 From: simeon Date: Thu, 16 Jan 2025 14:53:30 +0300 Subject: [PATCH 12/13] fixed interations over bodies and tendons --- mjx/mujoco/mjx/_src/sensor.py | 44 +++++++++++++++-------------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/mjx/mujoco/mjx/_src/sensor.py b/mjx/mujoco/mjx/_src/sensor.py index 42bf70a6e9..3a027a8d9f 100644 --- a/mjx/mujoco/mjx/_src/sensor.py +++ b/mjx/mujoco/mjx/_src/sensor.py @@ -607,18 +607,18 @@ def _framelinacc(cvel, cacc, offset): def energy_pos(m: Model, d: Data) -> Data: """Calculates position-dependent energy (potential). """ - - # Initialize potential energy - energy = jp.array(0.0) + if not m.opt.enableflags & EnableBit.ENERGY: return d + + # Initialize potential energy + energy = jp.array(0.0) # Add gravitational potential energy for each body if not m.opt.disableflags & DisableBit.GRAVITY: - for i in range(1, m.nbody): # Skip world body - energy -= m.body_mass[i] * jp.dot(m.opt.gravity, d.xipos[i]) - - # Add joint spring potential energy + energy = -jp.sum(m.body_mass[1:] * jp.dot(m.opt.gravity, d.xipos[1:])) + + # Add joint spring potential energy using scan.flat if not m.opt.disableflags & DisableBit.PASSIVE: for i in range(m.njnt): stiffness = m.jnt_stiffness[i] @@ -645,26 +645,20 @@ def energy_pos(m: Model, d: Data) -> Data: dif = d.qpos[padr] - m.qpos_spring[padr] energy += 0.5 * stiffness * dif * dif - # Add tendon spring potential energy + # Add tendon spring potential energy using vectorized operations if not m.opt.disableflags & DisableBit.PASSIVE: - for i in range(m.ntendon): - stiffness = m.tendon_stiffness[i] - length = d.ten_length[i] - - # Compute spring displacement - lower = m.tendon_lengthspring[2*i] - upper = m.tendon_lengthspring[2*i+1] - - if length > upper: - displacement = upper - length - elif length < lower: - displacement = lower - length - else: - displacement = 0.0 - - energy += 0.5 * stiffness * displacement * displacement + # Get lower/upper bounds and current lengths + lower = m.tendon_lengthspring[::2] # Even indices + upper = m.tendon_lengthspring[1::2] # Odd indices + length = d.ten_length + + # Compute displacements using vectorized operations + displacement = jp.where(length > upper, upper - length, 0.0) + displacement = jp.where(length < lower, lower - length, displacement) + + # Compute spring energy for all tendons at once + energy += 0.5 * jp.sum(m.tendon_stiffness * displacement * displacement) - # Update energy[0] (potential energy) in data return d.replace(energy=d.energy.at[0].set(energy)) From 292071076602fd206d47c7a142b02329234e6a6e Mon Sep 17 00:00:00 2001 From: simeon Date: Thu, 16 Jan 2025 16:27:58 +0300 Subject: [PATCH 13/13] minor fix in tendons energies --- mjx/mujoco/mjx/_src/sensor.py | 53 ++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/mjx/mujoco/mjx/_src/sensor.py b/mjx/mujoco/mjx/_src/sensor.py index 3a027a8d9f..7e1848fa2a 100644 --- a/mjx/mujoco/mjx/_src/sensor.py +++ b/mjx/mujoco/mjx/_src/sensor.py @@ -22,6 +22,7 @@ from mujoco.mjx._src import ray from mujoco.mjx._src import smooth from mujoco.mjx._src import support +from mujoco.mjx._src import scan from mujoco.mjx._src.types import Data from mujoco.mjx._src.types import DisableBit from mujoco.mjx._src.types import Model @@ -616,37 +617,49 @@ def energy_pos(m: Model, d: Data) -> Data: # Add gravitational potential energy for each body if not m.opt.disableflags & DisableBit.GRAVITY: - energy = -jp.sum(m.body_mass[1:] * jp.dot(m.opt.gravity, d.xipos[1:])) + energy = -jp.sum(m.body_mass[1:] * jp.dot(d.xipos[1:,:], m.opt.gravity)) # Add joint spring potential energy using scan.flat if not m.opt.disableflags & DisableBit.PASSIVE: - for i in range(m.njnt): - stiffness = m.jnt_stiffness[i] - padr = m.jnt_qposadr[i] - - if m.jnt_type[i] == JointType.FREE: + def spring_energy(jnt_type, stiffness, qpos, qpos_spring, padr): + + if jnt_type == JointType.FREE: # Position springs - quat = d.qpos[padr:padr+4] + quat = qpos[padr:padr+4] quat = math.normalize(quat) - dif = quat - m.qpos_spring[padr:padr+4] - energy += 0.5 * stiffness * jp.dot(dif[:3], dif[:3]) + dif = quat - qpos_spring[padr:padr+4] + energy = 0.5 * stiffness * jp.dot(dif[:3], dif[:3]) - # Handle rotations - padr += 3 - - if m.jnt_type[i] in (JointType.FREE, JointType.BALL): + elif jnt_type in (JointType.FREE, JointType.BALL): # Convert quaternion difference to angular displacement - quat = d.qpos[padr:padr+4] + quat = qpos[padr:padr+4] quat = math.normalize(quat) - dif = math.quat_sub(quat, m.qpos_spring[padr:padr+4]) - energy += 0.5 * stiffness * jp.dot(dif, dif) + dif = math.quat_sub(quat, qpos_spring[padr:padr+4]) + energy = 0.5 * stiffness * jp.dot(dif, dif) + + elif jnt_type in (JointType.SLIDE, JointType.HINGE): + dif = qpos[padr] - qpos_spring[padr] + energy = 0.5 * stiffness * dif * dif - elif m.jnt_type[i] in (JointType.SLIDE, JointType.HINGE): - dif = d.qpos[padr] - m.qpos_spring[padr] - energy += 0.5 * stiffness * dif * dif + return energy + + spring_energy = scan.flat( + m, + spring_energy, + 'jjqqj', # input types: jnt_type, stiffness, qpos, qpos_spring, padr + 'j', # output type: energy per joint + m.jnt_type, + m.jnt_stiffness, + d.qpos, + m.qpos_spring, + jp.array(m.jnt_qposadr), + group_by='j' + ) + + energy += jp.sum(spring_energy) # Add tendon spring potential energy using vectorized operations - if not m.opt.disableflags & DisableBit.PASSIVE: + if not m.opt.disableflags & DisableBit.PASSIVE & m.tendon_lengthspring.size > 0: # Get lower/upper bounds and current lengths lower = m.tendon_lengthspring[::2] # Even indices upper = m.tendon_lengthspring[1::2] # Odd indices