-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit b862e1c
Showing
59 changed files
with
1,828 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
GENERAL: | ||
experiment: all | ||
seed: 0 | ||
|
||
DATA: | ||
# data path | ||
root_dir: F:/dataset/Teeth3DS/data | ||
split_dir: F:/dataset/Teeth3DS/split | ||
# batch_size per gpu | ||
batch_size: 2 | ||
# sample | ||
num_points: 10000 | ||
# augmentation | ||
augmentation: True | ||
# upper/lower | ||
|
||
STRUCTURE: | ||
k: 20 | ||
input_channels: 15 | ||
output_channels: 17 | ||
query_num: 80 | ||
n_edgeconvs_backbone: 5 | ||
emb_dims: 1024 | ||
global_pool_backbone: avg # max or avg | ||
norm: instance | ||
use_stn: True # False # spatial transformer network | ||
dynamic: False | ||
dropout: 0. | ||
|
||
TRAIN: | ||
max_epochs: 200 | ||
weight_decay: 0.0001 | ||
delta: 0.1667 | ||
load_from_checkpoint: | ||
resume_from_checkpoint: | ||
|
||
# one cycle lr scheduler | ||
lr_max: 0.001 | ||
pct_start: 0.1 # percentage of the cycle spent increasing lr | ||
div_factor: 25 # determine the initial lr (lr_max / div_factor) | ||
final_div_factor: 1e4 # determine the final lr (lr_max / final_div_factor) | ||
start_epoch: 0 | ||
|
||
train_file: training_all.txt | ||
train_workers: 1 | ||
|
||
val_workers: 1 | ||
val_file: testing_all.txt | ||
|
||
test_workers: 1 | ||
test_file: testing_all.txt |
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch | ||
|
||
|
||
def calc_features(vs, ts): | ||
""" | ||
:param vs: (nv, 3), float | ||
:param ts: (nf, 3), long | ||
:return: fea: (15, nf), float | ||
""" | ||
nf = ts.shape[0] | ||
fea = torch.empty((15, nf), dtype=torch.float32).to(vs.device) | ||
|
||
vs_in_ts = vs[ts] | ||
|
||
fea[:3, :] = vs_in_ts.mean(1).T[None] # centers , 3 | ||
fea[3:6, :] = calc_normals(vs, ts).T[None] # normal, 3 | ||
fea[6:15, :] = (vs_in_ts - vs_in_ts.mean(1, keepdim=True)).reshape((nf, -1)).T[None] | ||
return fea | ||
|
||
|
||
def calc_normals(vs: torch.Tensor, ts: torch.Tensor): | ||
""" | ||
:param vs: (n_v, 3) | ||
:param ts: (n_f, 3), long | ||
:return normals: (n_f, 3), float | ||
""" | ||
normals = torch.cross(vs[ts[:, 1]] - vs[ts[:, 0]], | ||
vs[ts[:, 2]] - vs[ts[:, 0]]) | ||
normals /= torch.sum(normals ** 2, 1, keepdims=True) ** 0.5 + 1e-9 | ||
return normals |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import os | ||
import numpy as np | ||
import random | ||
import json | ||
|
||
import torch | ||
import trimesh | ||
from torch.utils.data import Dataset | ||
|
||
from data.common import calc_features | ||
from utils.data_utils import augment, get_offsets, get_centroids, get_masks, get_bmap | ||
|
||
|
||
class Teeth3DS(Dataset): | ||
def __init__(self, args, split_file: str, train: bool): | ||
# args | ||
self.args = args | ||
self.root_dir = args.root_dir | ||
self.num_points = args.num_points | ||
self.augmentation = args.augmentation if train else False | ||
# files | ||
self.files = [] | ||
with open(os.path.join(args.split_dir, split_file)) as f: | ||
for line in f: | ||
filename = line.strip().split('_')[0] | ||
category = line.strip().split('_')[1] | ||
root = os.path.join(self.root_dir, category, filename) | ||
obj_file = os.path.join(root, f'{line.strip()}_sim.off') | ||
json_file = os.path.join(root, f'{line.strip()}_sim_re.txt') | ||
dmap_file = os.path.join(self.root_dir, category + '_fdmap', f'{line.strip()}_sim.txt') | ||
bmap_file = os.path.join(self.root_dir, category + '_fbmap', f'{line.strip()}_sim.txt') | ||
if os.path.exists(obj_file) and os.path.exists(json_file): | ||
self.files.append((obj_file, json_file, dmap_file, bmap_file)) | ||
random.shuffle(self.files) | ||
|
||
def __len__(self): | ||
return len(self.files) | ||
|
||
def __getitem__(self, idx): | ||
obj_file, json_file, dmap_file, bmap_file = self.files[idx] | ||
|
||
mesh = trimesh.load(obj_file) | ||
vs, fs = mesh.vertices, mesh.faces | ||
labels = np.loadtxt(json_file, dtype=np.int32) | ||
dmap = np.loadtxt(dmap_file, dtype=np.float32) # [-1 1] | ||
bmap = np.loadtxt(bmap_file, dtype=np.float32) # boundary 1 other 0 | ||
# 边界点为1 牙齿点为2 牙龈点为0 | ||
b_idx = np.argwhere(bmap == 1)[:, 0] | ||
bmap = np.ones(bmap.shape, dtype='int64') | ||
gum_idx = np.argwhere(labels == 0)[:, 0] | ||
bmap[gum_idx] = -1 | ||
bmap[b_idx] = 0 | ||
bmap = bmap + 1 | ||
|
||
# augmentation | ||
if self.augmentation: | ||
vs, fs = augment(vs, fs) | ||
# sample | ||
_, fids = trimesh.sample.sample_surface_even(mesh, self.num_points) | ||
|
||
fs, labels = fs[fids], labels[fids] | ||
bmap, dmap = bmap[fids], dmap[np.newaxis, fids] | ||
# extract input features | ||
vs = torch.tensor(vs, dtype=torch.float32) | ||
vs = vs - vs.mean(0) # preprocess | ||
fs = torch.tensor(fs, dtype=torch.long) | ||
features = calc_features(vs, fs) # (15, nf) | ||
labels = np.array(labels, dtype='float64').squeeze() | ||
|
||
cs = np.array(features.T[:, :3]) | ||
ins_masks, ins_labels, ins_xyz = get_masks(cs, labels) # label:[] xyz:[] | ||
|
||
return features, torch.tensor(labels, dtype=torch.long), \ | ||
torch.tensor(bmap, dtype=torch.long), torch.tensor(dmap, dtype=torch.float32), \ | ||
torch.tensor(ins_masks, dtype=torch.float32), torch.tensor(ins_labels, dtype=torch.float32) | ||
|
||
|
||
if __name__ == '__main__': | ||
class Args(object): | ||
def __init__(self): | ||
self.root_dir = 'F:/dataset/Teeth3DS/data' | ||
self.split_dir = 'F:/dataset/Teeth3DS/split' | ||
self.num_points = 10000 | ||
self.augmentation = True | ||
|
||
|
||
data = Teeth3DS(Args(), 'training_upper.txt', True) | ||
i = 0 | ||
for f, l, b, d, ins_m, ins_l in data: | ||
print(f.shape) | ||
print(b.shape) | ||
print(d.shape) | ||
i += 1 | ||
|
||
print(i) |
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
import torch | ||
import torch.nn as nn | ||
from .dgcnn_utils import STN, Backbone, SharedMLP1d, EdgeConv, knn | ||
from .fps_utils import center_fps | ||
# from pointnet2_ops.pointnet2_utils import furthest_point_sample | ||
|
||
|
||
class CBANet(nn.Module): | ||
def __init__(self, args): | ||
super(CBANet, self).__init__() | ||
self.num_points = args.num_points | ||
self.query_num = args.query_num | ||
self.out_channels = args.output_channels | ||
self.k = args.k | ||
if args.use_stn: | ||
self.stn = STN(args.k, args.norm) | ||
self.backbone = Backbone(args) | ||
|
||
self.bmap_decoder = SharedMLP1d([1344, 256], args.norm) | ||
self.bmap_head = nn.Sequential(SharedMLP1d([256, 256], args.norm), | ||
nn.Dropout(args.dropout), | ||
SharedMLP1d([256, 128], args.norm), | ||
nn.Conv1d(128, 3, kernel_size=1)) | ||
self.dmap_decoder = SharedMLP1d([1344, 256], args.norm) | ||
self.dmap_head = nn.Sequential(SharedMLP1d([256, 256], args.norm), | ||
nn.Dropout1d(args.dropout), | ||
SharedMLP1d([256, 128], args.norm), | ||
nn.Conv1d(128, 1, kernel_size=1)) | ||
|
||
self.decoder = SharedMLP1d([1344, 256], args.norm) | ||
self.mask_decoder = SharedMLP1d([256, 256], args.norm) | ||
|
||
self.homo_conv = EdgeConv([256 * 2, 256], self.k, args.norm) | ||
self.sp_conv = EdgeConv([256 * 2, 256], self.k, args.norm) | ||
|
||
self.class_head = nn.Sequential(SharedMLP1d([256, 256], args.norm), | ||
nn.Dropout(args.dropout), | ||
SharedMLP1d([256, 128], args.norm), | ||
nn.Conv1d(128, args.output_channels, kernel_size=1)) | ||
self.score_head = nn.Sequential(SharedMLP1d([256, 256], args.norm), | ||
nn.Dropout(args.dropout), | ||
SharedMLP1d([256, 128], args.norm), | ||
nn.Conv1d(128, 1, kernel_size=1)) | ||
self.mask_head = nn.Sequential(SharedMLP1d([self.query_num, self.query_num], args.norm), | ||
nn.Dropout(args.dropout), | ||
nn.Conv1d(self.query_num, self.query_num, kernel_size=1), | ||
nn.Sigmoid()) | ||
|
||
def forward(self, x, eid): | ||
device = x.device | ||
batch_size = x.size(0) | ||
p = x[:, :3, :].contiguous() # xyz | ||
|
||
if hasattr(self, "stn"): | ||
if not hasattr(self, 'c'): | ||
self.c = torch.zeros((x.shape[0], 15, 15), dtype=torch.float32, device=device) | ||
for i in range(0, 15, 3): | ||
self.c[:, i:i + 3, i:i + 3] = 1 | ||
|
||
t = self.stn(x[:, :3, :].contiguous()) | ||
t = t.repeat(1, 5, 5) # (batch_size, 15, 15) | ||
t1 = self.c * t | ||
x = torch.bmm(t1, x) | ||
else: | ||
t = torch.ones((1, 1), device=device) | ||
|
||
feats = self.backbone(x) # (b, 1344, 10000) | ||
|
||
bmap_feats = self.bmap_decoder(feats) | ||
dmap_feats = self.dmap_decoder(feats) | ||
bmap_out = self.bmap_head(bmap_feats) | ||
dmap_out = self.dmap_head(dmap_feats) | ||
|
||
# #==========================mask stage=============================================== | ||
sp_masks, sp_probs, g_ps = torch.zeros([batch_size, self.num_points, self.query_num]), torch.zeros([batch_size, 17, self.query_num]), torch.zeros([batch_size, 3, self.query_num]) | ||
score = torch.zeros([batch_size, 1, self.query_num]) | ||
all_idx = torch.zeros([batch_size, self.query_num]) # sampled idx | ||
if eid > 19: | ||
|
||
feats = self.decoder(feats) # (b, 256, 10000) | ||
mask_feats = self.mask_decoder(feats) # (b, 256 * 3, 10000) | ||
|
||
dmap = dmap_out.detach() | ||
bmap = bmap_out.detach().argmax(1) # (b, 1, 10000) | ||
|
||
g_fs = [] | ||
g_ps = [] | ||
all_idx = [] | ||
for i in range(0, batch_size): | ||
dm = dmap[i].squeeze() | ||
t_idx = torch.nonzero(dm > 0.2).squeeze() # M tooth points | ||
ps = p[i].T[t_idx] # (M, 3) | ||
fs = feats[i].T[t_idx] # (M, 256) | ||
ps_idx = knn(ps.T.unsqueeze(0).contiguous(), self.k) # 1, 20, M | ||
fs = self.homo_conv(fs.T.unsqueeze(0).contiguous(), ps_idx) # 1, 256, M | ||
|
||
# select_idx = furthest_point_sample(ps.unsqueeze(0).contiguous(), self.query_num).squeeze().to(torch.long) | ||
select_idx = center_fps(dm[t_idx], ps, self.query_num) | ||
select_ps = ps[select_idx].T.unsqueeze(0) # 1, 3, 100 | ||
select_fs = fs[:, :, select_idx] # 1, 256, 100 | ||
|
||
g_fs.append(select_fs) | ||
g_ps.append(select_ps) | ||
sampl_idx = t_idx[select_idx] | ||
all_idx.append(sampl_idx) | ||
|
||
g_fs = torch.concat(g_fs, dim=0) # b, 256, 100 | ||
g_ps = torch.concat(g_ps, dim=0) # b, 3, 100 | ||
|
||
idx = knn(g_ps.contiguous(), self.k) | ||
g_feats = self.sp_conv(g_fs, idx) # b, 256, 100 | ||
|
||
mask_feats = torch.einsum('bdn,bdm->bnm', g_feats, mask_feats) # b, 100, 10000 | ||
|
||
sp_probs = self.class_head(g_feats) # b, 17, 100 | ||
score = self.score_head(g_feats) # b, 1, 100 | ||
sp_masks = self.mask_head(mask_feats) | ||
|
||
return bmap_out, dmap_out, sp_masks, sp_probs, score, all_idx | ||
|
||
def inference(self, x): | ||
device = x.device | ||
batch_size = x.size(0) | ||
p = x[:, :3, :].contiguous() # xyz | ||
|
||
if hasattr(self, "stn"): | ||
if not hasattr(self, 'c'): | ||
self.c = torch.zeros((x.shape[0], 15, 15), dtype=torch.float32, device=device) | ||
for i in range(0, 15, 3): | ||
self.c[:, i:i + 3, i:i + 3] = 1 | ||
|
||
t = self.stn(x[:, :3, :].contiguous()) | ||
t = t.repeat(1, 5, 5) # (batch_size, 15, 15) | ||
t1 = self.c * t | ||
x = torch.bmm(t1, x) | ||
else: | ||
t = torch.ones((1, 1), device=device) | ||
|
||
feats = self.backbone(x) # (b, 1344, 10000) | ||
|
||
bmap_feats = self.bmap_decoder(feats) | ||
dmap_feats = self.dmap_decoder(feats) | ||
bmap_out = self.bmap_head(bmap_feats) | ||
dmap_out = self.dmap_head(dmap_feats) | ||
|
||
# #==========================mask stage=============================================== | ||
|
||
feats = self.decoder(feats) # (b, 256, 10000) | ||
mask_feats = self.mask_decoder(feats) # (b, 256 * 3, 10000) | ||
|
||
dmap = dmap_out.detach() | ||
bmap = bmap_out.detach().argmax(1) # (b, 1, 10000) | ||
|
||
g_fs = [] | ||
g_ps = [] | ||
all_idx = [] | ||
for i in range(0, batch_size): | ||
dm = dmap[i].squeeze() | ||
t_idx = torch.nonzero(dm > 0.2).squeeze() # M tooth points | ||
ps = p[i].T[t_idx] # (M, 3) | ||
fs = feats[i].T[t_idx] # (M, 256) | ||
ps_idx = knn(ps.T.unsqueeze(0).contiguous(), self.k) # 1, 30, M | ||
fs = self.homo_conv(fs.T.unsqueeze(0).contiguous(), ps_idx) # 1, 256, M | ||
|
||
# select_idx = furthest_point_sample(ps.unsqueeze(0).contiguous(), self.query_num).squeeze().to(torch.long) | ||
select_idx = center_fps(dm[t_idx], ps, self.query_num) | ||
select_ps = ps[select_idx].T.unsqueeze(0) # 1, 3, 100 | ||
select_fs = fs[:, :, select_idx] # 1, 256, 100 | ||
|
||
g_fs.append(select_fs) | ||
g_ps.append(select_ps) | ||
sampl_idx = t_idx[select_idx] | ||
all_idx.append(sampl_idx) | ||
|
||
g_fs = torch.concat(g_fs, dim=0) # b, 256, 100 | ||
g_ps = torch.concat(g_ps, dim=0) # b, 3, 100 | ||
|
||
idx = knn(g_ps.contiguous(), self.k) | ||
g_feats = self.sp_conv(g_fs, idx) # b, 256, 100 | ||
|
||
mask_feats = torch.einsum('bdn,bdm->bnm', g_feats, mask_feats) # b, 100, 10000 | ||
|
||
sp_probs = self.class_head(g_feats) # b, 17, 100 | ||
score = self.score_head(g_feats) # b, 1, 100 | ||
sp_masks = self.mask_head(mask_feats) | ||
|
||
return bmap, dmap_out, sp_masks, sp_probs, score | ||
|
||
|
||
|
Oops, something went wrong.