diff --git a/CITATION.cff b/CITATION.cff index 38eeca1a..a15a4915 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -3,7 +3,7 @@ title: "RSL-RL: A Learning Library for Robotics Research" message: "If you use this work, please cite the following paper." repository-code: "https://github.com/leggedrobotics/rsl_rl" license: BSD-3-Clause -version: 3.2.0 +version: 3.3.0 type: software authors: - family-names: Schwarke diff --git a/pyproject.toml b/pyproject.toml index 2c555e51..1b53c96d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "rsl-rl-lib" -version = "3.2.0" +version = "3.3.0" keywords = ["reinforcement-learning", "robotics"] maintainers = [ { name="Clemens Schwarke", email="cschwarke@ethz.ch" }, diff --git a/rsl_rl/algorithms/distillation.py b/rsl_rl/algorithms/distillation.py index 60e0f702..45b4d54d 100644 --- a/rsl_rl/algorithms/distillation.py +++ b/rsl_rl/algorithms/distillation.py @@ -3,6 +3,8 @@ # # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import torch import torch.nn as nn from tensordict import TensorDict diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py index 1f90456c..dc788d59 100644 --- a/rsl_rl/algorithms/ppo.py +++ b/rsl_rl/algorithms/ppo.py @@ -14,7 +14,7 @@ from rsl_rl.modules import ActorCritic, ActorCriticCNN, ActorCriticRecurrent from rsl_rl.modules.rnd import RandomNetworkDistillation from rsl_rl.storage import RolloutStorage -from rsl_rl.utils import string_to_callable +from rsl_rl.utils import resolve_callable class PPO: @@ -80,9 +80,8 @@ def __init__( # Print that we are not using symmetry if not use_symmetry: print("Symmetry not used for learning. We will use it for logging instead.") - # If function is a string then resolve it to a function - if isinstance(symmetry_cfg["data_augmentation_func"], str): - symmetry_cfg["data_augmentation_func"] = string_to_callable(symmetry_cfg["data_augmentation_func"]) + # Resolve the data augmentation function (supports string names or direct callables) + symmetry_cfg["data_augmentation_func"] = resolve_callable(symmetry_cfg["data_augmentation_func"]) # Check valid configuration if not callable(symmetry_cfg["data_augmentation_func"]): raise ValueError( diff --git a/rsl_rl/networks/memory.py b/rsl_rl/networks/memory.py index dc67abed..b0ecc8e6 100644 --- a/rsl_rl/networks/memory.py +++ b/rsl_rl/networks/memory.py @@ -7,10 +7,11 @@ import torch import torch.nn as nn +from typing import Union from rsl_rl.utils import unpad_trajectories -HiddenState = torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None +HiddenState = Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor], None] # Using Union due to Python <3.10 """Type alias for the hidden state of RNNs (GRU/LSTM). For GRUs, this is a single tensor while for LSTMs, this is a tuple of two tensors (hidden state and cell state). diff --git a/rsl_rl/runners/distillation_runner.py b/rsl_rl/runners/distillation_runner.py index 441e8f42..c50404b8 100644 --- a/rsl_rl/runners/distillation_runner.py +++ b/rsl_rl/runners/distillation_runner.py @@ -11,6 +11,7 @@ from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent from rsl_rl.runners import OnPolicyRunner from rsl_rl.storage import RolloutStorage +from rsl_rl.utils import resolve_callable class DistillationRunner(OnPolicyRunner): @@ -34,7 +35,7 @@ def _get_default_obs_sets(self) -> list[str]: def _construct_algorithm(self, obs: TensorDict) -> Distillation: """Construct the distillation algorithm.""" # Initialize the policy - student_teacher_class = eval(self.policy_cfg.pop("class_name")) + student_teacher_class = resolve_callable(self.policy_cfg.pop("class_name")) student_teacher: StudentTeacher | StudentTeacherRecurrent = student_teacher_class( obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg ).to(self.device) @@ -45,7 +46,7 @@ def _construct_algorithm(self, obs: TensorDict) -> Distillation: ) # Initialize the algorithm - alg_class = eval(self.alg_cfg.pop("class_name")) + alg_class = resolve_callable(self.alg_cfg.pop("class_name")) alg: Distillation = alg_class( student_teacher, storage, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg ) diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 2d9bac6c..eb924733 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -21,7 +21,7 @@ resolve_symmetry_config, ) from rsl_rl.storage import RolloutStorage -from rsl_rl.utils import resolve_obs_groups +from rsl_rl.utils import resolve_callable, resolve_obs_groups from rsl_rl.utils.logger import Logger @@ -267,7 +267,7 @@ def _construct_algorithm(self, obs: TensorDict) -> PPO: self.policy_cfg["critic_obs_normalization"] = self.cfg["empirical_normalization"] # Initialize the policy - actor_critic_class = eval(self.policy_cfg.pop("class_name")) + actor_critic_class = resolve_callable(self.policy_cfg.pop("class_name")) actor_critic: ActorCritic | ActorCriticRecurrent | ActorCriticCNN = actor_critic_class( obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg ).to(self.device) @@ -278,7 +278,7 @@ def _construct_algorithm(self, obs: TensorDict) -> PPO: ) # Initialize the algorithm - alg_class = eval(self.alg_cfg.pop("class_name")) + alg_class = resolve_callable(self.alg_cfg.pop("class_name")) alg: PPO = alg_class( actor_critic, storage, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg ) diff --git a/rsl_rl/utils/__init__.py b/rsl_rl/utils/__init__.py index 35aa5d9d..c9cb318f 100644 --- a/rsl_rl/utils/__init__.py +++ b/rsl_rl/utils/__init__.py @@ -7,20 +7,20 @@ from .utils import ( get_param, + resolve_callable, resolve_nn_activation, resolve_obs_groups, resolve_optimizer, split_and_pad_trajectories, - string_to_callable, unpad_trajectories, ) __all__ = [ "get_param", + "resolve_callable", "resolve_nn_activation", "resolve_obs_groups", "resolve_optimizer", "split_and_pad_trajectories", - "string_to_callable", "unpad_trajectories", ] diff --git a/rsl_rl/utils/utils.py b/rsl_rl/utils/utils.py index 43e0a141..b5c50379 100644 --- a/rsl_rl/utils/utils.py +++ b/rsl_rl/utils/utils.py @@ -6,11 +6,14 @@ from __future__ import annotations import importlib +import pkgutil import torch import warnings from tensordict import TensorDict from typing import Any, Callable +import rsl_rl + def get_param(param: Any, idx: int) -> Any: """Get a parameter for the given index. @@ -122,7 +125,7 @@ def split_and_pad_trajectories( # Add at least one full length trajectory trajectories = (*trajectories, torch.zeros(v.shape[0], *v.shape[2:], device=v.device)) # Pad the trajectories to the length of the longest trajectory - padded_trajectories[k] = torch.nn.utils.rnn.pad_sequence(trajectories) + padded_trajectories[k] = torch.nn.utils.rnn.pad_sequence(trajectories) # type: ignore # Remove the added trajectory padded_trajectories[k] = padded_trajectories[k][:, :-1] padded_trajectories = TensorDict( @@ -134,7 +137,7 @@ def split_and_pad_trajectories( # Add at least one full length trajectory trajectories = (*trajectories, torch.zeros(tensor.shape[0], *tensor.shape[2:], device=tensor.device)) # Pad the trajectories to the length of the longest trajectory - padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories) + padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories) # type: ignore # Remove the added trajectory padded_trajectories = padded_trajectories[:, :-1] # Create masks for the valid parts of the trajectories @@ -152,34 +155,84 @@ def unpad_trajectories(trajectories: torch.Tensor | TensorDict, masks: torch.Ten ) -def string_to_callable(name: str) -> Callable: - """Resolve the module and function names to return the function. +def resolve_callable(callable_or_name: type | Callable | str) -> Callable: + """Resolve a callable from a string, type, or return callable directly. + + This function enables passing custom classes or functions directly or as strings. The following formats are + supported: + - Direct callable: Pass a type or function directly (e.g., MyClass, my_func) + - Qualified name with colon: "module.path:Attr.Nested" (explicit, recommended) + - Qualified name with dot: "module.path.ClassName" (implicit) + - Simple name: e.g. "PPO", "ActorCritic", ... (looks for callable in rsl_rl) Args: - name: Function name. The format should be 'module:attribute_name'. + callable_or_name: A callable (type/function) or string name. Returns: - The function loaded from the module. + The resolved callable. Raises: - ValueError: When the resolved attribute is not a function. - ValueError: When unable to resolve the attribute. + TypeError: If input is neither a callable nor a string. + ImportError: If the module cannot be imported. + AttributeError: If the attribute cannot be found in the module. + ValueError: If a simple name cannot be found in rsl_rl packages. """ - try: - mod_name, attr_name = name.split(":") - mod = importlib.import_module(mod_name) - callable_object = getattr(mod, attr_name) - # Check if attribute is callable - if callable(callable_object): - return callable_object + # Already a callable - return directly + if callable(callable_or_name): + return callable_or_name + + # Must be a string at this point + if not isinstance(callable_or_name, str): + raise TypeError(f"Expected callable or string, got {type(callable_or_name)}") + + # Handle qualified name with colon separator (e.g., "module.path:Attr.Nested") + if ":" in callable_or_name: + module_path, attr_path = callable_or_name.rsplit(":", 1) + # Try to import the module + module = importlib.import_module(module_path) + # Try to get the attribute + obj = module + for attr in attr_path.split("."): + obj = getattr(obj, attr) + return obj # type: ignore + + # Handle qualified name with dot separator (e.g., "module.path.ClassName") + if "." in callable_or_name: + parts = callable_or_name.split(".") + module_found = False + for i in range(len(parts) - 1, 0, -1): + # Try to import the module with the first i parts + module_path = ".".join(parts[:i]) + attr_parts = parts[i:] + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError: + continue + module_found = True + # Once a module is found, try to get the attribute + obj = module + try: + for attr in attr_parts: + obj = getattr(obj, attr) + return obj # type: ignore + except AttributeError: + continue + if module_found: + raise AttributeError(f"Could not resolve '{callable_or_name}': attribute not found in module") else: - raise ValueError(f"The imported object is not callable: '{name}'") - except AttributeError as err: - msg = ( - "We could not interpret the entry as a callable object. The format of input should be" - f" 'module:attribute_name'\nWhile processing input '{name}'." - ) - raise ValueError(msg) from err + raise ImportError(f"Could not resolve '{callable_or_name}': no valid module.attr split found") + + # Simple name - look for it in rsl_rl + for _, module_name, _ in pkgutil.iter_modules(rsl_rl.__path__, "rsl_rl."): + module = importlib.import_module(module_name) + if hasattr(module, callable_or_name): + return getattr(module, callable_or_name) + + # Raise error if no approach worked + raise ValueError( + f"Could not resolve '{callable_or_name}'. Use qualified name like 'module.path:ClassName' " + f"or pass the class directly." + ) def resolve_obs_groups( diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..cad2d130 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for rsl_rl.""" diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..5bb05f73 --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for the utils module of rsl_rl.""" diff --git a/tests/utils/test_resolve_callable.py b/tests/utils/test_resolve_callable.py new file mode 100644 index 00000000..eb4da1e2 --- /dev/null +++ b/tests/utils/test_resolve_callable.py @@ -0,0 +1,148 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for resolve_callable utility function.""" + +import pytest + +from rsl_rl.utils import resolve_callable + + +# Test fixtures - nested class for testing nested attribute resolution +class OuterClass: + """Outer class for testing nested attribute resolution.""" + + class InnerClass: + """Inner nested class.""" + + pass + + @staticmethod + def static_method() -> str: + return "static" + + +def sample_function() -> str: + """Sample function for testing.""" + return "sample" + + +class TestResolveCallableDirect: + """Tests for direct callable passing.""" + + def test_direct_class(self) -> None: + """Passing a class directly should return it unchanged.""" + from rsl_rl.algorithms import PPO + + result = resolve_callable(PPO) + assert result is PPO + + def test_direct_function(self) -> None: + """Passing a function directly should return it unchanged.""" + result = resolve_callable(sample_function) + assert result is sample_function + + def test_direct_builtin(self) -> None: + """Passing a builtin should return it unchanged.""" + result = resolve_callable(len) + assert result is len + + +class TestResolveCallableColonFormat: + """Tests for colon-separated format 'module:attr'.""" + + def test_colon_format_class(self) -> None: + """Should resolve 'module:Class' format.""" + result = resolve_callable("rsl_rl.algorithms:PPO") + from rsl_rl.algorithms import PPO + + assert result is PPO + + def test_colon_format_nested(self) -> None: + """Should resolve 'module:Outer.Inner' nested format.""" + result = resolve_callable("tests.utils.test_resolve_callable:OuterClass.InnerClass") + assert result is OuterClass.InnerClass + + def test_colon_format_static_method(self) -> None: + """Should resolve nested static methods.""" + result = resolve_callable("tests.utils.test_resolve_callable:OuterClass.static_method") + assert result is OuterClass.static_method + + def test_colon_format_invalid_module(self) -> None: + """Should raise ImportError for invalid module.""" + with pytest.raises(ImportError): + resolve_callable("nonexistent_module:SomeClass") + + def test_colon_format_invalid_attr(self) -> None: + """Should raise AttributeError for invalid attribute.""" + with pytest.raises(AttributeError): + resolve_callable("rsl_rl.algorithms:NonexistentClass") + + +class TestResolveCallableDotFormat: + """Tests for dot-separated format 'module.attr'.""" + + def test_dot_format_class(self) -> None: + """Should resolve 'module.Class' format.""" + result = resolve_callable("rsl_rl.algorithms.PPO") + from rsl_rl.algorithms import PPO + + assert result is PPO + + def test_dot_format_nested(self) -> None: + """Should resolve 'module.Outer.Inner' nested format.""" + # This tests the progressive module path splitting + result = resolve_callable("tests.utils.test_resolve_callable.OuterClass.InnerClass") + assert result is OuterClass.InnerClass + + def test_dot_format_static_method(self) -> None: + """Should resolve nested static methods.""" + result = resolve_callable("tests.utils.test_resolve_callable.OuterClass.static_method") + assert result is OuterClass.static_method + + def test_dot_format_invalid_module(self) -> None: + """Should raise ImportError for invalid module.""" + with pytest.raises(ImportError): + resolve_callable("nonexistent_module.SomeClass") + + def test_dot_format_invalid_attr(self) -> None: + """Should raise AttributeError for invalid attribute.""" + with pytest.raises(AttributeError): + resolve_callable("rsl_rl.algorithms.NonexistentClass") + + +class TestResolveCallableSimpleName: + """Tests for simple name resolution via rsl_rl packages.""" + + def test_simple_name(self) -> None: + """Should resolve 'PPO' from rsl_rl.algorithms.""" + result = resolve_callable("PPO") + from rsl_rl.algorithms import PPO + + assert result is PPO + + def test_simple_name_unknown(self) -> None: + """Should raise ValueError for unknown simple names.""" + with pytest.raises(ValueError, match="Could not resolve"): + resolve_callable("NonexistentClassName") + + +class TestResolveCallableErrors: + """Tests for error handling.""" + + def test_type_error_none(self) -> None: + """Should raise TypeError for None input.""" + with pytest.raises(TypeError, match="Expected callable or string"): + resolve_callable(None) + + def test_type_error_int(self) -> None: + """Should raise TypeError for int input.""" + with pytest.raises(TypeError, match="Expected callable or string"): + resolve_callable(42) + + def test_type_error_list(self) -> None: + """Should raise TypeError for list input.""" + with pytest.raises(TypeError, match="Expected callable or string"): + resolve_callable(["PPO"])