diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e9f3416b12..937853979c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -701,6 +701,7 @@ New Features: - Added checkpoints for replay buffer and ``VecNormalize`` statistics (@anand-bala) - Added option for ``Monitor`` to append to existing file instead of overriding (@sidney-tio) - The env checker now raises an error when using dict observation spaces and observation keys don't match observation space keys +- Use MacOS Metal "mps" device when available `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -758,6 +759,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Save cloudpickle version `SB3-Contrib`_ diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 7caef05018..a5a5476b97 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -30,8 +30,8 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None: """ Seed the different random generators. - :param seed: - :param using_cuda: + :param seed: Seed + :param using_cuda: Whether CUDA is currently used """ # Seed python RNG random.seed(seed) @@ -141,19 +141,20 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device: """ Retrieve PyTorch device. It checks that the requested device is available first. - For now, it supports only cpu and cuda. - By default, it tries to use the gpu. + For now, it supports only CPU and CUDA. + By default, it tries to use the GPU. - :param device: One for 'auto', 'cuda', 'cpu' + :param device: One of "auto", "cuda", "cpu", + or any PyTorch supported device (for instance "mps") :return: Supported Pytorch device """ - # Cuda by default + # MPS/CUDA by default if device == "auto": - device = "cuda" + device = get_available_accelerator() # Force conversion to th.device device = th.device(device) - # Cuda not available + # CUDA not available if device.type == th.device("cuda").type and not th.cuda.is_available(): return th.device("cpu") @@ -518,6 +519,20 @@ def should_collect_more_steps( ) +def get_available_accelerator() -> str: + """ + Return the available accelerator + (currently checking only for CUDA and MPS device) + """ + if hasattr(th, "backends") and th.backends.mps.is_built(): + # MacOS Metal GPU + return "mps" + elif th.cuda.is_available(): + return "cuda" + else: + return "cpu" + + def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]: """ Retrieve system and python env info for the current system. @@ -533,7 +548,7 @@ def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]: "Python": platform.python_version(), "Stable-Baselines3": sb3.__version__, "PyTorch": th.__version__, - "GPU Enabled": str(th.cuda.is_available()), + "Accelerator": get_available_accelerator(), "Numpy": np.__version__, "Cloudpickle": cloudpickle.__version__, "Gymnasium": gym.__version__, diff --git a/tests/test_utils.py b/tests/test_utils.py index bb2ebd0676..1280f57726 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -445,9 +445,10 @@ def test_get_system_info(): assert info["Stable-Baselines3"] == str(sb3.__version__) assert "Python" in info_str assert "PyTorch" in info_str - assert "GPU Enabled" in info_str + assert "Accelerator" in info_str assert "Numpy" in info_str assert "Gym" in info_str + assert "Cloudpickle" in info_str def test_is_vectorized_observation():