Skip to content

Commit cbd575c

Browse files
author
Michael Horgan
committed
bug fix for predict_ar
1 parent 4c3f499 commit cbd575c

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

predict_ar.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def forward(self, inputs):
5757
test_series = dset["solutions"][cfgyml.ntrain + cfgyml.nval :]
5858

5959
test_series = test_series[:, ::spacing]
60-
ntest, npts, ndim = test_series.shape
60+
ntest, n_total_pts, ndim = test_series.shape
6161
H = cfgyml.H
6262
L = cfgyml.input_size
63-
npts -= L
63+
npts = n_total_pts - L
6464
Hf = H * ndim
6565
Lf = L * ndim
6666

@@ -70,7 +70,7 @@ def forward(self, inputs):
7070
model.eval()
7171

7272
inputs = torch.from_numpy(test_series[:, :L].reshape(ntest, -1)).type(torch.float32)
73-
outputs = torch.empty((ntest, npts * ndim)).to(device)
73+
outputs = torch.empty((ntest, n_total_pts * ndim)).to(device)
7474

7575
outputs[:, :Lf] = inputs
7676
idx = Lf
@@ -89,11 +89,11 @@ def forward(self, inputs):
8989
inc = args.inc * ndim
9090
else:
9191
inc = Hf
92-
while idx < npts * ndim:
92+
while idx < n_total_pts * ndim:
9393
print(idx, end="\r")
9494
with torch.no_grad():
9595
hpred = model(outputs[:, idx - Lf : idx])
96-
rem = min(inc, npts * ndim - idx)
96+
rem = min(inc, n_total_pts * ndim - idx)
9797
outputs[:, idx : idx + rem] = hpred[:, :rem]
9898
outidx = idx // ndim - L
9999
outwin = outputs[:, idx : idx + rem].reshape(ntest, -1, ndim).cpu().numpy()

0 commit comments

Comments
 (0)