Skip to content

Commit 076022a

Browse files
committed
black
1 parent 1506fc7 commit 076022a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pytorch_forecasting/data/timeseries.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,9 +1567,9 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
15671567

15681568
# switch some variables to nan if encode length is 0
15691569
if encoder_length == 0 and len(self.dropout_categoricals) > 0:
1570-
data_cat[
1571-
:, [self.flat_categoricals.index(c) for c in self.dropout_categoricals]
1572-
] = 0 # zero is encoded nan
1570+
fc = self.flat_categoricals
1571+
dc = self.dropout_categoricals
1572+
data_cat[:, [fc.index(c) for c in dc]] = 0 # zero is encoded nan
15731573

15741574
assert decoder_length > 0, "Decoder length should be greater than 0"
15751575
assert encoder_length >= 0, "Encoder length should be at least 0"

0 commit comments

Comments
 (0)