-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
68 lines (56 loc) · 1.71 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import argparse
from cvae.build import train_vae_grasp, train_gcnn_grasp
from cvae.build import train_vae_keypoint, train_discr_keypoint
from cvae.build import inference_grasp, inference_keypoint
parser = argparse.ArgumentParser()
parser.add_argument(
'--mode',
type=str,
default=None)
parser.add_argument(
'--model_path',
type=str,
default=None,
help='pretrained model')
parser.add_argument(
'--task_name',
type=str,
default='task')
parser.add_argument(
'--gpu',
type=str,
default='0',
help='gpu to use')
parser.add_argument(
'--data_path',
type=str,
default='./data/data.hdf5',
help='path to data in hdf5 format')
args = parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
if args.mode == 'vae_grasp':
train_vae_grasp(data_path=args.data_path,
model_path=args.model_path)
elif args.mode == 'gcnn_grasp':
train_gcnn_grasp(data_path=args.data_path,
model_path=args.model_path)
elif args.mode == 'inference_grasp':
inference_grasp(
data_path=args.data_path,
model_path=args.model_path)
elif args.mode == 'vae_keypoint':
train_vae_keypoint(data_path=args.data_path,
model_path=args.model_path,
task_name=args.task_name)
elif args.mode == 'discr_keypoint':
train_discr_keypoint(data_path=args.data_path,
model_path=args.model_path,
task_name=args.task_name)
elif args.mode == 'inference_keypoint':
inference_keypoint(
data_path=args.data_path,
model_path=args.model_path)
else:
raise NotImplementedError