Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace top-k with first-k thresholding #1449

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions mjx/mujoco/mjx/_src/collision_driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2023 DeepMind Technologies Limited
# Copyright 2024 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -362,6 +363,19 @@ def ncon(m: Union[Model, mujoco.MjModel]) -> int:

return min(max_count, count) if max_count > -1 else count

def _take_k(mask, k):
""" Compute indices of the first k elements for which the mask is set.
If there are fewer than k elements with mask set, the second return value will
be 1 for the indices that must be ignored. If there are k or more indices
with mask set, the second return value will be zeros.
"""
index = jax.lax.associative_scan(jp.add, jp.array(mask, dtype=jp.int32))
last = index[-1]
index = jp.where(mask, index - 1, k + 1)
# This is just a scatter, but jax.lax.scatter is discouraged to use directly.
seg_index = jax.ops.segment_sum(jp.arange(mask.shape[-1]), index, num_segments=k)
zero_mask = 1 - (jp.arange(k) < last)
return seg_index, zero_mask

def collision(m: Model, d: Data) -> Data:
"""Collides geometries."""
Expand All @@ -382,8 +396,16 @@ def collision(m: Model, d: Data) -> Data:

max_contact_points = _max_contact_points(m)
if max_contact_points > -1 and contact.dist.shape[0] > max_contact_points:
# get top-k contacts
_, idx = jax.lax.top_k(-contact.dist, k=max_contact_points)
contact = jax.tree_map(lambda x, idx=idx: jp.take(x, idx, axis=0), contact)
# get first k active contacts
mask = contact.dist < 0.0
mask_index, zero_mask = _take_k(mask, k=max_contact_points)
contact = jax.tree_map(
lambda x: x[mask_index], contact
)
contact = contact.replace(
dist=contact.dist * (1 - zero_mask) + zero_mask,
solimp=contact.solimp * (1 - zero_mask)[:, None] + zero_mask[:, None] * jp.array([0.9, 0.95, 0.001, 0.5, 2.0]),
solref=contact.solref * (1 - zero_mask)[:, None] + zero_mask[:, None] * jp.array([0.002, 1.0]),
)

return d.replace(contact=contact)