diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 718571f0c8..63e37b89bd 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -155,8 +155,12 @@ def train(self) -> None: if self.normalize_advantage: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) - # Policy gradient loss + # Policy gradient + add_loss = None policy_loss = -(advantages * log_prob).mean() + if self.has_additional_loss: + add_loss = self._calculate_additional_loss(rollout_data.observations, log_prob).mean() + policy_loss += add_loss # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values) @@ -188,6 +192,8 @@ def train(self) -> None: self.logger.record("train/value_loss", value_loss.item()) if hasattr(self.policy, "log_std"): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + if add_loss is not None: + self.logger.record(f"train/{self.additional_loss_name}", add_loss.item()) def learn( self: SelfA2C, diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index e43955f94c..29e93200c9 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -6,7 +6,7 @@ import warnings from abc import ABC, abstractmethod from collections import deque -from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union import gymnasium as gym import numpy as np @@ -22,7 +22,14 @@ from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict +from stable_baselines3.common.type_aliases import ( + GymEnv, + MaybeCallback, + ReplayBufferSamples, + RolloutBufferSamples, + Schedule, + TensorDict, +) from stable_baselines3.common.utils import ( check_for_correct_spaces, get_device, @@ -199,6 +206,9 @@ def __init__( np.isfinite(np.array([self.action_space.low, self.action_space.high])) ), "Continuous action space must have a finite lower and upper bound" + # in order to initialize values + self.remove_additional_loss() + @staticmethod def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv: """ " @@ -864,3 +874,20 @@ def save( params_to_save = self.get_parameters() save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables) + + def set_additional_loss( + self, + loss_fn: Callable[[th.Tensor, th.Tensor], th.Tensor], + name: str, + ): + self.has_additional_loss = True + self.additional_loss_func = loss_fn + self.additional_loss_name = name if name.endswith("_loss") else f"{name}_loss" + + def remove_additional_loss(self): + self.has_additional_loss = False + self.additional_loss_func = None + self.additional_loss_name = None + + def _calculate_additional_loss(self, observations: th.Tensor, logits: th.Tensor) -> th.Tensor: + return self.additional_loss_func(observations, logits) if self.has_additional_loss else th.Tensor(0) diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index c7841866b4..b6c4b39103 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -5,8 +5,9 @@ import gymnasium as gym import numpy as np - +import torch as th from stable_baselines3.common.logger import Logger +from stable_baselines3.common.type_aliases import ReplayBufferSamples, RolloutBufferSamples try: from tqdm import TqdmExperimentalWarning @@ -125,6 +126,19 @@ def on_rollout_end(self) -> None: def _on_rollout_end(self) -> None: pass + def on_update_loss( + self, + samples: Union[RolloutBufferSamples, ReplayBufferSamples], + ) -> th.Tensor: + self.is_rollout_buffer = isinstance(samples, RolloutBufferSamples) + return self._on_update_loss(samples) + + def _on_update_loss( + self, + samples: Union[RolloutBufferSamples, ReplayBufferSamples], + ) -> th.Tensor: + pass + def update_locals(self, locals_: Dict[str, Any]) -> None: """ Update the references to the local variables. diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 262453721d..451f5ac2b4 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -242,6 +242,7 @@ def collect_rollouts( callback.on_rollout_end() + return True def train(self) -> None: diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 52ee2eb64c..3700857749 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -198,6 +198,7 @@ def train(self) -> None: entropy_losses = [] pg_losses, value_losses = [], [] clip_fractions = [] + additional_losses = [] continue_training = True # train for n_epochs epochs @@ -252,8 +253,12 @@ def train(self) -> None: entropy_loss = -th.mean(entropy) entropy_losses.append(entropy_loss.item()) - + add_loss = None loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + if self.has_additional_loss: + add_loss = self._calculate_additional_loss(rollout_data.observations, log_prob).mean() + loss += add_loss + additional_losses.append(add_loss.item()) # Calculate approximate form of reverse KL Divergence for early stopping # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 @@ -299,6 +304,9 @@ def train(self) -> None: if self.clip_range_vf is not None: self.logger.record("train/clip_range_vf", clip_range_vf) + if len(additional_losses) > 0: + self.logger.record(f"train/{self.additional_loss_name}", np.mean(additional_losses)) + def learn( self: SelfPPO, total_timesteps: int, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index bf0fa50282..4307c29d6b 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -209,6 +209,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: ent_coef_losses, ent_coefs = [], [] actor_losses, critic_losses = [], [] + additional_losses = [] for gradient_step in range(gradient_steps): # Sample replay buffer @@ -270,9 +271,14 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Compute actor loss # Alternative: actor_loss = th.mean(log_prob - qf1_pi) # Min over all critic networks + add_loss = None q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1) min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) actor_loss = (ent_coef * log_prob - min_qf_pi).mean() + if self.has_additional_loss: + add_loss = self._calculate_additional_loss(replay_data.observations, actions_pi).mean() + actor_loss += add_loss + additional_losses.append(add_loss.item()) actor_losses.append(actor_loss.item()) # Optimize the actor @@ -294,6 +300,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: self.logger.record("train/critic_loss", np.mean(critic_losses)) if len(ent_coef_losses) > 0: self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) + if len(additional_losses): + self.logger.record(f"train/{self.additional_loss_name}", np.mean(additional_losses)) def learn( self: SelfSAC, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index a61d954bc5..8b23b25d0d 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -159,6 +159,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: self._update_learning_rate([self.actor.optimizer, self.critic.optimizer]) actor_losses, critic_losses = [], [] + additional_losses = [] for _ in range(gradient_steps): self._n_updates += 1 # Sample replay buffer @@ -191,7 +192,13 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Delayed policy updates if self._n_updates % self.policy_delay == 0: # Compute actor loss - actor_loss = -self.critic.q1_forward(replay_data.observations, self.actor(replay_data.observations)).mean() + add_loss = None + logits = self.actor(replay_data.observations) + actor_loss = -self.critic.q1_forward(replay_data.observations, logits).mean() + if self.has_additional_loss: + add_loss = self._calculate_additional_loss(replay_data.observations, logits).mean() + actor_loss += add_loss + additional_losses.append(add_loss.item()) actor_losses.append(actor_loss.item()) # Optimize the actor @@ -209,6 +216,8 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: if len(actor_losses) > 0: self.logger.record("train/actor_loss", np.mean(actor_losses)) self.logger.record("train/critic_loss", np.mean(critic_losses)) + if len(additional_losses) > 0: + self.logger.record(f"train/{self.additional_loss_name}", np.mean(additional_losses)) def learn( self: SelfTD3,