forked from wmtlab/GrainGrasp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGrainGrasp.py
91 lines (80 loc) · 3.42 KB
/
GrainGrasp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import time
import torch
import numpy as np
import open3d as o3d
from utils import annotate
from utils import vis
from utils import tools
from utils import Load_obman
from attrdict import AttrDict
from DCoG import DCoGModel
from PointCVAE import load_model as load_cvae_model
from PointCVAE import inference as cvae_inference
from config import cfgs
class GrainGrasp:
def __init__(self, dcog_config, cvae_path=None, device="cuda"):
"""
if cvae_config is None, the cvae model will not be loaded
"""
self.device = device
self.dcog_model = DCoGModel(
dcog_config.mano_path,
dcog_config.init_handpose_path,
dcog_config.init_quat_path,
dcog_config.finger_index_path,
dcog_config.tip_index_path,
dcog_config.supnet_path,
dcog_config.init_move_finger_idx,
dcog_config.weights,
device=self.device,
)
print("-----------------DCoG Model loaded successfully-----------------")
if cvae_path is not None:
self.cvae_model = load_cvae_model(cvae_path, requires_grad=False)
self.cvae_model = self.cvae_model.eval().to(self.device)
print("-----------------CVAE Model loaded successfully-----------------")
else:
self.cvae_model = None
def inference_complete(self, obj_pc, epochs=300, select_finger_idx=[1, 2, 3, 4, 5], threshold=0.1):
"""
obj_pc: Tensor, (N, 3)
return: result, AttrDict
"""
obj_pc = obj_pc.to(self.device)
if self.cvae_model is None:
RuntimeError("You should load the CVAE model if you want to run this function.")
obj_cls = cvae_inference(self.cvae_model, obj_pc)
result = self.dcog_model.run(obj_pc, obj_cls, epochs, select_finger_idx)
E_pen, min_idx = self.dcog_model.get_idx_minEpen(obj_pc, result.hand_pc, threshold)
result.obj_cls = obj_cls.cpu().detach()
result.E_pen = E_pen
result.min_idx = min_idx
result.min_idx_hand_pc = result.hand_pc[min_idx]
result.min_idx_record_hand_pc = result.record_hand_pc[:, min_idx]
return result
def inference_only_opt(self, obj_pc, obj_cls=None, hand_pc=None, K=50, epochs=300, select_finger_idx=[1, 2, 3, 4, 5], threshold=0.1):
"""
obj_pc: Tensor, (N, 3)
obj_cls: Tensor, (N, 3)
hand_pc: Tensor, (M, 3)
K: int, the number of the nearest neighbors
return: result, AttrDict
"""
obj_pc = obj_pc.to(self.device)
if obj_cls is None:
if hand_pc is None:
RuntimeError("If you don't have 'obj_cls', you should at least provide 'hand_pc'.")
print("'obj_cls' will be generated by the annotation method with K = {}".format(K))
obj_cls, _ = annotate.get_obj_cls_and_colors(hand_pc, obj_pc, K=K, device=self.device)
obj_cls = obj_cls.squeeze()
obj_cls = obj_cls.to(self.device)
result = self.dcog_model.run(obj_pc, obj_cls, epochs, select_finger_idx)
E_pen, min_idx = self.dcog_model.get_idx_minEpen(obj_pc, result.hand_pc, threshold)
result.obj_cls = obj_cls.cpu().detach()
result.E_pen = E_pen
result.min_idx = min_idx
result.min_idx_hand_pc = result.hand_pc[min_idx]
result.min_idx_record_hand_pc = result.record_hand_pc[:, min_idx]
return result
if __name__ == "__main__":
pass