Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
bibi547 committed Dec 29, 2024
0 parents commit b862e1c
Show file tree
Hide file tree
Showing 59 changed files with 1,828 additions and 0 deletions.
51 changes: 51 additions & 0 deletions config/all_tooth.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
GENERAL:
experiment: all
seed: 0

DATA:
# data path
root_dir: F:/dataset/Teeth3DS/data
split_dir: F:/dataset/Teeth3DS/split
# batch_size per gpu
batch_size: 2
# sample
num_points: 10000
# augmentation
augmentation: True
# upper/lower

STRUCTURE:
k: 20
input_channels: 15
output_channels: 17
query_num: 80
n_edgeconvs_backbone: 5
emb_dims: 1024
global_pool_backbone: avg # max or avg
norm: instance
use_stn: True # False # spatial transformer network
dynamic: False
dropout: 0.

TRAIN:
max_epochs: 200
weight_decay: 0.0001
delta: 0.1667
load_from_checkpoint:
resume_from_checkpoint:

# one cycle lr scheduler
lr_max: 0.001
pct_start: 0.1 # percentage of the cycle spent increasing lr
div_factor: 25 # determine the initial lr (lr_max / div_factor)
final_div_factor: 1e4 # determine the final lr (lr_max / final_div_factor)
start_epoch: 0

train_file: training_all.txt
train_workers: 1

val_workers: 1
val_file: testing_all.txt

test_workers: 1
test_file: testing_all.txt
Empty file added data/__init__.py
Empty file.
Binary file added data/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added data/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added data/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file added data/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added data/__pycache__/common.cpython-310.pyc
Binary file not shown.
Binary file added data/__pycache__/common.cpython-37.pyc
Binary file not shown.
Binary file added data/__pycache__/common.cpython-38.pyc
Binary file not shown.
Binary file added data/__pycache__/common.cpython-39.pyc
Binary file not shown.
Binary file added data/__pycache__/human.cpython-37.pyc
Binary file not shown.
Binary file added data/__pycache__/st_data.cpython-310.pyc
Binary file not shown.
Binary file added data/__pycache__/st_data.cpython-37.pyc
Binary file not shown.
Binary file added data/__pycache__/st_data.cpython-38.pyc
Binary file not shown.
Binary file added data/__pycache__/st_data.cpython-39.pyc
Binary file not shown.
30 changes: 30 additions & 0 deletions data/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch


def calc_features(vs, ts):
"""
:param vs: (nv, 3), float
:param ts: (nf, 3), long
:return: fea: (15, nf), float
"""
nf = ts.shape[0]
fea = torch.empty((15, nf), dtype=torch.float32).to(vs.device)

vs_in_ts = vs[ts]

fea[:3, :] = vs_in_ts.mean(1).T[None] # centers , 3
fea[3:6, :] = calc_normals(vs, ts).T[None] # normal, 3
fea[6:15, :] = (vs_in_ts - vs_in_ts.mean(1, keepdim=True)).reshape((nf, -1)).T[None]
return fea


def calc_normals(vs: torch.Tensor, ts: torch.Tensor):
"""
:param vs: (n_v, 3)
:param ts: (n_f, 3), long
:return normals: (n_f, 3), float
"""
normals = torch.cross(vs[ts[:, 1]] - vs[ts[:, 0]],
vs[ts[:, 2]] - vs[ts[:, 0]])
normals /= torch.sum(normals ** 2, 1, keepdims=True) ** 0.5 + 1e-9
return normals
95 changes: 95 additions & 0 deletions data/st_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
import numpy as np
import random
import json

import torch
import trimesh
from torch.utils.data import Dataset

from data.common import calc_features
from utils.data_utils import augment, get_offsets, get_centroids, get_masks, get_bmap


class Teeth3DS(Dataset):
def __init__(self, args, split_file: str, train: bool):
# args
self.args = args
self.root_dir = args.root_dir
self.num_points = args.num_points
self.augmentation = args.augmentation if train else False
# files
self.files = []
with open(os.path.join(args.split_dir, split_file)) as f:
for line in f:
filename = line.strip().split('_')[0]
category = line.strip().split('_')[1]
root = os.path.join(self.root_dir, category, filename)
obj_file = os.path.join(root, f'{line.strip()}_sim.off')
json_file = os.path.join(root, f'{line.strip()}_sim_re.txt')
dmap_file = os.path.join(self.root_dir, category + '_fdmap', f'{line.strip()}_sim.txt')
bmap_file = os.path.join(self.root_dir, category + '_fbmap', f'{line.strip()}_sim.txt')
if os.path.exists(obj_file) and os.path.exists(json_file):
self.files.append((obj_file, json_file, dmap_file, bmap_file))
random.shuffle(self.files)

def __len__(self):
return len(self.files)

def __getitem__(self, idx):
obj_file, json_file, dmap_file, bmap_file = self.files[idx]

mesh = trimesh.load(obj_file)
vs, fs = mesh.vertices, mesh.faces
labels = np.loadtxt(json_file, dtype=np.int32)
dmap = np.loadtxt(dmap_file, dtype=np.float32) # [-1 1]
bmap = np.loadtxt(bmap_file, dtype=np.float32) # boundary 1 other 0
# 边界点为1 牙齿点为2 牙龈点为0
b_idx = np.argwhere(bmap == 1)[:, 0]
bmap = np.ones(bmap.shape, dtype='int64')
gum_idx = np.argwhere(labels == 0)[:, 0]
bmap[gum_idx] = -1
bmap[b_idx] = 0
bmap = bmap + 1

# augmentation
if self.augmentation:
vs, fs = augment(vs, fs)
# sample
_, fids = trimesh.sample.sample_surface_even(mesh, self.num_points)

fs, labels = fs[fids], labels[fids]
bmap, dmap = bmap[fids], dmap[np.newaxis, fids]
# extract input features
vs = torch.tensor(vs, dtype=torch.float32)
vs = vs - vs.mean(0) # preprocess
fs = torch.tensor(fs, dtype=torch.long)
features = calc_features(vs, fs) # (15, nf)
labels = np.array(labels, dtype='float64').squeeze()

cs = np.array(features.T[:, :3])
ins_masks, ins_labels, ins_xyz = get_masks(cs, labels) # label:[] xyz:[]

return features, torch.tensor(labels, dtype=torch.long), \
torch.tensor(bmap, dtype=torch.long), torch.tensor(dmap, dtype=torch.float32), \
torch.tensor(ins_masks, dtype=torch.float32), torch.tensor(ins_labels, dtype=torch.float32)


if __name__ == '__main__':
class Args(object):
def __init__(self):
self.root_dir = 'F:/dataset/Teeth3DS/data'
self.split_dir = 'F:/dataset/Teeth3DS/split'
self.num_points = 10000
self.augmentation = True


data = Teeth3DS(Args(), 'training_upper.txt', True)
i = 0
for f, l, b, d, ins_m, ins_l in data:
print(f.shape)
print(b.shape)
print(d.shape)
i += 1

print(i)
Empty file added models/__init__.py
Empty file.
Binary file added models/__pycache__/GCN_REG3_N.cpython-37.pyc
Binary file not shown.
Binary file added models/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added models/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added models/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added models/__pycache__/cbanet.cpython-310.pyc
Binary file not shown.
Binary file added models/__pycache__/dgcnn_utils.cpython-310.pyc
Binary file not shown.
Binary file added models/__pycache__/fps_utils.cpython-310.pyc
Binary file not shown.
190 changes: 190 additions & 0 deletions models/cbanet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import torch
import torch.nn as nn
from .dgcnn_utils import STN, Backbone, SharedMLP1d, EdgeConv, knn
from .fps_utils import center_fps
# from pointnet2_ops.pointnet2_utils import furthest_point_sample


class CBANet(nn.Module):
def __init__(self, args):
super(CBANet, self).__init__()
self.num_points = args.num_points
self.query_num = args.query_num
self.out_channels = args.output_channels
self.k = args.k
if args.use_stn:
self.stn = STN(args.k, args.norm)
self.backbone = Backbone(args)

self.bmap_decoder = SharedMLP1d([1344, 256], args.norm)
self.bmap_head = nn.Sequential(SharedMLP1d([256, 256], args.norm),
nn.Dropout(args.dropout),
SharedMLP1d([256, 128], args.norm),
nn.Conv1d(128, 3, kernel_size=1))
self.dmap_decoder = SharedMLP1d([1344, 256], args.norm)
self.dmap_head = nn.Sequential(SharedMLP1d([256, 256], args.norm),
nn.Dropout1d(args.dropout),
SharedMLP1d([256, 128], args.norm),
nn.Conv1d(128, 1, kernel_size=1))

self.decoder = SharedMLP1d([1344, 256], args.norm)
self.mask_decoder = SharedMLP1d([256, 256], args.norm)

self.homo_conv = EdgeConv([256 * 2, 256], self.k, args.norm)
self.sp_conv = EdgeConv([256 * 2, 256], self.k, args.norm)

self.class_head = nn.Sequential(SharedMLP1d([256, 256], args.norm),
nn.Dropout(args.dropout),
SharedMLP1d([256, 128], args.norm),
nn.Conv1d(128, args.output_channels, kernel_size=1))
self.score_head = nn.Sequential(SharedMLP1d([256, 256], args.norm),
nn.Dropout(args.dropout),
SharedMLP1d([256, 128], args.norm),
nn.Conv1d(128, 1, kernel_size=1))
self.mask_head = nn.Sequential(SharedMLP1d([self.query_num, self.query_num], args.norm),
nn.Dropout(args.dropout),
nn.Conv1d(self.query_num, self.query_num, kernel_size=1),
nn.Sigmoid())

def forward(self, x, eid):
device = x.device
batch_size = x.size(0)
p = x[:, :3, :].contiguous() # xyz

if hasattr(self, "stn"):
if not hasattr(self, 'c'):
self.c = torch.zeros((x.shape[0], 15, 15), dtype=torch.float32, device=device)
for i in range(0, 15, 3):
self.c[:, i:i + 3, i:i + 3] = 1

t = self.stn(x[:, :3, :].contiguous())
t = t.repeat(1, 5, 5) # (batch_size, 15, 15)
t1 = self.c * t
x = torch.bmm(t1, x)
else:
t = torch.ones((1, 1), device=device)

feats = self.backbone(x) # (b, 1344, 10000)

bmap_feats = self.bmap_decoder(feats)
dmap_feats = self.dmap_decoder(feats)
bmap_out = self.bmap_head(bmap_feats)
dmap_out = self.dmap_head(dmap_feats)

# #==========================mask stage===============================================
sp_masks, sp_probs, g_ps = torch.zeros([batch_size, self.num_points, self.query_num]), torch.zeros([batch_size, 17, self.query_num]), torch.zeros([batch_size, 3, self.query_num])
score = torch.zeros([batch_size, 1, self.query_num])
all_idx = torch.zeros([batch_size, self.query_num]) # sampled idx
if eid > 19:

feats = self.decoder(feats) # (b, 256, 10000)
mask_feats = self.mask_decoder(feats) # (b, 256 * 3, 10000)

dmap = dmap_out.detach()
bmap = bmap_out.detach().argmax(1) # (b, 1, 10000)

g_fs = []
g_ps = []
all_idx = []
for i in range(0, batch_size):
dm = dmap[i].squeeze()
t_idx = torch.nonzero(dm > 0.2).squeeze() # M tooth points
ps = p[i].T[t_idx] # (M, 3)
fs = feats[i].T[t_idx] # (M, 256)
ps_idx = knn(ps.T.unsqueeze(0).contiguous(), self.k) # 1, 20, M
fs = self.homo_conv(fs.T.unsqueeze(0).contiguous(), ps_idx) # 1, 256, M

# select_idx = furthest_point_sample(ps.unsqueeze(0).contiguous(), self.query_num).squeeze().to(torch.long)
select_idx = center_fps(dm[t_idx], ps, self.query_num)
select_ps = ps[select_idx].T.unsqueeze(0) # 1, 3, 100
select_fs = fs[:, :, select_idx] # 1, 256, 100

g_fs.append(select_fs)
g_ps.append(select_ps)
sampl_idx = t_idx[select_idx]
all_idx.append(sampl_idx)

g_fs = torch.concat(g_fs, dim=0) # b, 256, 100
g_ps = torch.concat(g_ps, dim=0) # b, 3, 100

idx = knn(g_ps.contiguous(), self.k)
g_feats = self.sp_conv(g_fs, idx) # b, 256, 100

mask_feats = torch.einsum('bdn,bdm->bnm', g_feats, mask_feats) # b, 100, 10000

sp_probs = self.class_head(g_feats) # b, 17, 100
score = self.score_head(g_feats) # b, 1, 100
sp_masks = self.mask_head(mask_feats)

return bmap_out, dmap_out, sp_masks, sp_probs, score, all_idx

def inference(self, x):
device = x.device
batch_size = x.size(0)
p = x[:, :3, :].contiguous() # xyz

if hasattr(self, "stn"):
if not hasattr(self, 'c'):
self.c = torch.zeros((x.shape[0], 15, 15), dtype=torch.float32, device=device)
for i in range(0, 15, 3):
self.c[:, i:i + 3, i:i + 3] = 1

t = self.stn(x[:, :3, :].contiguous())
t = t.repeat(1, 5, 5) # (batch_size, 15, 15)
t1 = self.c * t
x = torch.bmm(t1, x)
else:
t = torch.ones((1, 1), device=device)

feats = self.backbone(x) # (b, 1344, 10000)

bmap_feats = self.bmap_decoder(feats)
dmap_feats = self.dmap_decoder(feats)
bmap_out = self.bmap_head(bmap_feats)
dmap_out = self.dmap_head(dmap_feats)

# #==========================mask stage===============================================

feats = self.decoder(feats) # (b, 256, 10000)
mask_feats = self.mask_decoder(feats) # (b, 256 * 3, 10000)

dmap = dmap_out.detach()
bmap = bmap_out.detach().argmax(1) # (b, 1, 10000)

g_fs = []
g_ps = []
all_idx = []
for i in range(0, batch_size):
dm = dmap[i].squeeze()
t_idx = torch.nonzero(dm > 0.2).squeeze() # M tooth points
ps = p[i].T[t_idx] # (M, 3)
fs = feats[i].T[t_idx] # (M, 256)
ps_idx = knn(ps.T.unsqueeze(0).contiguous(), self.k) # 1, 30, M
fs = self.homo_conv(fs.T.unsqueeze(0).contiguous(), ps_idx) # 1, 256, M

# select_idx = furthest_point_sample(ps.unsqueeze(0).contiguous(), self.query_num).squeeze().to(torch.long)
select_idx = center_fps(dm[t_idx], ps, self.query_num)
select_ps = ps[select_idx].T.unsqueeze(0) # 1, 3, 100
select_fs = fs[:, :, select_idx] # 1, 256, 100

g_fs.append(select_fs)
g_ps.append(select_ps)
sampl_idx = t_idx[select_idx]
all_idx.append(sampl_idx)

g_fs = torch.concat(g_fs, dim=0) # b, 256, 100
g_ps = torch.concat(g_ps, dim=0) # b, 3, 100

idx = knn(g_ps.contiguous(), self.k)
g_feats = self.sp_conv(g_fs, idx) # b, 256, 100

mask_feats = torch.einsum('bdn,bdm->bnm', g_feats, mask_feats) # b, 100, 10000

sp_probs = self.class_head(g_feats) # b, 17, 100
score = self.score_head(g_feats) # b, 1, 100
sp_masks = self.mask_head(mask_feats)

return bmap, dmap_out, sp_masks, sp_probs, score



Loading

0 comments on commit b862e1c

Please sign in to comment.