Skip to content

Commit

Permalink
Add missing folder, fix #1
Browse files Browse the repository at this point in the history
  • Loading branch information
nachovizzo committed Jul 14, 2022
1 parent 0da82e8 commit 8dd65be
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 0 deletions.
Empty file added models/.gitkeep
Empty file.
3 changes: 3 additions & 0 deletions src/make_it_dense/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .atlas import *
from .blocks3d import *
from .completion_net import *
49 changes: 49 additions & 0 deletions src/make_it_dense/models/atlas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Inspired by https://github.com/magicleap/Atlas
import torch
import torch.nn as nn
import torch.nn.functional as F

from make_it_dense.models.blocks3d import FeatureExtractor, Unet3D
from make_it_dense.utils.config import MkdConfig


class AtlasNet(nn.Module):
def __init__(self, config: MkdConfig):
super().__init__()
self.config = config
self.voxel_sizes = self.config.fusion.voxel_sizes
self.occ_th = self.config.model.occ_th
self.f_maps = self.config.model.f_maps
self.layers_down = self.config.model.layers_down
self.layers_up = self.config.model.layers_up

# Network
self.feature_extractor = FeatureExtractor(channels=self.f_maps[0])
self.unet = Unet3D(
channels=self.f_maps,
layers_down=self.layers_down,
layers_up=self.layers_up,
)
self.decoders = nn.ModuleList(
[nn.Conv3d(c, 1, 1, bias=False) for c in self.f_maps[:-1]][::-1]
)

def forward(self, xs):
feats = self.feature_extractor(xs)
out = self.unet(feats)

output = {}
mask_occupied = []
for i, (decoder, x) in enumerate(zip(self.decoders, out)):
# regress the TSDF
tsdf = torch.tanh(decoder(x)) * 1.05

# use previous scale to sparsify current scale
if i > 0:
tsdf_prev = output[f"out_tsdf_{self.voxel_sizes[i - 1]}"]
tsdf_prev = F.interpolate(tsdf_prev, scale_factor=2)
mask_truncated = tsdf_prev.abs() >= self.occ_th[i - 1]
tsdf[mask_truncated] = tsdf_prev[mask_truncated].sign()
mask_occupied.append(~mask_truncated)
output[f"out_tsdf_{ self.voxel_sizes[i]}"] = tsdf
return output, mask_occupied
93 changes: 93 additions & 0 deletions src/make_it_dense/models/blocks3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import List

import torch.nn as nn


class ResNetBlock3d(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm3d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm3d(out_channels)

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)

return out


class Unet3D(nn.Module):
def __init__(self, channels: List[int], layers_down: List[int], layers_up: List[int]):
super().__init__()

self.layers_down = nn.ModuleList()
self.layers_down.append(ResNetBlock3d(channels[0], channels[0]))
for i in range(1, len(channels)):
layer = [
nn.Conv3d(
channels[i - 1], channels[i], kernel_size=3, stride=2, padding=1, bias=False
),
nn.BatchNorm3d(channels[i]),
nn.ReLU(inplace=True),
]
# Do we need 4 resnet blocks here?
layer += [ResNetBlock3d(channels[i], channels[i]) for _ in range(layers_down[i])]
self.layers_down.append(nn.Sequential(*layer))

channels = channels[::-1]
self.layers_up_conv = nn.ModuleList()
for i in range(1, len(channels)):
self.layers_up_conv.append(
nn.Sequential(
nn.ConvTranspose3d(
channels[i - 1], channels[i], kernel_size=2, stride=2, bias=False
),
nn.BatchNorm3d(channels[i]),
nn.ReLU(inplace=True),
nn.Conv3d(channels[i], channels[i], kernel_size=3, padding=1, bias=False),
nn.BatchNorm3d(channels[i]),
nn.ReLU(inplace=True),
)
)

self.layers_up_res = nn.ModuleList()
for i in range(1, len(channels)):
layer = [ResNetBlock3d(channels[i], channels[i]) for _ in range(layers_up[i - 1])]
self.layers_up_res.append(nn.Sequential(*layer))

def forward(self, x):
xs = []
for layer in self.layers_down:
x = layer(x)
xs.append(x)

xs.reverse()
out = []
for i in range(len(self.layers_up_conv)):
x = self.layers_up_conv[i](x)
x = (x + xs[i + 1]) / 2.0
x = self.layers_up_res[i](x)
out.append(x)

return out


class FeatureExtractor(nn.Module):
"""Extract features from a TSDF volume withouth chaning the size."""

def __init__(self, channels=4):
super().__init__()
self.model = ResNetBlock3d(1, channels)

def forward(self, x):
return self.model(x)
64 changes: 64 additions & 0 deletions src/make_it_dense/models/completion_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytorch_lightning as pl
import torch

from make_it_dense.loss import SDFLoss
from make_it_dense.models.atlas import AtlasNet
from make_it_dense.utils.config import MkdConfig


class CompletionNet(pl.LightningModule):
def __init__(self, config: MkdConfig):
super().__init__()
self.config = config
self.model = AtlasNet(self.config)
self.loss = SDFLoss(self.config)
self.lr = self.config.optimization.lr
self.voxel_sizes = self.config.fusion.voxel_sizes
self.voxel_trunc = self.config.fusion.voxel_trunc
self.save_hyperparameters(config)

def forward(self, x):
return self.model(x)

def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.lr,
weight_decay=self.config.optimization.weight_decay,
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": lr_scheduler,
"monitor": "train/train_loss",
}

def step(self, batch, batch_idx, mode: str):
inputs, targets = batch
outputs, masks = self(inputs["nodes"])
pred_tsdf_t = outputs["out_tsdf_10"]

# Compute Loss function
losses = self.loss(outputs, masks, targets)
loss = sum(losses.values())

self.log(mode + "/train_loss", loss)
for key in losses.keys():
self.log(mode + "/losses/" + key, losses[key])

# Log some metrics
self.log(mode + "/metrics/max_sdf", pred_tsdf_t.max())
self.log(mode + "/metrics/min_sdf", pred_tsdf_t.min())
self.log(mode + "/metrics/mean_sdf", pred_tsdf_t.mean())

return loss

def training_step(self, train_batch, batch_idx):
return self.step(train_batch, batch_idx, mode="train")

def validation_step(self, val_batch, batch_idx):
self.step(val_batch, batch_idx, mode="val")

@torch.no_grad()
def predict_step(self, input_tsdf_t):
return self(input_tsdf_t)[0]

0 comments on commit 8dd65be

Please sign in to comment.