Skip to content
1,136 changes: 618 additions & 518 deletions examples/regression.ipynb

Large diffs are not rendered by default.

91 changes: 83 additions & 8 deletions ptmelt/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

from ptmelt.layers import MELTBatchNorm
from ptmelt.layers import MELTBatchNorm, MELTBayesianDenseFlipOut
from ptmelt.nn_utils import get_activation, get_initializer


Expand Down Expand Up @@ -38,6 +40,7 @@ def __init__(
batch_norm_type: Optional[str] = "ema",
use_batch_renorm: Optional[bool] = False,
initializer: Optional[str] = "glorot_uniform",
seed: Optional[int] = None,
**kwargs: Any,
):
super(MELTBlock, self).__init__(**kwargs)
Expand All @@ -50,6 +53,7 @@ def __init__(
self.batch_norm_type = batch_norm_type
self.use_batch_renorm = use_batch_renorm
self.initializer = initializer
self.seed = seed

# Get the initializer function
self.initializer_fn = get_initializer(self.initializer)
Expand Down Expand Up @@ -141,6 +145,7 @@ def __init__(
}
)
# Initialize the weights
torch.manual_seed(self.seed) if self.seed is not None else None
[
self.initializer_fn(self.layer_dict[f"dense_{i}"].weight)
for i in range(self.num_layers)
Expand Down Expand Up @@ -211,6 +216,7 @@ def __init__(
}
)
# Initialize the weights
torch.manual_seed(self.seed) if self.seed is not None else None
[
self.initializer_fn(self.layer_dict[f"dense_{i}"].weight)
for i in range(self.num_layers)
Expand Down Expand Up @@ -251,6 +257,60 @@ def forward(self, inputs: torch.Tensor):
return x


class BayesianBlock(MELTBlock):
"""
Bayesian block for the MELT architecture using custom Bayesian layers.
"""

def __init__(
self,
num_points,
perturbation_type="multiplicative",
seed: Optional[int] = None,
**kwargs: Any,
):
super(BayesianBlock, self).__init__(**kwargs)
# self.num_points = num_points

self.perturbation_type = perturbation_type
self.seed = seed

# Initialize Bayesian layers
self.layer_dict.update(
{
f"bayesian_{i}": MELTBayesianDenseFlipOut(
in_features=(
self.input_features if i == 0 else self.node_list[i - 1]
),
out_features=self.node_list[i],
perturbation_type=self.perturbation_type,
seed=self.seed,
)
for i in range(self.num_layers)
}
)

def forward(self, inputs: torch.Tensor):
"""Perform the forward pass of the Bayesian block."""
x = inputs

for i in range(self.num_layers):
# bayesian -> batch norm -> activation -> dropout
x = self.layer_dict[f"bayesian_{i}"](x)
x = self.layer_dict[f"batch_norm_{i}"](x) if self.batch_norm else x
x = self.layer_dict[f"activation_{i}"](x) if self.activation else x
x = self.layer_dict[f"dropout_{i}"](x) if self.dropout > 0 else x

return x

def kl_loss(self):
"""Calculate the KL divergence loss for all Bayesian layers."""
kl_div = 0
for i in range(self.num_layers):
kl_div += self.layer_dict[f"bayesian_{i}"]._kl_divergence()
return kl_div


class DefaultOutput(nn.Module):
"""
Default output layer with a single dense layer and optional activation function.
Expand All @@ -269,6 +329,8 @@ def __init__(
output_features: int,
activation: Optional[str] = "linear",
initializer: Optional[str] = "glorot_uniform",
do_bayesian: Optional[bool] = False,
seed: Optional[int] = None,
**kwargs: Any,
):
super(DefaultOutput, self).__init__(**kwargs)
Expand All @@ -277,16 +339,25 @@ def __init__(
self.output_features = output_features
self.activation = activation
self.initializer = initializer
self.seed = seed

# Get the initializer function
self.initializer_fn = get_initializer(self.initializer)

# Initialize output layer
self.output_layer = nn.Linear(
in_features=self.input_features, out_features=self.output_features
)
# Initialize the weights
self.initializer_fn(self.output_layer.weight)
if do_bayesian:
self.output_layer = MELTBayesianDenseFlipOut(
in_features=self.input_features,
out_features=self.output_features,
seed=self.seed,
)
else:
self.output_layer = nn.Linear(
in_features=self.input_features, out_features=self.output_features
)
# Initialize the weights
torch.manual_seed(self.seed) if self.seed is not None else None
self.initializer_fn(self.output_layer.weight)

# Initialize activation layer
self.activation_layer = get_activation(self.activation)
Expand Down Expand Up @@ -321,6 +392,7 @@ def __init__(
num_outputs: int,
activation: Optional[str] = "linear",
initializer: Optional[str] = "glorot_uniform",
seed: Optional[int] = None,
**kwargs: Any,
):
super(MixtureDensityOutput, self).__init__(**kwargs)
Expand All @@ -330,6 +402,7 @@ def __init__(
self.num_outputs = num_outputs
self.activation = activation
self.initializer = initializer
self.seed = seed

# Get the initializer function
self.initializer_fn = get_initializer(self.initializer)
Expand All @@ -348,6 +421,7 @@ def __init__(
)

# Initialize the weights
torch.manual_seed(self.seed) if self.seed is not None else None
self.initializer_fn(self.mix_coeffs_layer.weight)
self.initializer_fn(self.mean_layer.weight)
self.initializer_fn(self.log_var_layer.weight)
Expand All @@ -359,13 +433,14 @@ def __init__(
def forward(self, inputs: torch.Tensor):
"""Perform the forward pass of the multiple mixture output layer."""
mix_coeffs = self.mix_coeffs_layer(inputs)
mix_coeffs = torch.clamp(mix_coeffs, min=-10, max=10)
mix_coeffs = self.softmax_layer(mix_coeffs)

mean = self.mean_layer(inputs)
mean = self.activation_layer(mean)
# TODO: Do we ever want to apply an activation function to the mean?
# mean = self.activation_layer(mean)
Comment thread
nickwimer marked this conversation as resolved.

log_var = self.log_var_layer(inputs)
log_var = self.activation_layer(log_var)

# return concatenated output
return torch.cat([mix_coeffs, mean, log_var], dim=-1)
113 changes: 113 additions & 0 deletions ptmelt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,119 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, kl_divergence


class MELTBayesianDenseFlipOut(nn.Module):
"""
Custom Bayesian Layer for PT-MELT.
"""

def __init__(
self,
in_features: int,
out_features: int,
prior_mean: float = 0.0,
prior_std: float = 10.0,
perturbation_type: str = "additive",
seed: Optional[int] = None,
):
"""
Initialize the Bayesian layer using a Dense Flipout type approach.
Implementation based on the paper:
"Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches"
by Wen et al. (2018).
https://doi.org/10.48550/arXiv.1803.04386

Perturbations are of the form: W = W_mu + delta_W where delta_W can take on
different forms depending on the perturbation type.
Additive perturbations are formulated like: W = W_mu + W_sigma * epsilon
Multiplicative perturbations are formulated like: W = W_mu * (1 + W_sigma * epsilon)

"""
super(MELTBayesianDenseFlipOut, self).__init__()

self.in_features = in_features
self.out_features = out_features
self.prior_mean = prior_mean
self.prior_std = prior_std
self.perturbation_type = perturbation_type
self.seed = seed

# Initialize learnable parameters for the posterior (zeros? or random?)
self.weight_mu = nn.Parameter(torch.zeros(out_features, in_features))
self.weight_rho = nn.Parameter(torch.zeros(out_features, in_features))
self.bias_mu = nn.Parameter(torch.zeros(out_features))
self.bias_rho = nn.Parameter(torch.zeros(out_features))

# Initialize the parameters
torch.manual_seed(self.seed) if self.seed is not None else None
nn.init.xavier_uniform_(self.weight_mu)
nn.init.zeros_(self.bias_mu)
nn.init.constant_(self.weight_rho, -3.0)
nn.init.constant_(self.bias_rho, -3.0)

# Define prior distributions
self.prior = Normal(self.prior_mean, self.prior_std)
self.posterior_weight = None
self.posterior_bias = None

def forward(self, input: torch.Tensor):
"""
Perform the forward pass of the Bayesian Linear Layer.

Flipout weights are perturbed like: W = W_mu + delta_W
delta_W has a component shared across the entire mini-batch and a component
that is unique to each input sample.

"""
batch_size = input.size(0)

# Convert rho to sigma
weight_sigma = F.softplus(self.weight_rho)
bias_sigma = F.softplus(self.bias_rho)

# Compute the mini-batch delta for weights and biases
weight_epsilon = torch.randn_like(self.weight_mu)
bias_epsilon = torch.randn(batch_size, self.out_features, device=input.device)

# delta_W shared across the mini-batch
delta_W = weight_sigma * weight_epsilon
delta_b = bias_sigma * bias_epsilon

# delta_W unique to each input sample by Flipout perturbations
row_sign = torch.sign(
torch.randn(batch_size, self.out_features, device=input.device)
)
col_sign = torch.sign(
torch.randn(batch_size, self.in_features, device=input.device)
)

# Compute the perturbed weights
pert_matrix = row_sign.unsqueeze(2) * col_sign.unsqueeze(1)
if self.perturbation_type == "additive":
perturbed_weights = self.weight_mu + delta_W * pert_matrix
elif self.perturbation_type == "multiplicative":
perturbed_weights = self.weight_mu * (1 + delta_W * pert_matrix)

perturbed_bias = self.bias_mu + delta_b

# Torch-based efficient way of handling the matrix multiplication
output = torch.einsum("bij,bj->bi", perturbed_weights, input) + perturbed_bias

return output

def _kl_divergence(self):
posterior_weight = Normal(self.weight_mu, F.softplus(self.weight_rho))
posterior_bias = Normal(self.bias_mu, F.softplus(self.bias_rho))
prior = Normal(self.prior_mean, self.prior_std)

# Compute KL divergence for weights and biases
kl_weights = kl_divergence(posterior_weight, prior).sum()
kl_biases = kl_divergence(posterior_bias, prior).sum()

return kl_weights + kl_biases


class MELTBatchNorm(nn.Module):
Expand Down
55 changes: 34 additions & 21 deletions ptmelt/losses.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import torch
import torch.nn.functional as F


def safe_exp(x):
"""Prevents overflow by clipping input range to reasonable values."""
x = torch.clamp(x, min=-20, max=20)
x = torch.clamp(x, min=-10, max=10)
return torch.exp(x)


class MixtureDensityLoss(torch.nn.Module):
"""
Custom loss function for a Gaussian mixture model.
Custom loss function for Mixture Density Network (MDN).

Args:
num_mixtures (int): Number of mixture components.
Expand All @@ -22,38 +23,50 @@ def __init__(self, num_mixtures, num_outputs):
self.num_outputs = num_outputs

def forward(self, y_pred, y_true):
# NOTE: the order of the parameters is reversed compared to Keras and TensorFlow
# Extract the mixture coefficients, means, and log-variances
end_mixture = self.num_mixtures
end_mean = end_mixture + self.num_mixtures * self.num_outputs
end_log_var = end_mean + self.num_mixtures * self.num_outputs

# coefficients -> (batch_size, num_mixtures)
m_coeffs = y_pred[:, :end_mixture]
# means -> (batch_size, num_mixtures * num_outputs)
mean_preds = y_pred[:, end_mixture:end_mean]
# log variances -> (batch_size, num_mixtures * num_outputs)
log_var_preds = y_pred[:, end_mean:end_log_var]

# Reshape to ensure same shape as y_true replicated across mixtures
# Reshape mean predictions -> (batch_size, num_mixtures, num_outputs)
mean_preds = mean_preds.view(-1, self.num_mixtures, self.num_outputs)
# Reshape log variance predictions -> (batch_size, num_mixtures, num_outputs)
log_var_preds = log_var_preds.view(-1, self.num_mixtures, self.num_outputs)

# Calculate the Gaussian probability density function for each component
# Ensure mixture coefficients sum to 1
m_coeffs = F.softmax(m_coeffs, dim=1)
# Convert log variance to variance
var_preds = safe_exp(log_var_preds)

# Difference term -> (batch_size, num_mixtures, num_outputs)
diff = y_true.unsqueeze(1) - mean_preds
# # Exponent term -> (batch_size, num_mixtures, num_outputs)
# exp_term = -0.5 * torch.square(diff) / var_preds

# Compute log probabilities terms
const_term = -0.5 * self.num_outputs * torch.log(torch.tensor(2 * torch.pi))
inv_sigma_log = -0.5 * log_var_preds
exp_term = (
-0.5
* torch.square(y_true.unsqueeze(1) - mean_preds)
/ safe_exp(log_var_preds)
)

# form the log probabilities
log_probs = const_term + inv_sigma_log + exp_term

# Calculate the log likelihood
weighted_log_probs = log_probs + torch.log(m_coeffs.unsqueeze(-1))
var_log_term = -0.5 * log_var_preds
exp_term = -0.5 * torch.square(diff) / var_preds
log_probs = const_term + var_log_term + exp_term

# Sum over output dimensions to get log probabilities for each mixture
# -> (batch_size, num_mixtures)
log_probs = log_probs.sum(dim=2)

# Compute mixture weighted log probabilities and add eps to prevent log(0)
weighted_log_probs = log_probs + torch.log(m_coeffs + 1e-8)

# Log-Sum-Exp trick for numerical stability -> (batch_size,)
log_sum_exp = torch.logsumexp(weighted_log_probs, dim=1)

# Compute the log likelihood loss
log_likelihood = torch.mean(log_sum_exp)
# Compute final negative log-likelihood loss -> scalar
loss = -torch.mean(log_sum_exp)

# Return the negative log likelihood
return -log_likelihood
return loss
Loading