diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 71456ed22..0322a5353 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,7 +52,7 @@ repos: files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$ - repo: https://github.com/codespell-project/codespell - rev: v2.3.0 + rev: v2.4.0 hooks: - id: codespell - args: [ "--skip", "*.hook", "--ignore-words-list", "ans,nd,Bu,astroid,hart" ] + args: [ "--skip", "*.hook", "--ignore-words-list", "ans,nd,Bu,astroid,hart", "--ignore-multiline-regex", "codespell:ignore-begin.*codespell:ignore-end" ] diff --git a/alf/algorithms/agent.py b/alf/algorithms/agent.py index 804a24ea7..d997c06e6 100644 --- a/alf/algorithms/agent.py +++ b/alf/algorithms/agent.py @@ -506,7 +506,7 @@ def preprocess_experience(self, root_inputs, rollout_info, batch_info): def summarize_rollout(self, experience): """First call ``RLAlgorithm.summarize_rollout()`` to summarize basic - rollout statisics. If the rl algorithm has overridden this function, + rollout statistics. If the rl algorithm has overridden this function, then also call its customized version. """ super(Agent, self).summarize_rollout(experience) diff --git a/alf/algorithms/diffusion_algorithm.py b/alf/algorithms/diffusion_algorithm.py new file mode 100644 index 000000000..c708a1cbe --- /dev/null +++ b/alf/algorithms/diffusion_algorithm.py @@ -0,0 +1,338 @@ +# Copyright (c) 2025 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Diffusion-like generative models driven by SDEs.""" + +import math +import torch +import alf +from alf.utils import summary_utils +from alf.utils.dist_utils import TruncatedNormal +from alf.nest import get_nest_batch_size +from alf.utils.common import expand_dims_as + + +class SDE: + r"""Base stochastic differential equation interface. + + .. math:: + + dx = -\beta(t) x dt + g(t) dw + + Sub-classes provide closed-form coefficients for specific SDE families. + + """ + + def pt0(self, t): + """Return the closed-form :math:`p(x_t | x_0)` parameters. + + Args: + t: Tensor of time values at which to evaluate the marginal. + + Returns: + Tuple ``(alpha, sigma)`` representing the linear coefficient and + standard deviation of the conditional distribution. + :math:`p(x_t | x_0) = Normal(alpha * x_0, sigma^2 I)`. + """ + raise NotImplementedError + + def pt0_dot(self, t): + """Return the derivatives of pt0 with respect to time. + + Args: + t: Tensor of time values to differentiate at. + + Returns: + Tuple ``(alpha_dot, sigma_dot)`` corresponding to the time + derivatives of the parameters of :math:`p(x_t | x_0)`. + """ + raise NotImplementedError + + def diffusion_coeff(self, t): + r"""Return the drift and diffusion coefficients :math:`(\beta, g)`. + + Args: + t: Tensor of time values at which to evaluate the coefficients. + + Returns: + Tuple ``(beta, g)`` describing the SDE drift and diffusion terms. + """ + raise NotImplementedError + + +class OTSDE(SDE): + # codespell:ignore-begin + r"""Optimal transport SDE with closed-form linear coefficients. + + This corresponds to commonly used flow matching model ("FM/OT" in https://arxiv.org/abs/2210.02747). + + .. math:: + + dx = -\frac{1}{1-t} x dt + \sqrt{\frac{2t}{1-t}} dw + + """ + + # codespell:ignore-end + + def pt0(self, t): + return 1 - t, t + + def pt0_dot(self, t): + return -1, 1 + + def diffusion_coeff(self, t): + return 1 / (1 - t), (2 * t / (1 - t)).sqrt() + + +class RTSDE(SDE): + r"""SDE with trigonometric drift and diffusion. + + .. math:: + + dx = - \frac{\pi}{2}\tan(\frac{\pi}{2}t)xdt + \sqrt{\pi\tan(\frac{\pi}{2}t)} dw + + """ + + def pt0(self, t): + angle = 0.5 * math.pi * t + return torch.cos(angle), torch.sin(angle) + + def pt0_dot(self, t): + angle = 0.5 * math.pi * t + return -0.5 * math.pi * torch.sin(angle), 0.5 * math.pi * torch.cos( + angle) + + def diffusion_coeff(self, t): + beta = 0.5 * math.pi * torch.tan(0.5 * math.pi * t) + return 0.5 * beta, (2 * beta)**0.5 + + +class SDEGenerator(torch.nn.Module): + """Base class for diffusion-like generators driven by an SDE.""" + + def __init__(self, + input_spec, + output_spec, + model_ctor, + sde: SDE, + mean_flow=False, + time_sampler=torch.rand, + steps=5): + """Initialize the generator and underlying neural model. + + Args: + input_spec: Spec describing conditional inputs. + output_spec: Spec for generated data. + model_ctor: Factory creating the predictor network. The callable + should accept ``(input_spec, output_spec, mean_flow)`` and + return an :class:`alf.networks.Network` that consumes + ``(x_t, inputs, t[, h])``. + sde: Stochastic differential equation describing the generative + process. + mean_flow: If ``True``, the model will receive an additional input + encoding the look-ahead horizon for mean flow training. + time_sampler: Callable that samples the training time ``t``. + steps: Number of Euler steps used during sampling. + """ + super().__init__() + self._model = model_ctor(input_spec, output_spec, mean_flow=mean_flow) + self._sde = sde + self._output_spec = output_spec + self._steps = steps + if isinstance(output_spec, alf.BoundedTensorSpec): + self._min = torch.tensor(output_spec.minimum) + self._max = torch.tensor(output_spec.maximum) + self._time_sampler = time_sampler + + def calc_loss(self, inputs, samples, f_neg_energy=None, sample_mask=None): + """Compute per-sample losses for training. + + Args: + inputs: Conditional input nest consumed by the model during + training. + samples: Observed ``x_0`` tensors. When ``None`` the implementation + will draw ``x_0`` from the prior distribution. + f_neg_energy: Optional callable ``f(x0, inputs) -> Tensor`` that + returns per-sample negative energies used to importance weight + the loss. + sample_mask: Optional broadcastable tensor used to zero out losses + of masked samples. + """ + raise NotImplementedError + + def sample(self, inputs, steps): + """Generate samples given the conditioning inputs. + + Args: + inputs: Conditional inputs used during generation. + steps: Number of Euler integration steps to use. + + Returns: + Generated samples matching ``output_spec``. + """ + raise NotImplementedError + + @property + def state_spec(self): + return () + + def _apply_sample_mask(self, diff, sample_mask): + """Apply a mask to the loss tensor if provided. + + Args: + diff: Tensor containing per-element loss values. + sample_mask: Optional mask tensor broadcastable to ``diff``. + + Returns: + Masked loss tensor. + """ + if sample_mask is not None: + if sample_mask.ndim == 1: + sample_mask = expand_dims_as(sample_mask, diff) + diff = diff * sample_mask + return diff + + def _calc_weights(self, x0, alpha, sigma, inputs, f_neg_energy): + """Compute importance weights based on the energy function. + + Args: + x0: Tensor of ``x_0`` samples. + alpha: Scaling factor from ``p(x_t | x_0)``. + sigma: Standard deviation from ``p(x_t | x_0)``. + inputs: Conditional inputs associated with each ``x_0`` sample. + f_neg_energy: Callable returning negative energies for importance + weighting. + + Returns: + Weights representing the likelihood of each ``x_0`` under the + energy function. + """ + with torch.no_grad(): + neg_energy = f_neg_energy(x0, inputs) + return neg_energy.exp() + + +def expand_dims(x, ndim): + """Reshape ``x`` to add ``ndim`` singleton dimensions at the end. + + Args: + x: Tensor to reshape. + ndim: Number of singleton dimensions to append. + + Returns: + Reshaped tensor with additional singleton dimensions. + """ + return x.reshape(x.shape[0], *((1, ) * ndim)) + + +@alf.configurable +class FlowMatching(SDEGenerator): + """Flow matching objective that regresses the true velocity field.""" + + def _get_x0_xt_x1(self, samples, batch_size, alpha, sigma): + """Sample ``(x_0, x_t, x_1)`` triplets for flow matching. + + Args: + samples: Optional tensor of ground-truth ``x_0`` values. + batch_size: Number of samples to draw. + alpha: Scaling factor from the SDE marginal. + sigma: Standard deviation from the SDE marginal. + + Returns: + Tuple ``(x0, xt, x1)`` consistent with the SDE marginals. + """ + if samples is None: + if isinstance(self._output_spec, alf.BoundedTensorSpec): + # For bounded data, we use a p1(.) that is uniform within the bounds. + xt = self._output_spec.sample((batch_size, )) + # x1 = self._output_spec.randn((batch_size,)) + + x1_max = (xt - alpha * self._min) / sigma + x1_min = (xt - alpha * self._max) / sigma + dist = TruncatedNormal(loc=torch.zeros_like(x1_min), + scale=torch.ones_like(x1_min), + lower_bound=x1_min, + upper_bound=x1_max) + x1 = dist.sample() + + # x1_max = x1_max.minimum(self._max) + # x1_min = x1_min.maximum(self._min) + # x1 = torch.rand((batch_size,) + self._output_spec.shape) * (x1_max - x1_min) + x1_min + + else: + xt = self._output_spec.randn((batch_size, )) + x1 = self._output_spec.randn((batch_size, )) + x0 = (xt - sigma * x1) / alpha + else: + x0 = samples + x1 = torch.randn((batch_size, ) + self._output_spec.shape, + device=samples.device) + xt = alpha * x0 + sigma * x1 + + return x0, xt, x1 + + def calc_loss(self, inputs, samples, f_neg_energy=None, sample_mask=None): + """Compute the squared error between predicted and true velocities. + + Args: + inputs: Conditional input nest. + samples: Optional tensor of ground-truth ``x_0`` values. + f_neg_energy: Optional energy function for importance weighting. + sample_mask: Optional mask applied to the per-element losses. + + Returns: + Tensor of per-sample velocity regression losses. + """ + leaf = alf.nest.extract_any_leaf_from_nest(inputs) + batch_size = leaf.shape[0] + device = leaf.device + t = self._time_sampler(batch_size, device=device) + alpha, sigma = self._sde.pt0(expand_dims(t, self._output_spec.ndim)) + x0, xt, x1 = self._get_x0_xt_x1(samples, batch_size, alpha, sigma) + vt = self._model((xt, inputs, t))[0] + alpha_dot, sigma_dot = self._sde.pt0_dot( + expand_dims(t, self._output_spec.ndim)) + cvt = alpha_dot * x0 + sigma_dot * x1 + diff = (vt - cvt)**2 + diff = self._apply_sample_mask(diff, sample_mask) + loss = diff.reshape(batch_size, -1).sum(-1) + + if f_neg_energy is not None: + p0 = self._calc_weights(x0, alpha, sigma, inputs, f_neg_energy) + loss = p0 * loss + + return loss + + def sample(self, inputs): + """Sample by integrating the learned velocity field. + + Args: + inputs: Conditional inputs used for generation. + + Returns: + Tensor of generated samples. + """ + batch_size = get_nest_batch_size(inputs) + if isinstance(self._output_spec, alf.BoundedTensorSpec): + noise = self._output_spec.sample((batch_size, )) + else: + noise = self._output_spec.randn((batch_size, )) + noise = noise * self._sde.pt0(torch.tensor(1))[1] + dt = 1 / self._steps + with torch.no_grad(): + for step in range(0, self._steps): + t = torch.full((batch_size, ), 1 - step * dt) + vt = self._model((noise, inputs, t))[0] + noise = noise - vt * dt + + return noise diff --git a/alf/algorithms/diffusion_algorithm_test.py b/alf/algorithms/diffusion_algorithm_test.py new file mode 100644 index 000000000..265fa18b6 --- /dev/null +++ b/alf/algorithms/diffusion_algorithm_test.py @@ -0,0 +1,220 @@ +# Copyright (c) 2025 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from alf.algorithms.diffusion_algorithm import * +from alf.algorithms.diffusion_model import SimpleMLPNet +from functools import partial + +import math +import torch +from functools import partial +import matplotlib.pyplot as plt + +torch.set_default_device('cuda') + + +def jointplot(x, y, bins=30, figure_size=(8, 8)): + """Draw scatter plot with marginal histograms aligned to the axes + """ + # scatter + marginal histograms + fig = plt.figure(figsize=figure_size) + gs = fig.add_gridspec(4, 4, wspace=0.05, hspace=0.05) + + # Main scatter plot + ax_scatter = fig.add_subplot(gs[1:4, 0:3]) + ax_scatter.plot(x, y, '.', alpha=0.5) + + # X histogram (above scatter, share x-axis) + ax_histx = fig.add_subplot(gs[0, 0:3], sharex=ax_scatter) + ax_histx.hist(x, bins=bins, color="gray") + plt.setp(ax_histx.get_xticklabels(), visible=False) # hide x labels here + + # Y histogram (to the right of scatter, share y-axis) + ax_histy = fig.add_subplot(gs[1:4, 3], sharey=ax_scatter) + ax_histy.hist(y, bins=bins, orientation='horizontal', color="gray") + plt.setp(ax_histy.get_yticklabels(), visible=False) # hide y labels here + + ax_histx.set_ylabel("count") + ax_histy.set_xlabel("count") + + plt.show() + + +def pdist(A, B): + """Distance between each pair of the two collections of inputs. + + Args: + A: (b, n, d) + B: (b, m, d) + Returns: + pairwise distances (b,n,m) + """ + A2 = (A * A).sum(dim=2, keepdim=True) # (b,n,1) + B2 = (B * B).sum(dim=2, keepdim=True) # (b,m,1) + # bmm: (B,n,d) x (B,d,m) -> (B,n,m) + M = torch.bmm(A, B.transpose(1, 2)) + D2 = A2 + B2.transpose(1, 2) - 2.0 * M + return D2.clamp_min_(0.0).sqrt_() + + +def energy_stat(sample_x, sample_y, size): + # https://en.wikipedia.org/wiki/Energy_distance#Testing_for_equal_distributions + # pairwise Euclidean norms + def _pdist(A, B): + # return (((A[:,None,:]-B[None,:,:])**2).sum(-1)).sqrt() + return pdist(A.unsqueeze(0), B.unsqueeze(0)).squeeze(0) + + X = sample_x(size) + Y = sample_y(size) + d_xy = _pdist(X, Y).mean() + d_xx = _pdist(X, X).mean() + d_yy = _pdist(Y, Y).mean() + return (2 * d_xy - d_xx - d_yy) / d_yy + + +class GMM: + name = 'gmm' + mu1 = torch.tensor([-0.5, -0.5]) + std1 = 0.25 + mu2 = torch.tensor([0.5, 0.5]) + std2 = 0.125 + prob1 = 0.3 + + def neg_energy(self, x, _): + logp1 = math.log(self.prob1) - 2 * math.log(self.std1) - 0.5 * (( + (x - self.mu1) / self.std1)**2).sum(-1) + logp2 = math.log(1 - self.prob1) - 2 * math.log(self.std2) - 0.5 * (( + (x - self.mu2) / self.std2)**2).sum(-1) + logp = torch.stack([logp1, logp2], dim=-1) + return logp.logsumexp(dim=-1) + + def sample(self, n): + r = torch.rand(n) + e = torch.randn(n, 2) + mu = torch.where(r.unsqueeze(-1) < self.prob1, self.mu1, self.mu2) + std = torch.where(r < self.prob1, self.std1, self.std2) + return mu + e * std.unsqueeze(-1) + + +def sample_f(n, generator): + inputs = torch.full((n, 1), 1.0) + return generator.sample(inputs) + + +def train(generator, dist, sample_based, batch_size=1024, ema=0.99): + if ema > 0: + averager = torch.optim.swa_utils.AveragedModel( + generator, + multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(ema)) + ema_generator = averager.module + else: + ema_generator = generator + + optimizer = torch.optim.Adam(generator.parameters(), lr=6e-4) + warmup_iters = 16 + warmup_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.001, + end_factor=1.0, + total_iters=warmup_iters, + ) + main_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, + total_iters=1000, + factor=1.0) + lr_schedule = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_scheduler, main_scheduler], + milestones=[warmup_iters]) + for i in range(1000): + samples = dist.sample(batch_size) if sample_based else None + f_neg_energy = dist.neg_energy if not sample_based else None + loss = generator.calc_loss(torch.full((batch_size, 1), 1.0), samples, + f_neg_energy).mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_schedule.step() + if ema > 0: + averager.update_parameters(generator) + + e_stat = energy_stat(partial(sample_f, generator=ema_generator), + dist.sample, 10000) + return ema_generator, e_stat + + +def run_setting(setting): + generator = setting['generator'](input_spec=alf.TensorSpec((1, )), + output_spec=alf.TensorSpec((2, )), + model_ctor=SimpleMLPNet, + sde=setting['sde'], + steps=setting.get('steps', 20)) + name = setting['name'] + sample_based = setting['sample_based'] + ema = setting.get('ema', 0.0) + print('Running', name) + e_stats = [] + repeat = 5 + + while len(e_stats) < repeat: + dist = GMM() + model, e_stat = train(generator, + dist, + sample_based, + batch_size=1024, + ema=ema) + print('train', len(e_stats), dist.name, 'energy_stat', e_stat.item()) + if e_stat.isfinite(): + e_stats.append(e_stat.item()) + + e_stats = torch.tensor(e_stats) + print('energy stat mean:', + e_stats.mean().item(), "std:", + e_stats.std().item()) + + x = sample_f(2000, generator).cpu().numpy() + jointplot(x[:, 0], x[:, 1], bins=50) + plt.savefig(f'{setting["name"]}.png') + return model, e_stats.mean().item(), e_stats.std().item() + + +settings = [ + dict(name='ot_sample_fm_ema', + sde=OTSDE(), + generator=FlowMatching, + sample_based=True, + ema=0.99), + dict(name='rt_sample_fm_ema', + sde=RTSDE(), + generator=FlowMatching, + sample_based=True, + ema=0.99), + dict(name='ot_fm_ema', + sde=OTSDE(), + generator=FlowMatching, + sample_based=False, + ema=0.99), + dict(name='rt_fm_ema', + sde=RTSDE(), + generator=FlowMatching, + sample_based=False, + ema=0.99), +] + +if __name__ == '__main__': + results = [] + for setting in settings: + results.append(run_setting(setting)) + for setting, result in zip(settings, results): + model, e_stat_mean, e_stat_std = result + print(setting['name'], e_stat_mean, e_stat_std) diff --git a/alf/algorithms/diffusion_model.py b/alf/algorithms/diffusion_model.py new file mode 100644 index 000000000..77d45491e --- /dev/null +++ b/alf/algorithms/diffusion_model.py @@ -0,0 +1,385 @@ +# Copyright (c) 2025 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import alf + + +class Concat(torch.nn.Module): + """Module that concatenates a sequence of tensors along the last axis.""" + + def forward(self, x): + return torch.cat(x, dim=-1) + + +# Timestep embedding used in the DDPM++ and ADM architectures. +class PositionalEmbedding(torch.nn.Module): + """Positional time embedding using deterministic sinusoidal features.""" + + def __init__(self, num_channels, max_positions=10000, endpoint=False): + """Create a positional embedding module. + + Args: + num_channels: Total number of output channels for the embedding. + max_positions: Maximum number of positions used to scale the + frequencies of the sinusoidal embedding. + endpoint: Whether the highest frequency should reach the endpoint + ``1 / max_positions`` exactly. + """ + super().__init__() + self.num_channels = num_channels + self.max_positions = max_positions + self.endpoint = endpoint + + def forward(self, x): + freqs = torch.arange(start=0, + end=self.num_channels // 2, + dtype=torch.float32, + device=x.device) + freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) + freqs = (1 / self.max_positions)**freqs + x = x.ger(freqs.to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x + + +# Timestep embedding used in the NCSN++ architecture. +class FourierEmbedding(torch.nn.Module): + """Random Fourier feature based time embedding.""" + + def __init__(self, num_channels, scale=16): + """Create a Fourier embedding module with random frequencies. + + Args: + num_channels: Total number of output channels for the embedding. + scale: Standard deviation used when sampling the random base + frequencies. + """ + super().__init__() + self.register_buffer('freqs', torch.randn(num_channels // 2) * scale) + + def forward(self, x): + x = x.ger((2 * math.pi * self.freqs).to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x + + +class MLPNet(alf.networks.Network): + """Multi-layer perceptron used to predict scores or velocities for SDEs.""" + + def __init__(self, + input_spec, + output_spec, + mean_flow=False, + hidden_dim=256, + time_embedding_type='positional'): + """Construct the MLP used for score or velocity prediction. + + Args: + input_spec: Specification of the conditional input tensor. + output_spec: Specification describing the generated tensor. + mean_flow: Whether the network receives an additional mean-flow + horizon input. + hidden_dim: Feature dimension used throughout the hidden layers and + embeddings. + time_embedding_type: Chooses between ``'positional'`` and + ``'fourier'`` time embeddings. + """ + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec(())) + if mean_flow: + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec( + ())) + (alf.TensorSpec(()), ) + super().__init__(input_tensor_spec) + self._mean_flow = mean_flow + k = 4 if mean_flow else 3 + self._model = torch.nn.Sequential( + Concat(), + torch.nn.Linear(k * hidden_dim, hidden_dim), + torch.nn.GELU(), + torch.nn.Linear(hidden_dim, hidden_dim), + torch.nn.GELU(), + torch.nn.Linear(hidden_dim, hidden_dim), + torch.nn.GELU(), + torch.nn.Linear(hidden_dim, output_spec.numel), + ) + self._time_embedding = (PositionalEmbedding( + num_channels=hidden_dim, endpoint=True) if time_embedding_type + == 'positional' else FourierEmbedding( + num_channels=hidden_dim)) + self._cond_embedding = torch.nn.Linear(in_features=input_spec.numel, + out_features=hidden_dim, + bias=False) + self._x_embedding = torch.nn.Linear(in_features=output_spec.numel, + out_features=hidden_dim, + bias=False) + + def forward(self, inputs, state=()): + x, cond, t = inputs[:3] + h = inputs[3] if self._mean_flow else None + embeddings = [self._x_embedding(x), self._time_embedding(t)] + if h is not None: + embeddings.append(self._time_embedding(h)) + embeddings.append(self._cond_embedding(cond)) + x = self._model(embeddings) + return x, state + + +class DiTBlock(torch.nn.Module): + """Transformer block with adaptive layer normalization conditioning.""" + + def __init__(self, d_model, d_ff, cond_dim, num_heads): + """Initialize the DiT block. + + Args: + d_model: Transformer hidden size. + d_ff: Hidden size of the feed-forward network inside the block. + cond_dim: Dimensionality of the conditioning vector. + num_heads: Number of attention heads. + """ + super().__init__() + self._norm1 = torch.nn.LayerNorm(d_model, elementwise_affine=False) + self._attn = torch.nn.MultiheadAttention(d_model, + num_heads=num_heads, + batch_first=True) + self._norm2 = torch.nn.LayerNorm(d_model, elementwise_affine=False) + self._fc1 = alf.layers.FC(d_model, + d_ff, + activation=torch.nn.functional.silu) + self._fc2 = alf.layers.FC(d_ff, d_model) + self._cond_mlp = torch.nn.Sequential( + # torch.nn.SiLU(), + alf.layers.FC(cond_dim, 6 * d_model, use_bias=True)) + + def forward(self, inputs): + x, cond = inputs + scale1, shift1, gate1, scale2, shift2, gate2 = self._cond_mlp( + cond).unsqueeze(1).chunk(6, dim=-1) + h = self._norm1(x) + h = torch.addcmul(shift1, h, 1 + scale1) + attn_output, _ = self._attn(h, h, h) + x.addcmul(attn_output, gate1) + h = self._norm2(x) + h = torch.addcmul(shift2, h, 1 + scale2) + h = self._fc1(h) + h = self._fc2(h) + return x.addcmul(h, gate2) + + +class DiT(alf.networks.Network): + """Diffusion Transformer architecture for sequence-shaped outputs.""" + + def __init__(self, + input_spec, + output_spec, + d_model=128, + num_heads=4, + num_blocks=2, + time_embedding_dim=32, + mean_flow=False): + """Create a DiT network tailored for ALF tensor specs. + + Args: + input_spec: Specification of conditioning inputs. + output_spec: Specification of the generated tensor shaped as a + sequence. + d_model: Transformer hidden size. + num_heads: Number of attention heads in each block. + num_blocks: Number of stacked transformer blocks. + time_embedding_dim: Dimensionality of sinusoidal time embeddings. + mean_flow: Whether the network receives the extra mean-flow + horizon input. + """ + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec(())) + if mean_flow: + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec( + ())) + (alf.TensorSpec(()), ) + super().__init__(input_tensor_spec) + + self._blocks = torch.nn.ModuleList() + assert output_spec.ndim == 2, "DiT only supports 2D output" + self._in_proj = alf.layers.FC(output_spec.shape[1], d_model) + length = output_spec.shape[0] + self._pe = torch.nn.Parameter(torch.zeros(1, length, d_model)) + self._out_proj = alf.layers.FC(d_model, output_spec.shape[1]) + cond_dim = sum(spec.numel for spec in alf.nest.flatten(input_spec)) + cond_dim += time_embedding_dim * (2 if mean_flow else 1) + self._time_embedding = PositionalEmbedding( + num_channels=time_embedding_dim, endpoint=True) + self._cond_mlp = torch.nn.Sequential( + alf.layers.FC(cond_dim, + d_model, + activation=torch.nn.functional.silu), + alf.layers.FC(d_model, + d_model, + activation=torch.nn.functional.silu), + ) + for _ in range(num_blocks): + self._blocks.append( + DiTBlock(d_model, 4 * d_model, d_model, num_heads)) + + def forward(self, inputs, state=()): + x, cond, t = inputs[:3] + assert x.ndim == 3, "DiT only supports 3D input" + h = inputs[3] if len(inputs) == 4 else None + embeddings = alf.nest.flatten(cond) + embeddings.append(self._time_embedding(t)) + if h is not None: + embeddings.append(self._time_embedding(h)) + cond = torch.cat(embeddings, dim=-1) + cond = self._cond_mlp(cond) + + x = self._in_proj(x) + self._pe + for block in self._blocks: + x = block((x, cond)) + x = self._out_proj(x) + return x, state + + +class AdaLnBlock(torch.nn.Module): + """Residual block with adaptive layer normalization conditioning.""" + + def __init__(self, in_dim, out_dim, hidden_dim, cond_dim): + """Configure the adaptive layer normalization block. + + Args: + in_dim: Size of the input feature dimension. + out_dim: Size of the output feature dimension. + hidden_dim: Hidden dimension for the internal MLP. + cond_dim: Dimensionality of the conditioning vector applied to AdaLN. + """ + super().__init__() + self._norm = torch.nn.LayerNorm(in_dim, elementwise_affine=False) + self._fc1 = alf.layers.FC(in_dim, + hidden_dim, + activation=torch.nn.functional.silu) + self._fc2 = alf.layers.FC(hidden_dim, out_dim) + self._ada = torch.nn.Sequential( + # torch.nn.SiLU(), + alf.layers.FC(cond_dim, 3 * in_dim, use_bias=True)) + + def forward(self, inputs): + x, cond = inputs + h = self._norm(x) + scale, shift, gate = self._ada(cond).chunk(3, dim=-1) + h = h * (1 + scale) + shift + h = self._fc1(h) + h = self._fc2(h) + return x + h * gate + + +class AdaNet(alf.networks.Network): + """Fully-connected network with AdaLN blocks for diffusion modeling.""" + + def __init__(self, + input_spec, + output_spec, + d_model=256, + num_blocks=2, + time_embedding_dim=32, + mean_flow=False): + """Create an AdaNet model for diffusion-based generation. + + Args: + input_spec: Specification of conditioning inputs. + output_spec: Specification of the generated tensor. + d_model: Hidden size of the AdaLN blocks. + num_blocks: Number of stacked AdaLN residual blocks. + time_embedding_dim: Dimensionality of sinusoidal embeddings. + mean_flow: Whether to include the additional mean-flow horizon. + """ + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec(())) + if mean_flow: + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec( + ())) + (alf.TensorSpec(()), ) + super().__init__(input_tensor_spec) + + self._blocks = torch.nn.ModuleList() + self._in_proj = alf.layers.FC(output_spec.numel, d_model) + self._out_proj = alf.layers.FC(d_model, output_spec.numel) + cond_dim = sum(spec.numel for spec in alf.nest.flatten(input_spec)) + cond_dim += time_embedding_dim * (2 if mean_flow else 1) + self._time_embedding = PositionalEmbedding( + num_channels=time_embedding_dim, endpoint=True) + self._cond_mlp = torch.nn.Sequential( + alf.layers.FC(cond_dim, + d_model, + activation=torch.nn.functional.silu), + alf.layers.FC(d_model, + d_model, + activation=torch.nn.functional.silu), + ) + for _ in range(num_blocks): + self._blocks.append( + AdaLnBlock(d_model, d_model, 4 * d_model, d_model)) + + def forward(self, inputs, state=()): + x, cond, t = inputs[:3] + x_shape = x.shape + x = x.reshape(x.shape[0], -1) + h = inputs[3] if len(inputs) == 4 else None + embeddings = alf.nest.flatten(cond) + embeddings.append(self._time_embedding(t)) + if h is not None: + embeddings.append(self._time_embedding(h)) + cond = torch.cat(embeddings, dim=-1) + cond = self._cond_mlp(cond) + + x = self._in_proj(x) + for block in self._blocks: + x = block((x, cond)) + x = self._out_proj(x) + x = x.reshape(*x_shape) + return x, state + + +class SimpleMLPNet(alf.networks.Network): + """Compact MLP baseline for score or velocity prediction.""" + + def __init__(self, input_spec, output_spec, mean_flow=False): + """Construct a simple baseline MLP network. + + Args: + input_spec: Specification of conditioning inputs. + output_spec: Specification of the generated tensor. + mean_flow: Whether to include a mean-flow horizon input. + """ + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec(())) + if mean_flow: + input_tensor_spec = (output_spec, input_spec, alf.TensorSpec( + ())) + (alf.TensorSpec(()), ) + super().__init__(input_tensor_spec) + activation = torch.nn.GELU + + self._model = torch.nn.Sequential( + Concat(), + torch.nn.Linear( + output_spec.numel + input_spec.numel + (2 if mean_flow else 1), + 256), + activation(), + torch.nn.Linear(256, 256), + activation(), + torch.nn.Linear(256, 256), + activation(), + torch.nn.Linear(256, output_spec.numel), + ) + + def forward(self, inputs, state=()): + x, cond, t = inputs[:3] + h = inputs[3] if len(inputs) == 4 else None + embeddings = [x, cond, t.unsqueeze(-1)] + if h is not None: + embeddings.append(h.unsqueeze(-1)) + x = self._model(embeddings) + return x, state diff --git a/alf/algorithms/mcts_algorithm.py b/alf/algorithms/mcts_algorithm.py index 291e7789a..7b5aeccc4 100644 --- a/alf/algorithms/mcts_algorithm.py +++ b/alf/algorithms/mcts_algorithm.py @@ -227,6 +227,7 @@ def _add_node(name: str, properties: dict): @alf.configurable class MCTSAlgorithm(OffPolicyAlgorithm): + # codespell:ignore-begin r"""Monte-Carlo Tree Search algorithm. The code largely follows the pseudocode of @@ -300,6 +301,8 @@ class MCTSAlgorithm(OffPolicyAlgorithm): extend these k' paths are most promising according to the UCB scores. """ + # codespell:ignore-end + def __init__( self, observation_spec, diff --git a/alf/algorithms/muzero_representation_learner.py b/alf/algorithms/muzero_representation_learner.py index f4bb97d45..4d9ce3a50 100644 --- a/alf/algorithms/muzero_representation_learner.py +++ b/alf/algorithms/muzero_representation_learner.py @@ -60,6 +60,7 @@ @alf.configurable class MuzeroRepresentationImpl(OffPolicyAlgorithm): + # codespell:ignore-begin """MuZero-style Representation Learner. MuZero is described in the paper: @@ -85,6 +86,8 @@ class MuzeroRepresentationImpl(OffPolicyAlgorithm): """ + # codespell:ignore-end + def __init__( self, observation_spec, diff --git a/alf/algorithms/taac_algorithm.py b/alf/algorithms/taac_algorithm.py index dc7dbe86f..2c6aa6905 100644 --- a/alf/algorithms/taac_algorithm.py +++ b/alf/algorithms/taac_algorithm.py @@ -230,7 +230,7 @@ class TaacAlgorithmBase(OffPolicyAlgorithm): In a nutsell, for inference TAAC adds a second stage that chooses between a candidate trajectory :math:`\hat{\tau}` output by an SAC actor and the previous trajectory :math:`\tau^-`. For policy evaluation, TAAC uses a compare-through Q - operator for TD backup by re-using state-action sequences that have shared + operator for TD backup by reusing state-action sequences that have shared actions between rollout and training. For policy improvement, the new actor gradient is approximated by multiplying a scaling factor to the :math:`\frac{\partial Q}{\partial a}` term in the original SAC’s actor diff --git a/alf/summary/render.py b/alf/summary/render.py index 5b0653fb4..94e179d26 100644 --- a/alf/summary/render.py +++ b/alf/summary/render.py @@ -266,7 +266,7 @@ def is_rendering_enabled(): def _rendering_wrapper(rendering_func): """A wrapper function to gate the rendering function based on if rendering is enabled, and if yes generate a scoped rendering identifier before - calling the rendering function. It re-uses the scope stack in ``alf.summary.summary_ops.py``. + calling the rendering function. It reuses the scope stack in ``alf.summary.summary_ops.py``. """ @functools.wraps(rendering_func) diff --git a/alf/utils/losses.py b/alf/utils/losses.py index 68521b75e..809f4bc27 100644 --- a/alf/utils/losses.py +++ b/alf/utils/losses.py @@ -129,7 +129,7 @@ def iqn_huber_loss(value: torch.Tensor, is between this and the target. target: the time-major tensor for return, this is used as the target for computing the loss. - next_delta_tau: the sampled increments of the probability for the input + next_delta_tau: the sampled increments of the probability for the input of the quantile function of the target critics. fixed_tau: the fixed increments of probability, for non iqn style quantile regression. @@ -166,7 +166,7 @@ def iqn_huber_loss(value: torch.Tensor, error = loss_fn(diff) if iqn_tau: if diff.ndim - tau_hat.ndim > 1: - # For multidimentional reward: + # For multidimensional reward: # diff is of shape [T or T-1, B, reward_dim, n_quantiles, n_quantiles] # while tau_hat and next_delta_tau have shape [T or T-1, B, n_quantiles] tau_hat = tau_hat.unsqueeze(-2)