-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathfileio_utils.py
49 lines (40 loc) · 1.81 KB
/
fileio_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import numpy as np
def save_int(t: torch.Tensor, scaling_factor: int, path):
if path[-4:] != '.bin':
raise ValueError('Path must end with .bin')
t_ = torch.round(t * scaling_factor).to(torch.int32)
t_.cpu().detach().numpy().astype(np.int32).tofile(path)
def save_long(t: torch.Tensor, scaling_factor: int, path):
if path[-4:] != '.bin':
raise ValueError('Path must end with .bin')
t_ = torch.round(t * scaling_factor).to(torch.int64)
t_.cpu().detach().numpy().astype(np.int64).tofile(path)
def load_int(path, device = 0):
if path[-4:] != '.bin':
raise ValueError('Path must end with .bin')
return torch.from_numpy(np.fromfile(path, dtype=np.int32)).to(device)
def load_long(path, device = 0):
if path[-4:] != '.bin':
raise ValueError('Path must end with .bin')
return torch.from_numpy(np.fromfile(path, dtype=np.int64)).to(device)
def to_int64(tensor: torch.Tensor, log_sf: int):
tensor_ = tensor.to(torch.float64)
tensor_ = torch.round(tensor_ * (1 << log_sf)).to(torch.int64)
return tensor_
def to_float(tensor: torch.Tensor, log_sf: int, to_type: torch.dtype = torch.float32):
tensor_ = (tensor / (1 << log_sf)).to(to_type)
return tensor_
def rescale(tensor: torch.Tensor, log_sf: int):
assert tensor.dtype == torch.int64
tensor_abs = tensor.abs()
tensor_abs += (1 << (log_sf - 1))
tensor_abs >>= log_sf
tensor = tensor.sign() * tensor_abs
return tensor
# kill ours
def fromto_int64(tensor: torch.Tensor, log_sf: int, float_dtype: torch.dtype = torch.float64):
return to_float(to_int64(tensor, log_sf), log_sf, torch.float64)
def compare_q(t: torch.Tensor, t_q: torch.Tensor, log_sf: int):
t_ = to_float(t_q, log_sf, torch.float64)
return (t - t_).abs().max().item(), (t - t_).abs().mean().item()