Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ title: "RSL-RL: A Learning Library for Robotics Research"
message: "If you use this work, please cite the following paper."
repository-code: "https://github.com/leggedrobotics/rsl_rl"
license: BSD-3-Clause
version: 3.2.0
version: 3.3.0
type: software
authors:
- family-names: Schwarke
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "rsl-rl-lib"
version = "3.2.0"
version = "3.3.0"
keywords = ["reinforcement-learning", "robotics"]
maintainers = [
{ name="Clemens Schwarke", email="[email protected]" },
Expand Down
2 changes: 2 additions & 0 deletions rsl_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch
import torch.nn as nn
from tensordict import TensorDict
Expand Down
7 changes: 3 additions & 4 deletions rsl_rl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rsl_rl.modules import ActorCritic, ActorCriticCNN, ActorCriticRecurrent
from rsl_rl.modules.rnd import RandomNetworkDistillation
from rsl_rl.storage import RolloutStorage
from rsl_rl.utils import string_to_callable
from rsl_rl.utils import resolve_callable


class PPO:
Expand Down Expand Up @@ -80,9 +80,8 @@ def __init__(
# Print that we are not using symmetry
if not use_symmetry:
print("Symmetry not used for learning. We will use it for logging instead.")
# If function is a string then resolve it to a function
if isinstance(symmetry_cfg["data_augmentation_func"], str):
symmetry_cfg["data_augmentation_func"] = string_to_callable(symmetry_cfg["data_augmentation_func"])
# Resolve the data augmentation function (supports string names or direct callables)
symmetry_cfg["data_augmentation_func"] = resolve_callable(symmetry_cfg["data_augmentation_func"])
# Check valid configuration
if not callable(symmetry_cfg["data_augmentation_func"]):
raise ValueError(
Expand Down
3 changes: 2 additions & 1 deletion rsl_rl/networks/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import torch
import torch.nn as nn
from typing import Union

from rsl_rl.utils import unpad_trajectories

HiddenState = torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None
HiddenState = Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor], None] # Using Union due to Python <3.10
"""Type alias for the hidden state of RNNs (GRU/LSTM).

For GRUs, this is a single tensor while for LSTMs, this is a tuple of two tensors (hidden state and cell state).
Expand Down
5 changes: 3 additions & 2 deletions rsl_rl/runners/distillation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent
from rsl_rl.runners import OnPolicyRunner
from rsl_rl.storage import RolloutStorage
from rsl_rl.utils import resolve_callable


class DistillationRunner(OnPolicyRunner):
Expand All @@ -34,7 +35,7 @@ def _get_default_obs_sets(self) -> list[str]:
def _construct_algorithm(self, obs: TensorDict) -> Distillation:
"""Construct the distillation algorithm."""
# Initialize the policy
student_teacher_class = eval(self.policy_cfg.pop("class_name"))
student_teacher_class = resolve_callable(self.policy_cfg.pop("class_name"))
student_teacher: StudentTeacher | StudentTeacherRecurrent = student_teacher_class(
obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg
).to(self.device)
Expand All @@ -45,7 +46,7 @@ def _construct_algorithm(self, obs: TensorDict) -> Distillation:
)

# Initialize the algorithm
alg_class = eval(self.alg_cfg.pop("class_name"))
alg_class = resolve_callable(self.alg_cfg.pop("class_name"))
alg: Distillation = alg_class(
student_teacher, storage, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
)
Expand Down
6 changes: 3 additions & 3 deletions rsl_rl/runners/on_policy_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
resolve_symmetry_config,
)
from rsl_rl.storage import RolloutStorage
from rsl_rl.utils import resolve_obs_groups
from rsl_rl.utils import resolve_callable, resolve_obs_groups
from rsl_rl.utils.logger import Logger


Expand Down Expand Up @@ -267,7 +267,7 @@ def _construct_algorithm(self, obs: TensorDict) -> PPO:
self.policy_cfg["critic_obs_normalization"] = self.cfg["empirical_normalization"]

# Initialize the policy
actor_critic_class = eval(self.policy_cfg.pop("class_name"))
actor_critic_class = resolve_callable(self.policy_cfg.pop("class_name"))
actor_critic: ActorCritic | ActorCriticRecurrent | ActorCriticCNN = actor_critic_class(
obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg
).to(self.device)
Expand All @@ -278,7 +278,7 @@ def _construct_algorithm(self, obs: TensorDict) -> PPO:
)

# Initialize the algorithm
alg_class = eval(self.alg_cfg.pop("class_name"))
alg_class = resolve_callable(self.alg_cfg.pop("class_name"))
alg: PPO = alg_class(
actor_critic, storage, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
)
Expand Down
4 changes: 2 additions & 2 deletions rsl_rl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@

from .utils import (
get_param,
resolve_callable,
resolve_nn_activation,
resolve_obs_groups,
resolve_optimizer,
split_and_pad_trajectories,
string_to_callable,
unpad_trajectories,
)

__all__ = [
"get_param",
"resolve_callable",
"resolve_nn_activation",
"resolve_obs_groups",
"resolve_optimizer",
"split_and_pad_trajectories",
"string_to_callable",
"unpad_trajectories",
]
97 changes: 75 additions & 22 deletions rsl_rl/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
from __future__ import annotations

import importlib
import pkgutil
import torch
import warnings
from tensordict import TensorDict
from typing import Any, Callable

import rsl_rl


def get_param(param: Any, idx: int) -> Any:
"""Get a parameter for the given index.
Expand Down Expand Up @@ -122,7 +125,7 @@ def split_and_pad_trajectories(
# Add at least one full length trajectory
trajectories = (*trajectories, torch.zeros(v.shape[0], *v.shape[2:], device=v.device))
# Pad the trajectories to the length of the longest trajectory
padded_trajectories[k] = torch.nn.utils.rnn.pad_sequence(trajectories)
padded_trajectories[k] = torch.nn.utils.rnn.pad_sequence(trajectories) # type: ignore
# Remove the added trajectory
padded_trajectories[k] = padded_trajectories[k][:, :-1]
padded_trajectories = TensorDict(
Expand All @@ -134,7 +137,7 @@ def split_and_pad_trajectories(
# Add at least one full length trajectory
trajectories = (*trajectories, torch.zeros(tensor.shape[0], *tensor.shape[2:], device=tensor.device))
# Pad the trajectories to the length of the longest trajectory
padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories)
padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories) # type: ignore
# Remove the added trajectory
padded_trajectories = padded_trajectories[:, :-1]
# Create masks for the valid parts of the trajectories
Expand All @@ -152,34 +155,84 @@ def unpad_trajectories(trajectories: torch.Tensor | TensorDict, masks: torch.Ten
)


def string_to_callable(name: str) -> Callable:
"""Resolve the module and function names to return the function.
def resolve_callable(callable_or_name: type | Callable | str) -> Callable:
"""Resolve a callable from a string, type, or return callable directly.

This function enables passing custom classes or functions directly or as strings. The following formats are
supported:
- Direct callable: Pass a type or function directly (e.g., MyClass, my_func)
- Qualified name with colon: "module.path:Attr.Nested" (explicit, recommended)
- Qualified name with dot: "module.path.ClassName" (implicit)
- Simple name: e.g. "PPO", "ActorCritic", ... (looks for callable in rsl_rl)

Args:
name: Function name. The format should be 'module:attribute_name'.
callable_or_name: A callable (type/function) or string name.

Returns:
The function loaded from the module.
The resolved callable.

Raises:
ValueError: When the resolved attribute is not a function.
ValueError: When unable to resolve the attribute.
TypeError: If input is neither a callable nor a string.
ImportError: If the module cannot be imported.
AttributeError: If the attribute cannot be found in the module.
ValueError: If a simple name cannot be found in rsl_rl packages.
"""
try:
mod_name, attr_name = name.split(":")
mod = importlib.import_module(mod_name)
callable_object = getattr(mod, attr_name)
# Check if attribute is callable
if callable(callable_object):
return callable_object
# Already a callable - return directly
if callable(callable_or_name):
return callable_or_name

# Must be a string at this point
if not isinstance(callable_or_name, str):
raise TypeError(f"Expected callable or string, got {type(callable_or_name)}")

# Handle qualified name with colon separator (e.g., "module.path:Attr.Nested")
if ":" in callable_or_name:
module_path, attr_path = callable_or_name.rsplit(":", 1)
# Try to import the module
module = importlib.import_module(module_path)
# Try to get the attribute
obj = module
for attr in attr_path.split("."):
obj = getattr(obj, attr)
return obj # type: ignore

# Handle qualified name with dot separator (e.g., "module.path.ClassName")
if "." in callable_or_name:
parts = callable_or_name.split(".")
module_found = False
for i in range(len(parts) - 1, 0, -1):
# Try to import the module with the first i parts
module_path = ".".join(parts[:i])
attr_parts = parts[i:]
try:
module = importlib.import_module(module_path)
except ModuleNotFoundError:
continue
module_found = True
# Once a module is found, try to get the attribute
obj = module
try:
for attr in attr_parts:
obj = getattr(obj, attr)
return obj # type: ignore
except AttributeError:
continue
if module_found:
raise AttributeError(f"Could not resolve '{callable_or_name}': attribute not found in module")
else:
raise ValueError(f"The imported object is not callable: '{name}'")
except AttributeError as err:
msg = (
"We could not interpret the entry as a callable object. The format of input should be"
f" 'module:attribute_name'\nWhile processing input '{name}'."
)
raise ValueError(msg) from err
raise ImportError(f"Could not resolve '{callable_or_name}': no valid module.attr split found")

# Simple name - look for it in rsl_rl
for _, module_name, _ in pkgutil.iter_modules(rsl_rl.__path__, "rsl_rl."):
module = importlib.import_module(module_name)
if hasattr(module, callable_or_name):
return getattr(module, callable_or_name)

# Raise error if no approach worked
raise ValueError(
f"Could not resolve '{callable_or_name}'. Use qualified name like 'module.path:ClassName' "
f"or pass the class directly."
)


def resolve_obs_groups(
Expand Down
6 changes: 6 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Tests for rsl_rl."""
6 changes: 6 additions & 0 deletions tests/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Tests for the utils module of rsl_rl."""
Loading