diff --git a/mjx/mujoco/mjx/_src/forward.py b/mjx/mujoco/mjx/_src/forward.py index 07f33a90c8..63f78b2f66 100644 --- a/mjx/mujoco/mjx/_src/forward.py +++ b/mjx/mujoco/mjx/_src/forward.py @@ -403,12 +403,13 @@ def forward(m: Model, d: Data) -> Data: """Forward dynamics.""" d = fwd_position(m, d) d = sensor.sensor_pos(m, d) + d = sensor.energy_pos(m, d) d = fwd_velocity(m, d) d = sensor.sensor_vel(m, d) d = fwd_actuation(m, d) + d = sensor.energy_vel(m, d) d = fwd_acceleration(m, d) d = sensor.sensor_acc(m, d) - if d.efc_J.size == 0: d = d.replace(qacc=d.qacc_smooth) return d diff --git a/mjx/mujoco/mjx/_src/io.py b/mjx/mujoco/mjx/_src/io.py index 7ace198bea..42cc2f933f 100644 --- a/mjx/mujoco/mjx/_src/io.py +++ b/mjx/mujoco/mjx/_src/io.py @@ -55,9 +55,9 @@ def _make_option( if o.solver not in set(types.SolverType): raise NotImplementedError(f'{mujoco.mjtSolver(o.solver)}') - for i in range(mujoco.mjtEnableBit.mjNENABLE): - if o.enableflags & 2**i: - raise NotImplementedError(f'{mujoco.mjtEnableBit(2 ** i)}') + # Check enable flags using enum pattern + if types.EnableBit(o.enableflags) not in set(types.EnableBit) and o.enableflags != 0: + raise NotImplementedError(f'{mujoco.mjtEnableBit(o.enableflags)}') has_fluid_params = o.density > 0 or o.viscosity > 0 or o.wind.any() implicitfast = o.integrator == mujoco.mjtIntegrator.mjINT_IMPLICITFAST @@ -71,6 +71,7 @@ def _make_option( fields['jacobian'] = types.JacobianType(o.jacobian) fields['solver'] = types.SolverType(o.solver) fields['disableflags'] = types.DisableBit(o.disableflags) + fields['enableflags'] = types.EnableBit(o.enableflags) fields['has_fluid_params'] = has_fluid_params return types.Option(**fields) @@ -363,6 +364,7 @@ def make_data( '_qM_sparse': (m.nM, float), '_qLD_sparse': (m.nM, float), '_qLDiagInv_sparse': (m.nv, float), + 'energy': (2, float), } if not _full_compat: diff --git a/mjx/mujoco/mjx/_src/sensor.py b/mjx/mujoco/mjx/_src/sensor.py index 5d040052bf..7e1848fa2a 100644 --- a/mjx/mujoco/mjx/_src/sensor.py +++ b/mjx/mujoco/mjx/_src/sensor.py @@ -22,11 +22,14 @@ from mujoco.mjx._src import ray from mujoco.mjx._src import smooth from mujoco.mjx._src import support +from mujoco.mjx._src import scan from mujoco.mjx._src.types import Data from mujoco.mjx._src.types import DisableBit from mujoco.mjx._src.types import Model from mujoco.mjx._src.types import ObjType from mujoco.mjx._src.types import SensorType +from mujoco.mjx._src.types import JointType +from mujoco.mjx._src.types import EnableBit # pylint: enable=g-importing-member import numpy as np @@ -600,3 +603,85 @@ def _framelinacc(cvel, cacc, offset): ) return d.replace(sensordata=sensordata) + + +def energy_pos(m: Model, d: Data) -> Data: + """Calculates position-dependent energy (potential). + """ + + if not m.opt.enableflags & EnableBit.ENERGY: + return d + + # Initialize potential energy + energy = jp.array(0.0) + + # Add gravitational potential energy for each body + if not m.opt.disableflags & DisableBit.GRAVITY: + energy = -jp.sum(m.body_mass[1:] * jp.dot(d.xipos[1:,:], m.opt.gravity)) + + # Add joint spring potential energy using scan.flat + if not m.opt.disableflags & DisableBit.PASSIVE: + def spring_energy(jnt_type, stiffness, qpos, qpos_spring, padr): + + if jnt_type == JointType.FREE: + # Position springs + quat = qpos[padr:padr+4] + quat = math.normalize(quat) + dif = quat - qpos_spring[padr:padr+4] + energy = 0.5 * stiffness * jp.dot(dif[:3], dif[:3]) + + elif jnt_type in (JointType.FREE, JointType.BALL): + # Convert quaternion difference to angular displacement + quat = qpos[padr:padr+4] + quat = math.normalize(quat) + dif = math.quat_sub(quat, qpos_spring[padr:padr+4]) + energy = 0.5 * stiffness * jp.dot(dif, dif) + + elif jnt_type in (JointType.SLIDE, JointType.HINGE): + dif = qpos[padr] - qpos_spring[padr] + energy = 0.5 * stiffness * dif * dif + + return energy + + spring_energy = scan.flat( + m, + spring_energy, + 'jjqqj', # input types: jnt_type, stiffness, qpos, qpos_spring, padr + 'j', # output type: energy per joint + m.jnt_type, + m.jnt_stiffness, + d.qpos, + m.qpos_spring, + jp.array(m.jnt_qposadr), + group_by='j' + ) + + energy += jp.sum(spring_energy) + + # Add tendon spring potential energy using vectorized operations + if not m.opt.disableflags & DisableBit.PASSIVE & m.tendon_lengthspring.size > 0: + # Get lower/upper bounds and current lengths + lower = m.tendon_lengthspring[::2] # Even indices + upper = m.tendon_lengthspring[1::2] # Odd indices + length = d.ten_length + + # Compute displacements using vectorized operations + displacement = jp.where(length > upper, upper - length, 0.0) + displacement = jp.where(length < lower, lower - length, displacement) + + # Compute spring energy for all tendons at once + energy += 0.5 * jp.sum(m.tendon_stiffness * displacement * displacement) + + return d.replace(energy=d.energy.at[0].set(energy)) + + +def energy_vel(m: Model, d: Data) -> Data: + """Calculates velocity-dependent energy (kinetic). + """ + if not m.opt.enableflags & EnableBit.ENERGY: + return d + + vec = support.mul_m(m, d, d.qvel) + energy = 0.5 * jp.dot(vec, d.qvel) + + return d.replace(energy=d.energy.at[1].set(energy)) diff --git a/mjx/mujoco/mjx/_src/types.py b/mjx/mujoco/mjx/_src/types.py index a97d0cb293..c810395349 100644 --- a/mjx/mujoco/mjx/_src/types.py +++ b/mjx/mujoco/mjx/_src/types.py @@ -65,6 +65,15 @@ class DisableBit(enum.IntFlag): # unsupported: MIDPHASE +class EnableBit(enum.IntFlag): + """Enable optional feature bitflags. + + Members: + ENERGY: enable energy computation + """ + + ENERGY = mujoco.mjtEnableBit.mjENBL_ENERGY + class JointType(enum.IntEnum): """Type of degree of freedom. @@ -482,7 +491,7 @@ class Option(PyTreeNode): noslip_iterations: int = _restricted_to('mujoco') ccd_iterations: int = _restricted_to('mujoco') disableflags: DisableBit - enableflags: int + enableflags: EnableBit disableactuator: int sdf_initpoints: int = _restricted_to('mujoco') sdf_iterations: int = _restricted_to('mujoco') @@ -1214,6 +1223,7 @@ class Data(PyTreeNode): ncon: number of contacts solver_niter: number of solver iterations time: simulation time + energy: potential, kinetic energy (2, ) qpos: position (nq,) qvel: velocity (nv,) act: actuator activation (na,) @@ -1341,6 +1351,7 @@ class Data(PyTreeNode): solver_niter: jax.Array # global properties: time: jax.Array + energy: jax.Array # state: qpos: jax.Array qvel: jax.Array