-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
152 lines (123 loc) · 4.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
import torch.nn as nn
import numpy as np
import argparse
from tqdm import tqdm
import time
import os
from functools import partial
from torch.optim.lr_scheduler import StepLR, OneCycleLR
import psutil, platform, subprocess, re, shutil
from time import time
from contextlib import contextmanager
import math
class Trainer(object):
# Trainer wrapper for neural operator
def __init__(self,
config,
args,
):
self.config = config
self.args = args
def training_loop(self):
raise NotImplementedError
def test_loop(self):
raise NotImplementedError
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
class Colors:
"""Defining Color Codes to color the text displayed on terminal.
"""
red = "\033[91m"
green = "\033[92m"
yellow = "\033[93m"
blue = "\033[94m"
magenta = "\033[95m"
end = "\033[0m"
def color(string: str, color: Colors = Colors.yellow) -> str:
return f"{color}{string}{Colors.end}"
def index_points(points, idx):
"""
:param points: input points data, [B, N, C]
:param idx: sample index data, [B, S]
:return: new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def save_checkpoint(state, save_path: str, is_best: bool = False, max_keep: int = None):
"""Saves torch model to checkpoint file.
Args:
state (torch model state): State of a torch Neural Network
save_path (str): Destination path for saving checkpoint
is_best (bool): If ``True`` creates additional copy
``best_model.ckpt``
max_keep (int): Specifies the max amount of checkpoints to keep
"""
# save checkpoint
torch.save(state, save_path)
# deal with max_keep
save_dir = os.path.dirname(save_path)
list_path = os.path.join(save_dir, 'latest_checkpoint.txt')
save_path = os.path.basename(save_path)
if os.path.exists(list_path):
with open(list_path) as f:
ckpt_list = f.readlines()
ckpt_list = [save_path + '\n'] + ckpt_list
else:
ckpt_list = [save_path + '\n']
if max_keep is not None:
for ckpt in ckpt_list[max_keep:]:
ckpt = os.path.join(save_dir, ckpt[:-1])
if os.path.exists(ckpt):
os.remove(ckpt)
ckpt_list[max_keep:] = []
with open(list_path, 'w') as f:
f.writelines(ckpt_list)
# copy best
if is_best:
shutil.copyfile(save_path, os.path.join(save_dir, 'best_model.ckpt'))
def load_checkpoint(ckpt_dir_or_file: str, map_location=None, load_best=False):
"""Loads torch model from checkpoint file.
Args:
ckpt_dir_or_file (str): Path to checkpoint directory or filename
map_location: Can be used to directly load to specific device
load_best (bool): If True loads ``best_model.ckpt`` if exists.
"""
if os.path.isdir(ckpt_dir_or_file):
if load_best:
ckpt_path = os.path.join(ckpt_dir_or_file, 'best_model.ckpt')
else:
with open(os.path.join(ckpt_dir_or_file, 'latest_checkpoint.txt')) as f:
ckpt_path = os.path.join(ckpt_dir_or_file, f.readline()[:-1])
else:
ckpt_path = ckpt_dir_or_file
ckpt = torch.load(ckpt_path, map_location=map_location)
print(' [*] Loading checkpoint from %s succeed!' % ckpt_path)
return ckpt
def ensure_dir(dir_name: str):
"""Creates folder if not exists.
"""
if not os.path.exists(dir_name):
os.makedirs(dir_name)
import timeit
class Timer:
def __enter__(self):
self.t_start = timeit.default_timer()
return self
def __exit__(self, _1, _2, _3):
self.t_end = timeit.default_timer()
self.dt = self.t_end - self.t_start