-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_ar.py
108 lines (91 loc) · 2.96 KB
/
predict_ar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import torch
from nhits import NHITS
from config import get_config
import numpy as np
import os
class Wrapper(torch.nn.Module):
def __init__(self, H, L, nhits_params, state_dict, device):
super().__init__()
self.model = NHITS(H, L, **nhits_params)
self.load_state_dict(state_dict)
self.to(device)
def forward(self, inputs):
return self.model(inputs)
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Generate predictions auto-regressively",
)
parser.add_argument(
"--cfg", required=True, help="yaml config file that was used to train the model"
)
parser.add_argument("--cp", required=True, help="checkpoint file")
parser.add_argument(
"--outfn",
default=None,
help="prediction file name",
)
parser.add_argument(
"--inc",
default=None,
type=int,
help="number of predicted samples to save per prediction; 'None' means use the entire horizon",
)
parser.add_argument("--gpu", action="store_true", help="use gpu")
parser.add_argument(
"--npy",
default=None,
help="optional npy file containing trajectories to predict (overrides datafile from cfg)",
)
args = parser.parse_args()
cfgyml = get_config(args.cfg)
spacing = getattr(cfgyml, "spacing", 1)
cp = torch.load(args.cp, map_location="cpu")
if args.npy is not None:
dset = np.load(args.npy, allow_pickle=True).item()
test_series = dset["solutions"]
else:
dset = np.load(cfgyml.datafile, allow_pickle=True).item()
test_series = dset["solutions"][cfgyml.ntrain + cfgyml.nval :]
test_series = test_series[:, ::spacing]
ntest, n_total_pts, ndim = test_series.shape
H = cfgyml.H
L = cfgyml.input_size
npts = n_total_pts - L
Hf = H * ndim
Lf = L * ndim
device = "cuda" if args.gpu else "cpu"
model = Wrapper(Hf, Lf, cfgyml.nhits_params, cp["state_dict"], device)
model.eval()
inputs = torch.from_numpy(test_series[:, :L].reshape(ntest, -1)).type(torch.float32)
outputs = torch.empty((ntest, n_total_pts * ndim)).to(device)
outputs[:, :Lf] = inputs
idx = Lf
os.makedirs(args.outfn, exist_ok=True)
yt_map = np.memmap(
f"{args.outfn}/ytrue.npy", mode="w+", dtype="float32", shape=(ntest, npts, ndim)
)
yt_map[:] = test_series[:, L:]
yh_map = np.memmap(
f"{args.outfn}/yhat.npy", mode="w+", dtype="float32", shape=(ntest, npts, ndim)
)
if args.inc is not None:
inc = args.inc * ndim
else:
inc = Hf
while idx < n_total_pts * ndim:
print(idx, end="\r")
with torch.no_grad():
hpred = model(outputs[:, idx - Lf : idx])
rem = min(inc, n_total_pts * ndim - idx)
outputs[:, idx : idx + rem] = hpred[:, :rem]
outidx = idx // ndim - L
outwin = outputs[:, idx : idx + rem].reshape(ntest, -1, ndim).cpu().numpy()
yh_map[:, outidx : outidx + outwin.shape[1]] = outwin
idx += inc
print()
np.save(
f"{args.outfn}/md.npy",
{"mode": "ar", "dt": dset["dt"], "config": cfgyml, "shape": (ntest, npts, ndim)},
allow_pickle=True,
)