forked from henry123-boy/Level-S2FM_official
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
23 lines (22 loc) · 765 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import os,sys
import torch
import importlib
from utils import options
from utils.util import log
def main():
log.process(os.getpid())
log.title("[{}] (PyTorch code for training Level-S2fM)".format(sys.argv[0]))
opt_cmd = options.parse_arguments(sys.argv[1:]) # indicate the parameters after train.py
opt = options.set(opt_cmd=opt_cmd)
options.save_options_file(opt)
# torch.set_default_dtype(getattr(torch,opt.prec))
# config the opt
with torch.cuda.device(opt.device):
model = importlib.import_module("pipelines.{}".format(opt.pipeline))
m = model.Model(opt)
m.load_dataset(opt)
m.restore_checkpoint(opt)
m.setup_visualizer(opt)
m.train(opt)
if __name__=="__main__":
main()