@@ -57,10 +57,10 @@ def forward(self, inputs):
57
57
test_series = dset ["solutions" ][cfgyml .ntrain + cfgyml .nval :]
58
58
59
59
test_series = test_series [:, ::spacing ]
60
- ntest , npts , ndim = test_series .shape
60
+ ntest , n_total_pts , ndim = test_series .shape
61
61
H = cfgyml .H
62
62
L = cfgyml .input_size
63
- npts -= L
63
+ npts = n_total_pts - L
64
64
Hf = H * ndim
65
65
Lf = L * ndim
66
66
@@ -70,7 +70,7 @@ def forward(self, inputs):
70
70
model .eval ()
71
71
72
72
inputs = torch .from_numpy (test_series [:, :L ].reshape (ntest , - 1 )).type (torch .float32 )
73
- outputs = torch .empty ((ntest , npts * ndim )).to (device )
73
+ outputs = torch .empty ((ntest , n_total_pts * ndim )).to (device )
74
74
75
75
outputs [:, :Lf ] = inputs
76
76
idx = Lf
@@ -89,11 +89,11 @@ def forward(self, inputs):
89
89
inc = args .inc * ndim
90
90
else :
91
91
inc = Hf
92
- while idx < npts * ndim :
92
+ while idx < n_total_pts * ndim :
93
93
print (idx , end = "\r " )
94
94
with torch .no_grad ():
95
95
hpred = model (outputs [:, idx - Lf : idx ])
96
- rem = min (inc , npts * ndim - idx )
96
+ rem = min (inc , n_total_pts * ndim - idx )
97
97
outputs [:, idx : idx + rem ] = hpred [:, :rem ]
98
98
outidx = idx // ndim - L
99
99
outwin = outputs [:, idx : idx + rem ].reshape (ntest , - 1 , ndim ).cpu ().numpy ()
0 commit comments