diff --git a/mjx/mujoco/mjx/_src/collision_driver.py b/mjx/mujoco/mjx/_src/collision_driver.py index 721b4345cf..9eacaefb4a 100644 --- a/mjx/mujoco/mjx/_src/collision_driver.py +++ b/mjx/mujoco/mjx/_src/collision_driver.py @@ -1,4 +1,5 @@ # Copyright 2023 DeepMind Technologies Limited +# Copyright 2024 NVIDIA Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -362,6 +363,19 @@ def ncon(m: Union[Model, mujoco.MjModel]) -> int: return min(max_count, count) if max_count > -1 else count +def _take_k(mask, k): + """ Compute indices of the first k elements for which the mask is set. + If there are fewer than k elements with mask set, the second return value will + be 1 for the indices that must be ignored. If there are k or more indices + with mask set, the second return value will be zeros. + """ + index = jax.lax.associative_scan(jp.add, jp.array(mask, dtype=jp.int32)) + last = index[-1] + index = jp.where(mask, index - 1, k + 1) + # This is just a scatter, but jax.lax.scatter is discouraged to use directly. + seg_index = jax.ops.segment_sum(jp.arange(mask.shape[-1]), index, num_segments=k) + zero_mask = 1 - (jp.arange(k) < last) + return seg_index, zero_mask def collision(m: Model, d: Data) -> Data: """Collides geometries.""" @@ -382,8 +396,16 @@ def collision(m: Model, d: Data) -> Data: max_contact_points = _max_contact_points(m) if max_contact_points > -1 and contact.dist.shape[0] > max_contact_points: - # get top-k contacts - _, idx = jax.lax.top_k(-contact.dist, k=max_contact_points) - contact = jax.tree_map(lambda x, idx=idx: jp.take(x, idx, axis=0), contact) + # get first k active contacts + mask = contact.dist < 0.0 + mask_index, zero_mask = _take_k(mask, k=max_contact_points) + contact = jax.tree_map( + lambda x: x[mask_index], contact + ) + contact = contact.replace( + dist=contact.dist * (1 - zero_mask) + zero_mask, + solimp=contact.solimp * (1 - zero_mask)[:, None] + zero_mask[:, None] * jp.array([0.9, 0.95, 0.001, 0.5, 2.0]), + solref=contact.solref * (1 - zero_mask)[:, None] + zero_mask[:, None] * jp.array([0.002, 1.0]), + ) return d.replace(contact=contact)