diff --git a/.github/workflows/merge-request.yaml b/.github/workflows/merge-request.yaml index 8e227ad..c759041 100644 --- a/.github/workflows/merge-request.yaml +++ b/.github/workflows/merge-request.yaml @@ -11,6 +11,6 @@ jobs: test: uses: ./.github/workflows/run-tests.yaml with: - install_string: .[develop] + install_string: .[develop,diptv3] build_site: uses: ./.github/workflows/build-site.yaml diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 5426c1d..909cf33 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -28,9 +28,18 @@ jobs: python-version: '3.10' ############################################## + - name: Install torch + run: pip install torch==2.5.1 torchvision==0.20.1 + + - name: Install torch_scatter + run: pip install torch_scatter -f https://data.pyg.org/whl/torch-2.5.0+cpu.html + - name: Install package run: pip install "${{ inputs.install_string }}" + - name: Install spconv. + run: pip install spconv + - name: Code Quality run: python -m black src/ tests/ --check diff --git a/pyproject.toml b/pyproject.toml index 4ed75c6..9221902 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ explicit_package_bases = true [[tool.mypy.overrides]] module = [ "addict.*", + "DiT.*", "PointTransformerV3.*", "spconv.*", "timm.*", @@ -77,4 +78,4 @@ module = [ ignore_missing_imports = true [tool.pyright] -extraPaths = ['./third_party'] +extraPaths = ['./third_party', './src'] diff --git a/src/rpad/nets/diptv3.py b/src/rpad/nets/diptv3.py index 6c79295..d6ededa 100644 --- a/src/rpad/nets/diptv3.py +++ b/src/rpad/nets/diptv3.py @@ -1,10 +1,12 @@ import math from functools import partial +from typing import Optional import spconv.pytorch as spconv import torch import torch_scatter from addict import Dict +from DiT.models import TimestepEmbedder from PointTransformerV3.model import ( MLP, Embedding, @@ -19,7 +21,7 @@ from torch import nn -class SerializedPooling(PointModule): +class DiPTv3SerializedPooling(PointModule): def __init__( self, in_channels, @@ -60,7 +62,7 @@ def forward(self, point: Point): "serialized_depth", }.issubset( point.keys() - ), "Run point.serialization() point cloud before SerializedPooling" + ), "Run point.serialization() point cloud before DiPTv3SerializedPooling" code = point.serialized_code >> pooling_depth * 3 code_, cluster, counts = torch.unique( @@ -106,6 +108,10 @@ def forward(self, point: Point): serialized_inverse=inverse, serialized_depth=point.serialized_depth - pooling_depth, batch=point.batch[head_indices], + # For DiPTv3: keep track of these as well. + t=point.t, + N=point.N, + t_emb=point.t_emb, ) if "condition" in point.keys(): @@ -125,12 +131,13 @@ def forward(self, point: Point): return point -class Block(PointModule): +class DiPTv3Block(PointModule): def __init__( self, channels, num_heads, patch_size=48, + t_hidden_size=128, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, @@ -150,6 +157,7 @@ def __init__( super().__init__() self.channels = channels self.pre_norm = pre_norm + assert pre_norm, "Only support pre-normalization" self.cpe = PointSequential( spconv.SubMConv3d( @@ -192,30 +200,79 @@ def __init__( DropPath(drop_path) if drop_path > 0.0 else nn.Identity() ) + # Adapted from the DiT paper - add adaLN modulation. + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(t_hidden_size, 6 * channels, bias=True), + ) + self.t_hidden_size = t_hidden_size + + @staticmethod + def modulate(x, shift, scale): + return x * (1 + scale) + shift + def forward(self, point: Point): + # Time embedding should still be batched. + assert len(point.t_emb.shape) == 2 + + # First, compute adaLN modulation. + adaLN_params = self.adaLN_modulation(point.t_emb) + + # Then, repeat with batching in mind. + adaLN_params = adaLN_params[point.batch] + + # Compute the adaLN parameters. + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = adaLN_params.chunk(6, dim=1) + shortcut = point.feat point = self.cpe(point) point.feat = shortcut + point.feat shortcut = point.feat if self.pre_norm: point = self.norm1(point) - point = self.drop_path(self.attn(point)) + + # Apply adaLN modulation. + point.feat = self.modulate(point.feat, shift_msa, scale_msa) + + # Attention. + point = self.attn(point) + + # Modulation. + point.feat = gate_msa * point.feat + + # Residual connection. + point = self.drop_path(point) point.feat = shortcut + point.feat - if not self.pre_norm: - point = self.norm1(point) shortcut = point.feat if self.pre_norm: point = self.norm2(point) - point = self.drop_path(self.mlp(point)) + + # Apply adaLN modulation. + point.feat = self.modulate(point.feat, shift_mlp, scale_mlp) + + # MLP. + point = self.mlp(point) + + # Modulation. + point.feat = gate_mlp * point.feat + + # Residual connection. + point = self.drop_path(point) point.feat = shortcut + point.feat - if not self.pre_norm: - point = self.norm2(point) + point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) return point -class PointTransformerV3(PointModule): +class DiPTv3(PointModule): def __init__( self, in_channels=6, @@ -229,6 +286,7 @@ def __init__( dec_channels=(64, 64, 128, 256), dec_num_head=(4, 4, 8, 16), dec_patch_size=(1024, 1024, 1024, 1024), + t_hidden_size=128, mlp_ratio=4, qkv_bias=True, qk_scale=None, @@ -254,6 +312,7 @@ def __init__( self.order = [order] if isinstance(order, str) else order self.cls_mode = cls_mode self.shuffle_orders = shuffle_orders + self.t_hidden_size = t_hidden_size assert self.num_stages == len(stride) + 1 assert self.num_stages == len(enc_depths) @@ -298,6 +357,8 @@ def __init__( act_layer=act_layer, ) + self.t_embedder = TimestepEmbedder(t_hidden_size) + # encoder enc_drop_path = [ x.item() for x in torch.linspace(0, drop_path, sum(enc_depths)) @@ -310,7 +371,7 @@ def __init__( enc = PointSequential() if s > 0: enc.add( - SerializedPooling( + DiPTv3SerializedPooling( in_channels=enc_channels[s - 1], out_channels=enc_channels[s], stride=stride[s - 1], @@ -321,10 +382,11 @@ def __init__( ) for i in range(enc_depths[s]): enc.add( - Block( + DiPTv3Block( channels=enc_channels[s], num_heads=enc_num_head[s], patch_size=enc_patch_size[s], + t_hidden_size=t_hidden_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, @@ -371,10 +433,11 @@ def __init__( ) for i in range(dec_depths[s]): dec.add( - Block( + DiPTv3Block( channels=dec_channels[s], num_heads=dec_num_head[s], patch_size=dec_patch_size[s], + t_hidden_size=t_hidden_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, @@ -408,7 +471,171 @@ def forward(self, data_dict): point.sparsify() point = self.embedding(point) + point.t_emb = self.t_embedder(point.t) point = self.enc(point) if not self.cls_mode: point = self.dec(point) return point + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, t_hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(t_hidden_size, 2 * hidden_size, bias=True) + ) + + @staticmethod + def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + def forward(self, x, c) -> torch.Tensor: + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = self.modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x # type: ignore + + +class DiPTv3Adapter(nn.Module): + def __init__(self, model, grid_size=0.001, final_dimension=6): + super().__init__() + self.model = model + self.grid_size = grid_size + + final_input_size = model.dec[-1][-1].mlp[0].fc2.out_features + self.final_layer = FinalLayer( + final_input_size, model.t_hidden_size, final_dimension + ) + + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + # w = self.x_embedder.proj.weight.data + # nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + # nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + # nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.model.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.model.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for enc in self.model.enc: + for block in enc: + if hasattr(block, "adaLN_modulation"): + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + for dec in self.model.dec: + for block in dec: + if hasattr(block, "adaLN_modulation"): + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + x0: torch.Tensor, + y: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Denoise the input (a point cloud). + """ + + # Permute to (B, N, C) + assert len(x.shape) == 3 + assert x.shape[1] == 3 + + x = x.permute(0, 2, 1) # type: ignore + x0 = x0.permute(0, 2, 1) # type: ignore + if y is not None: + y = y.permute(0, 2, 1) + B, N, C_in = x.shape + B0, N0, D = x0.shape + B1, N1, D1 = y.shape if y is not None else (0, 0, 0) + + assert t.shape[0] == B + + if y is not None: + # Concatenate y and x0 on the second-to-last dim (N) + x0 = torch.cat([x0, y], dim=1) + # Add zeros to the x tensor. + x = torch.cat([x, torch.zeros(B, N1, C_in, device=x.device)], dim=1) + + # Create a labels tensor (one-hot encoded) for x0 and y. + # This is a tensor of shape (B, N0 + N1, 2). + labels = torch.cat( + [ + torch.zeros(B, N0, 1, device=x.device), + torch.ones(B, N1, 1, device=x.device), + ], + dim=1, + ) + # Make one-hot + labels = torch.cat([1 - labels, labels], dim=-1) + else: + labels = None + + # Reshape to (BxN, C), and create a batch vector with the indices. + # Right now assuming that the batch has same number of points for each example. + x0_flat = x0.reshape(-1, D) + x_flat = x.reshape(-1, C_in) + if labels is not None: + labels = labels.reshape(-1, labels.shape[-1]) + batch_ixs = torch.repeat_interleave( + torch.arange(B, device=x0.device), x0.shape[1] + ) + + if y is not None: + cat_feat = torch.cat([x0_flat, x_flat, labels], dim=-1) + else: + cat_feat = torch.cat([x0_flat, x_flat], dim=-1) + + data = Point( + coord=x0_flat, + feat=cat_feat, + batch=batch_ixs, + grid_size=self.grid_size, # Not sure what to do here... + t=t, + N=N0 + N1, + ) + pred = self.model(data) + C_out = pred.feat.shape[-1] + + # Only need the action points. + # Reshape back to (B, N, C) + feats = pred.feat.reshape(B, -1, C_out) + + # Final linear layer to get the output. + feats = self.final_layer(feats, pred.t_emb) + + feats = feats[:, :N, :] + + # Permute back to what we expect. + feats = feats.permute(0, 2, 1) + return feats # type: ignore diff --git a/tests/diptv3_test.py b/tests/diptv3_test.py new file mode 100644 index 0000000..2918c2b --- /dev/null +++ b/tests/diptv3_test.py @@ -0,0 +1,35 @@ +import pytest +import torch + +from rpad.nets.diptv3 import DiPTv3, DiPTv3Adapter + + +# Skip if cuda not detected. +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_diptv3(): + # Create a model + final_dimension = 11 + B = 3 + N = 256 + model = DiPTv3Adapter( + model=DiPTv3(in_channels=6), final_dimension=final_dimension + ).cuda() + + # Create a random input tensor + x = torch.rand(B, 3, N).cuda().float() + t = ( + torch.rand( + B, + ) + .cuda() + .float() + ) + x0 = torch.rand(B, 3, N).cuda().float() + + # Run the model + output = model(x, t, x0) + + breakpoint() + + # Check the output shape + assert output.shape == (B, final_dimension, N) diff --git a/tests/simple_test.py b/tests/simple_test.py deleted file mode 100644 index f010262..0000000 --- a/tests/simple_test.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_simple(): - assert 1 == 1 diff --git a/third_party/DiT/__init__.py b/third_party/DiT/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/third_party/DiT/models.py b/third_party/DiT/models.py new file mode 100644 index 0000000..c90eeba --- /dev/null +++ b/third_party/DiT/models.py @@ -0,0 +1,370 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import numpy as np +import math +from timm.models.vision_transformer import PatchEmbed, Attention, Mlp + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +################################################################################# +# Core DiT Model # +################################################################################# + +class DiTBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + + self.blocks = nn.ModuleList([ + DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) + ]) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + + def forward(self, x, t, y): + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + """ + x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(t) # (N, D) + y = self.y_embedder(y, self.training) # (N, D) + c = t + y # (N, D) + for block in self.blocks: + x = block(x, c) # (N, T, D) + x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + def forward_with_cfg(self, x, t, y, cfg_scale): + """ + Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, y) + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# DiT Configs # +################################################################################# + +def DiT_XL_2(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) + +def DiT_XL_4(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) + +def DiT_XL_8(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) + +def DiT_L_2(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) + +def DiT_L_4(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) + +def DiT_L_8(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) + +def DiT_B_2(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) + +def DiT_B_4(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) + +def DiT_B_8(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) + +def DiT_S_2(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) + +def DiT_S_4(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) + +def DiT_S_8(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) + + +DiT_models = { + 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, + 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, + 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, + 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, +} diff --git a/third_party/PointTransformerV3/serialization/hilbert.py b/third_party/PointTransformerV3/serialization/hilbert.py index 682be19..56762da 100644 --- a/third_party/PointTransformerV3/serialization/hilbert.py +++ b/third_party/PointTransformerV3/serialization/hilbert.py @@ -140,7 +140,7 @@ def encode(locs, num_dims, num_bits): # Treat the location integers as 64-bit unsigned and then split them up into # a sequence of uint8s. Preserve the association by dimension. - locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1) + locs_uint8 = locs.long().contiguous().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1) # Now turn these into bits and truncate to num_bits. gray = (