-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0da82e8
commit 8dd65be
Showing
5 changed files
with
209 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .atlas import * | ||
from .blocks3d import * | ||
from .completion_net import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Inspired by https://github.com/magicleap/Atlas | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from make_it_dense.models.blocks3d import FeatureExtractor, Unet3D | ||
from make_it_dense.utils.config import MkdConfig | ||
|
||
|
||
class AtlasNet(nn.Module): | ||
def __init__(self, config: MkdConfig): | ||
super().__init__() | ||
self.config = config | ||
self.voxel_sizes = self.config.fusion.voxel_sizes | ||
self.occ_th = self.config.model.occ_th | ||
self.f_maps = self.config.model.f_maps | ||
self.layers_down = self.config.model.layers_down | ||
self.layers_up = self.config.model.layers_up | ||
|
||
# Network | ||
self.feature_extractor = FeatureExtractor(channels=self.f_maps[0]) | ||
self.unet = Unet3D( | ||
channels=self.f_maps, | ||
layers_down=self.layers_down, | ||
layers_up=self.layers_up, | ||
) | ||
self.decoders = nn.ModuleList( | ||
[nn.Conv3d(c, 1, 1, bias=False) for c in self.f_maps[:-1]][::-1] | ||
) | ||
|
||
def forward(self, xs): | ||
feats = self.feature_extractor(xs) | ||
out = self.unet(feats) | ||
|
||
output = {} | ||
mask_occupied = [] | ||
for i, (decoder, x) in enumerate(zip(self.decoders, out)): | ||
# regress the TSDF | ||
tsdf = torch.tanh(decoder(x)) * 1.05 | ||
|
||
# use previous scale to sparsify current scale | ||
if i > 0: | ||
tsdf_prev = output[f"out_tsdf_{self.voxel_sizes[i - 1]}"] | ||
tsdf_prev = F.interpolate(tsdf_prev, scale_factor=2) | ||
mask_truncated = tsdf_prev.abs() >= self.occ_th[i - 1] | ||
tsdf[mask_truncated] = tsdf_prev[mask_truncated].sign() | ||
mask_occupied.append(~mask_truncated) | ||
output[f"out_tsdf_{ self.voxel_sizes[i]}"] = tsdf | ||
return output, mask_occupied |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from typing import List | ||
|
||
import torch.nn as nn | ||
|
||
|
||
class ResNetBlock3d(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super().__init__() | ||
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False) | ||
self.bn1 = nn.BatchNorm3d(out_channels) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False) | ||
self.bn2 = nn.BatchNorm3d(out_channels) | ||
|
||
def forward(self, x): | ||
identity = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
out += identity | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
|
||
class Unet3D(nn.Module): | ||
def __init__(self, channels: List[int], layers_down: List[int], layers_up: List[int]): | ||
super().__init__() | ||
|
||
self.layers_down = nn.ModuleList() | ||
self.layers_down.append(ResNetBlock3d(channels[0], channels[0])) | ||
for i in range(1, len(channels)): | ||
layer = [ | ||
nn.Conv3d( | ||
channels[i - 1], channels[i], kernel_size=3, stride=2, padding=1, bias=False | ||
), | ||
nn.BatchNorm3d(channels[i]), | ||
nn.ReLU(inplace=True), | ||
] | ||
# Do we need 4 resnet blocks here? | ||
layer += [ResNetBlock3d(channels[i], channels[i]) for _ in range(layers_down[i])] | ||
self.layers_down.append(nn.Sequential(*layer)) | ||
|
||
channels = channels[::-1] | ||
self.layers_up_conv = nn.ModuleList() | ||
for i in range(1, len(channels)): | ||
self.layers_up_conv.append( | ||
nn.Sequential( | ||
nn.ConvTranspose3d( | ||
channels[i - 1], channels[i], kernel_size=2, stride=2, bias=False | ||
), | ||
nn.BatchNorm3d(channels[i]), | ||
nn.ReLU(inplace=True), | ||
nn.Conv3d(channels[i], channels[i], kernel_size=3, padding=1, bias=False), | ||
nn.BatchNorm3d(channels[i]), | ||
nn.ReLU(inplace=True), | ||
) | ||
) | ||
|
||
self.layers_up_res = nn.ModuleList() | ||
for i in range(1, len(channels)): | ||
layer = [ResNetBlock3d(channels[i], channels[i]) for _ in range(layers_up[i - 1])] | ||
self.layers_up_res.append(nn.Sequential(*layer)) | ||
|
||
def forward(self, x): | ||
xs = [] | ||
for layer in self.layers_down: | ||
x = layer(x) | ||
xs.append(x) | ||
|
||
xs.reverse() | ||
out = [] | ||
for i in range(len(self.layers_up_conv)): | ||
x = self.layers_up_conv[i](x) | ||
x = (x + xs[i + 1]) / 2.0 | ||
x = self.layers_up_res[i](x) | ||
out.append(x) | ||
|
||
return out | ||
|
||
|
||
class FeatureExtractor(nn.Module): | ||
"""Extract features from a TSDF volume withouth chaning the size.""" | ||
|
||
def __init__(self, channels=4): | ||
super().__init__() | ||
self.model = ResNetBlock3d(1, channels) | ||
|
||
def forward(self, x): | ||
return self.model(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import pytorch_lightning as pl | ||
import torch | ||
|
||
from make_it_dense.loss import SDFLoss | ||
from make_it_dense.models.atlas import AtlasNet | ||
from make_it_dense.utils.config import MkdConfig | ||
|
||
|
||
class CompletionNet(pl.LightningModule): | ||
def __init__(self, config: MkdConfig): | ||
super().__init__() | ||
self.config = config | ||
self.model = AtlasNet(self.config) | ||
self.loss = SDFLoss(self.config) | ||
self.lr = self.config.optimization.lr | ||
self.voxel_sizes = self.config.fusion.voxel_sizes | ||
self.voxel_trunc = self.config.fusion.voxel_trunc | ||
self.save_hyperparameters(config) | ||
|
||
def forward(self, x): | ||
return self.model(x) | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.Adam( | ||
self.parameters(), | ||
lr=self.lr, | ||
weight_decay=self.config.optimization.weight_decay, | ||
) | ||
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) | ||
return { | ||
"optimizer": optimizer, | ||
"lr_scheduler": lr_scheduler, | ||
"monitor": "train/train_loss", | ||
} | ||
|
||
def step(self, batch, batch_idx, mode: str): | ||
inputs, targets = batch | ||
outputs, masks = self(inputs["nodes"]) | ||
pred_tsdf_t = outputs["out_tsdf_10"] | ||
|
||
# Compute Loss function | ||
losses = self.loss(outputs, masks, targets) | ||
loss = sum(losses.values()) | ||
|
||
self.log(mode + "/train_loss", loss) | ||
for key in losses.keys(): | ||
self.log(mode + "/losses/" + key, losses[key]) | ||
|
||
# Log some metrics | ||
self.log(mode + "/metrics/max_sdf", pred_tsdf_t.max()) | ||
self.log(mode + "/metrics/min_sdf", pred_tsdf_t.min()) | ||
self.log(mode + "/metrics/mean_sdf", pred_tsdf_t.mean()) | ||
|
||
return loss | ||
|
||
def training_step(self, train_batch, batch_idx): | ||
return self.step(train_batch, batch_idx, mode="train") | ||
|
||
def validation_step(self, val_batch, batch_idx): | ||
self.step(val_batch, batch_idx, mode="val") | ||
|
||
@torch.no_grad() | ||
def predict_step(self, input_tsdf_t): | ||
return self(input_tsdf_t)[0] |