Skip to content

Commit

Permalink
Throw EpisodeInitializationError instead of raw RuntimeError when pro…
Browse files Browse the repository at this point in the history
…p initializer fails.

The more specific error allows callers to catch it if need be. Existing code should continue to work because EpisodeInitializationError is a subclass of RuntimeError.

PiperOrigin-RevId: 718344531
Change-Id: I344dd518dab878b36e9a99cd0f4ea3ca4fffcf2a
  • Loading branch information
nimrod-gileadi authored and copybara-github committed Jan 22, 2025
1 parent 3634882 commit 316e0f1
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 18 deletions.
18 changes: 11 additions & 7 deletions dm_control/composer/initializers/prop_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self,
velocities are less than this threshold.
max_attempts_per_prop: The maximum number of rejection sampling attempts
per prop. If a non-colliding pose cannot be found before this limit is
reached, a `RuntimeError` will be raised.
reached, an `EpisodeInitializationError` will be raised.
settle_physics: (optional) If True, the physics simulation will be
advanced for a few steps to allow the prop positions to settle.
min_settle_physics_time: (optional) When `settle_physics` is True, lower
Expand Down Expand Up @@ -170,8 +170,9 @@ def __call__(self, physics, random_state, ignore_contacts_with_entities=None):
subsequently).
Raises:
RuntimeError: If `ignore_collisions == False` and a non-colliding prop
pose could not be found within `max_attempts_per_prop`.
EpisodeInitializationError: If `ignore_collisions == False` and a
non-colliding prop pose could not be found within
`max_attempts_per_prop`.
"""
if ignore_contacts_with_entities is None:
ignore_contacts_with_entities = []
Expand Down Expand Up @@ -222,9 +223,12 @@ def place_props():
break

if not success:
raise RuntimeError(_REJECTION_SAMPLING_FAILED.format(
model_name=prop.mjcf_model.model,
max_attempts=self._max_attempts_per_prop))
raise composer.EpisodeInitializationError(
_REJECTION_SAMPLING_FAILED.format(
model_name=prop.mjcf_model.model,
max_attempts=self._max_attempts_per_prop,
)
)

for prop in ignore_contacts_with_entities:
self._restore_contact_parameters(physics, prop, cached_contact_params)
Expand Down Expand Up @@ -256,7 +260,7 @@ def place_and_settle():
physics.data.time = original_time

if self._raise_exception_on_settle_failure:
raise RuntimeError(
raise composer.EpisodeInitializationError(
_SETTLING_PHYSICS_FAILED.format(
max_attempts=self._max_settle_physics_attempts,
max_time=self._max_settle_physics_time,
Expand Down
8 changes: 6 additions & 2 deletions dm_control/composer/initializers/prop_initializer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def test_rejection_sampling_failure(self):
expected_message = prop_initializer._REJECTION_SAMPLING_FAILED.format(
model_name=spheres[1].mjcf_model.model, # Props are placed in order.
max_attempts=max_attempts_per_prop)
with self.assertRaisesWithLiteralMatch(RuntimeError, expected_message):
with self.assertRaisesWithLiteralMatch(
composer.EpisodeInitializationError, expected_message
):
prop_placer(physics, random_state=np.random.RandomState(0))

def test_ignore_contacts_with_entities(self):
Expand Down Expand Up @@ -141,7 +143,9 @@ def test_ignore_contacts_with_entities(self):
prop_placer_init(physics, random_state=np.random.RandomState(0))
expected_message = prop_initializer._REJECTION_SAMPLING_FAILED.format(
model_name=spheres[0].mjcf_model.model, max_attempts=1)
with self.assertRaisesWithLiteralMatch(RuntimeError, expected_message):
with self.assertRaisesWithLiteralMatch(
composer.EpisodeInitializationError, expected_message
):
prop_placer_seq[0](physics, random_state=np.random.RandomState(0))

# Placing the first sphere should succeed if we ignore contacts involving
Expand Down
13 changes: 8 additions & 5 deletions dm_control/composer/initializers/tcp_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def __call__(self, physics, random_state):
random_state: An `np.random.RandomState` instance.
Raises:
RuntimeError: If a collision-free pose could not be found within
`max_ik_attempts`.
composer.EpisodeInitializationError: If a collision-free pose could not be
found within `max_ik_attempts`.
"""
if self._hand is not None:
target_site = self._hand.tool_center_point
Expand Down Expand Up @@ -162,6 +162,9 @@ def __call__(self, physics, random_state):
# positions and try again with a new target.
physics.bind(self._arm.joints).qpos = initial_qpos

raise RuntimeError(_REJECTION_SAMPLING_FAILED.format(
max_rejection_samples=self._max_rejection_samples,
max_ik_attempts=self._max_ik_attempts))
raise composer.EpisodeInitializationError(
_REJECTION_SAMPLING_FAILED.format(
max_rejection_samples=self._max_rejection_samples,
max_ik_attempts=self._max_ik_attempts,
)
)
6 changes: 2 additions & 4 deletions dm_control/composer/initializers/tcp_initializer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# limitations under the License.
# ============================================================================

"""Tests for tcp_initializer."""

import functools

from absl.testing import absltest
Expand Down Expand Up @@ -105,7 +103,7 @@ def test_exception_if_hand_colliding_with_fixed_body(self):

initializer = make_initializer()
with self.assertRaisesWithLiteralMatch(
RuntimeError,
composer.EpisodeInitializationError,
tcp_initializer._REJECTION_SAMPLING_FAILED.format(
max_rejection_samples=max_rejection_samples,
max_ik_attempts=max_ik_attempts)):
Expand Down Expand Up @@ -145,7 +143,7 @@ def test_exception_if_self_collision(self, with_hand):

initializer = make_initializer()
with self.assertRaisesWithLiteralMatch(
RuntimeError,
composer.EpisodeInitializationError,
tcp_initializer._REJECTION_SAMPLING_FAILED.format(
max_rejection_samples=max_rejection_samples,
max_ik_attempts=max_ik_attempts)):
Expand Down

0 comments on commit 316e0f1

Please sign in to comment.