Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KL Adaptive LR and SimBa policy for PPO #67

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion sbx/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _maybe_reset_params(self) -> None:
):
# Note: we are not resetting the entropy coeff
assert isinstance(self.qf_learning_rate, float)
self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate)
self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) # type: ignore[operator]
self.reset_idx += 1

def _get_torch_save_params(self):
Expand Down
30 changes: 24 additions & 6 deletions sbx/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gymnasium as gym
import jax
import numpy as np
import optax
import torch as th
from gymnasium import spaces
from stable_baselines3.common.buffers import RolloutBuffer
Expand Down Expand Up @@ -75,6 +76,27 @@ def _excluded_save_params(self) -> list[str]:
excluded.remove("policy")
return excluded

def _update_learning_rate( # type: ignore[override]
self,
optimizers: Union[list[optax.OptState], optax.OptState],
learning_rate: float,
) -> None:
"""
Update the optimizers learning rate using the current learning rate schedule
and the current progress remaining (from 1 to 0).

:param optimizers:
An optimizer or a list of optimizers.
"""
# Log the current learning rate
self.logger.record("train/learning_rate", learning_rate)

if not isinstance(optimizers, list):
optimizers = [optimizers]
for optimizer in optimizers:
# Note: the optimizer must have been defined with inject_hyperparams
optimizer.hyperparams["learning_rate"] = learning_rate

def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override]
super().set_random_seed(seed)
if seed is None:
Expand Down Expand Up @@ -167,12 +189,8 @@ def collect_rollouts(

# Handle timeout by bootstraping with value function
# see GitHub issue #633
for idx, done in enumerate(dones):
if (
done
and infos[idx].get("terminal_observation") is not None
and infos[idx].get("TimeLimit.truncated", False)
):
for idx in dones.nonzero()[0]:
if infos[idx].get("terminal_observation") is not None and infos[idx].get("TimeLimit.truncated", False):
terminal_obs = self.policy.prepare_obs(infos[idx]["terminal_observation"])[0]
terminal_value = np.array(
self.vf.apply( # type: ignore[union-attr]
Expand Down
16 changes: 14 additions & 2 deletions sbx/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ class SquashedGaussianActor(nn.Module):
log_std_min: float = -20
log_std_max: float = 2
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
ortho_init: bool = False
log_std_init: float = -1.2 # log(0.3)

def get_std(self):
# Make it work with gSDE
Expand All @@ -253,8 +255,18 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
x = self.activation_fn(x)
mean = nn.Dense(self.action_dim)(x)
log_std = nn.Dense(self.action_dim)(x)

if self.ortho_init:
orthogonal_init = nn.initializers.orthogonal(scale=0.01)
# orthogonal_init = nn.initializers.uniform(scale=0.01)
# orthogonal_init = nn.initializers.normal(stddev=0.01)
bias_init = nn.initializers.zeros
mean = nn.Dense(self.action_dim, kernel_init=orthogonal_init, bias_init=bias_init)(x)
log_std = self.param("log_std", nn.initializers.constant(self.log_std_init), (self.action_dim,))
# log_std = nn.Dense(self.action_dim, kernel_init=orthogonal_init, bias_init=bias_init)(x)
else:
mean = nn.Dense(self.action_dim)(x)
log_std = nn.Dense(self.action_dim)(x)
log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
dist = TanhTransformedDistribution(
tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)),
Expand Down
26 changes: 26 additions & 0 deletions sbx/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import dataclass

import numpy as np


@dataclass
class KLAdaptiveLR:
"""Adaptive lr schedule, see https://arxiv.org/abs/1707.02286"""

# If set will trigger adaptive lr
target_kl: float
current_adaptive_lr: float
# Values taken from https://github.com/leggedrobotics/rsl_rl
min_learning_rate: float = 1e-5
max_learning_rate: float = 1e-2
kl_margin: float = 2.0
# Divide or multiply the lr by this factor
adaptive_lr_factor: float = 1.5

def update(self, kl_div: float) -> None:
if kl_div > self.target_kl * self.kl_margin:
self.current_adaptive_lr /= self.adaptive_lr_factor
elif kl_div < self.target_kl / self.kl_margin:
self.current_adaptive_lr *= self.adaptive_lr_factor

self.current_adaptive_lr = np.clip(self.current_adaptive_lr, self.min_learning_rate, self.max_learning_rate)
140 changes: 129 additions & 11 deletions sbx/ppo/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gymnasium import spaces
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.jax_layers import SimbaResidualBlock
from sbx.common.policies import BaseJaxPolicy, Flatten

tfd = tfp.distributions
Expand All @@ -34,6 +35,23 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return x


class SimbaCritic(nn.Module):
net_arch: Sequence[int]
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
scale_factor: int = 4

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = nn.Dense(self.net_arch[0])(x)
for n_units in self.net_arch:
x = SimbaResidualBlock(n_units, self.activation_fn, self.scale_factor)(x)

x = nn.LayerNorm()(x)
x = nn.Dense(1)(x)
return x


class Actor(nn.Module):
action_dim: int
net_arch: Sequence[int]
Expand All @@ -44,6 +62,8 @@ class Actor(nn.Module):
# For MultiDiscrete
max_num_choices: int = 0
split_indices: np.ndarray = field(default_factory=lambda: np.array([]))
# Last layer with small scale
ortho_init: bool = False

def get_std(self) -> jnp.ndarray:
# Make it work with gSDE
Expand All @@ -65,7 +85,15 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def
x = nn.Dense(n_units)(x)
x = self.activation_fn(x)

action_logits = nn.Dense(self.action_dim)(x)
if self.ortho_init:
orthogonal_init = nn.initializers.orthogonal(scale=0.01)
bias_init = nn.initializers.zeros
action_logits = nn.Dense(self.action_dim, kernel_init=orthogonal_init, bias_init=bias_init)(x)

else:
action_logits = nn.Dense(self.action_dim)(x)

log_std = jnp.zeros(1)
if self.num_discrete_choices is None:
# Continuous actions
log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,))
Expand Down Expand Up @@ -97,6 +125,47 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def
return dist


class SimbaActor(nn.Module):
action_dim: int
net_arch: Sequence[int]
log_std_init: float = 0.0
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
# For Discrete, MultiDiscrete and MultiBinary actions
num_discrete_choices: Optional[Union[int, Sequence[int]]] = None
# For MultiDiscrete
max_num_choices: int = 0
# Last layer with small scale
ortho_init: bool = False
scale_factor: int = 4

def get_std(self) -> jnp.ndarray:
# Make it work with gSDE
return jnp.array(0.0)

@nn.compact
def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined]
x = Flatten()(x)

x = nn.Dense(self.net_arch[0])(x)
for n_units in self.net_arch:
x = SimbaResidualBlock(n_units, self.activation_fn, self.scale_factor)(x)
x = nn.LayerNorm()(x)

if self.ortho_init:
orthogonal_init = nn.initializers.orthogonal(scale=0.01)
bias_init = nn.initializers.zeros
mean_action = nn.Dense(self.action_dim, kernel_init=orthogonal_init, bias_init=bias_init)(x)

else:
mean_action = nn.Dense(self.action_dim)(x)

# Continuous actions
log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,))
dist = tfd.MultivariateNormalDiag(loc=mean_action, scale_diag=jnp.exp(log_std))

return dist


class PPOPolicy(BaseJaxPolicy):
def __init__(
self,
Expand All @@ -118,6 +187,8 @@ def __init__(
optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
share_features_extractor: bool = False,
actor_class: type[nn.Module] = Actor,
critic_class: type[nn.Module] = Critic,
):
if optimizer_kwargs is None:
# Small values to avoid NaN in Adam optimizer
Expand Down Expand Up @@ -146,6 +217,9 @@ def __init__(
else:
self.net_arch_pi = self.net_arch_vf = [64, 64]
self.use_sde = use_sde
self.ortho_init = ortho_init
self.actor_class = actor_class
self.critic_class = critic_class

self.key = self.noise_key = jax.random.PRNGKey(0)

Expand Down Expand Up @@ -189,38 +263,38 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) ->
else:
raise NotImplementedError(f"{self.action_space}")

self.actor = Actor(
self.actor = self.actor_class(
net_arch=self.net_arch_pi,
log_std_init=self.log_std_init,
activation_fn=self.activation_fn,
ortho_init=self.ortho_init,
**actor_kwargs, # type: ignore[arg-type]
)
# Hack to make gSDE work without modifying internal SB3 code
self.actor.reset_noise = self.reset_noise

# Inject hyperparameters to be able to modify it later
# See https://stackoverflow.com/questions/78527164
# Note: eps=1e-5 for Adam
optimizer_class = optax.inject_hyperparams(self.optimizer_class)(learning_rate=lr_schedule(1), **self.optimizer_kwargs)

self.actor_state = TrainState.create(
apply_fn=self.actor.apply,
params=self.actor.init(actor_key, obs),
tx=optax.chain(
optax.clip_by_global_norm(max_grad_norm),
self.optimizer_class(
learning_rate=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs, # , eps=1e-5
),
optimizer_class,
),
)

self.vf = Critic(net_arch=self.net_arch_vf, activation_fn=self.activation_fn)
self.vf = self.critic_class(net_arch=self.net_arch_vf, activation_fn=self.activation_fn)

self.vf_state = TrainState.create(
apply_fn=self.vf.apply,
params=self.vf.init({"params": vf_key}, obs),
tx=optax.chain(
optax.clip_by_global_norm(max_grad_norm),
self.optimizer_class(
learning_rate=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs, # , eps=1e-5
),
optimizer_class,
),
)

Expand Down Expand Up @@ -257,3 +331,47 @@ def _predict_all(actor_state, vf_state, observations, key):
log_probs = dist.log_prob(actions)
values = vf_state.apply_fn(vf_state.params, observations).flatten()
return actions, log_probs, values


class SimbaPPOPolicy(PPOPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
ortho_init: bool = False,
log_std_init: float = 0,
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh,
use_sde: bool = False,
use_expln: bool = False,
clip_mean: float = 2,
features_extractor_class=None,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
share_features_extractor: bool = False,
actor_class: type[nn.Module] = SimbaActor,
critic_class: type[nn.Module] = SimbaCritic,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
ortho_init,
log_std_init,
activation_fn,
use_sde,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
share_features_extractor,
actor_class,
critic_class,
)
Loading