From 49c3ab4e9f8e646b7e536981badd4e28ee28c784 Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Sun, 12 Oct 2025 11:46:59 +0530 Subject: [PATCH 1/8] add feature scaling to d2 --- pytorch_forecasting/data/data_module.py | 84 +++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 34aa145e7..9d7739da9 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -339,6 +339,41 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: else torch.zeros((features.shape[0], 0)) ) + if self._scalers and self.continuous_indices: + for i, orig_idx in enumerate(self.continuous_indices): + col_name = self.time_series_metadata["cols"]["x"][orig_idx] + if col_name in self._scalers: + scaler = self._scalers[col_name] + feature_data = continuous[:, i] + + if isinstance(scaler, (TorchNormalizer, EncoderNormalizer)): + continuous[:, i] = scaler.transform(feature_data) + elif isinstance(scaler, (StandardScaler, RobustScaler)): + # if scaler is a sklearn scaler, we need to + # input numpy np.array + feature_data_np = feature_data.numpy().reshape(-1, 1) + scaled_feature_np = scaler.transform(feature_data_np) + continuous[:, i] = torch.tensor(scaled_feature_np.flatten()) + + if self._target_normalizer is not None: + if isinstance( + self._target_normalizer, (TorchNormalizer, EncoderNormalizer) + ): + # automatically handle multiple targets. + target = self._target_normalizer.transform(target) + elif isinstance(self._target_normalizer, (StandardScaler, RobustScaler)): + if target.ndim == 2: # (seq_len, n_targets) + target_scaled = [] + for i in range(target.shape[1]): + target_np = target[:, i].numpy().reshape(-1, 1) + scaled = self._target_normalizer.transform(target_np) + target_scaled.append(torch.tensor(scaled.flatten())) + target = torch.stack(target_scaled, dim=1) + else: + target_np = target.numpy().reshape(-1, 1) + target_scaled = self._target_normalizer.transform(target_np) + target = torch.tensor(target_scaled.flatten()) + return { "features": {"categorical": categorical, "continuous": continuous}, "target": target, @@ -623,6 +658,52 @@ def _create_windows(self, indices: torch.Tensor) -> list[tuple[int, int, int, in return windows + def _fit_scalers(self, train_indices: torch.Tensor): + """Fit scaler on the training dataset. + + Parameters + ---------- + train_indices : torch.Tensor + Indices of the training time series in `time_series_dataset`. + """ + + train_targets = [] + train_features = [] + + for series_idx in train_indices: + sample = self.time_series_dataset[series_idx.item()] + target = sample["y"] + features = sample["x"] + + train_targets.append(target) + train_features.append(features) + + train_targets = torch.cat(train_targets, dim=0) + train_features = torch.cat(train_features, dim=0) + + if self._target_normalizer is not None: + if isinstance( + self._target_normalizer, (TorchNormalizer, EncoderNormalizer) + ): + self._target_normalizer.fit(train_targets) + elif isinstance(self._target_normalizer, (StandardScaler, RobustScaler)): + target_np = train_targets.numpy() + if target_np.ndim == 1: + target_np = target_np.reshape(-1, 1) + self._target_normalizer.fit(target_np) + + if self._scaler and self.continuous_indices: + for i, orig_idx in enumerate(self.continuous_indices): + col_name = self.time_series_metadata["cols"]["x"]["orig_idx"] + if col_name in self._scalers: + scaler = self._scalers[col_name] + feature_data = train_features[:, orig_idx] + + if isinstance(scaler, (StandardScaler, RobustScaler)): + feature_data = feature_data.numpy().reshape(-1, 1) + + scaler.fit(feature_data) + def setup(self, stage: Optional[str] = None): """Prepare the datasets for training, validation, testing, or prediction. @@ -647,6 +728,9 @@ def setup(self, stage: Optional[str] = None): ] self._test_indices = self._split_indices[self._train_size + self._val_size :] + if stage is None or stage == "fit": + self._fit_scalers(self._train_indices) + if stage is None or stage == "fit": if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): self.train_windows = self._create_windows(self._train_indices) From 403e0f20e34272ab2a1dbb8706c7219b748f9a56 Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Sun, 12 Oct 2025 11:55:02 +0530 Subject: [PATCH 2/8] fix incorrect orig_idx index --- pytorch_forecasting/data/data_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 9d7739da9..91f8a53b7 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -694,7 +694,7 @@ def _fit_scalers(self, train_indices: torch.Tensor): if self._scaler and self.continuous_indices: for i, orig_idx in enumerate(self.continuous_indices): - col_name = self.time_series_metadata["cols"]["x"]["orig_idx"] + col_name = self.time_series_metadata["cols"]["x"][orig_idx] if col_name in self._scalers: scaler = self._scalers[col_name] feature_data = train_features[:, orig_idx] From 5661be4433de3c45b47ddd3ddee2cc9d6e2a66b2 Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Sun, 12 Oct 2025 13:11:33 +0530 Subject: [PATCH 3/8] fix incorrect attibute --- pytorch_forecasting/data/data_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 91f8a53b7..ae4faedef 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -692,7 +692,7 @@ def _fit_scalers(self, train_indices: torch.Tensor): target_np = target_np.reshape(-1, 1) self._target_normalizer.fit(target_np) - if self._scaler and self.continuous_indices: + if self._scalers and self.continuous_indices: for i, orig_idx in enumerate(self.continuous_indices): col_name = self.time_series_metadata["cols"]["x"][orig_idx] if col_name in self._scalers: From 38cefe4aaf13d4fecdf81b43289330315527694d Mon Sep 17 00:00:00 2001 From: Pranav Bhat Date: Thu, 16 Oct 2025 23:43:31 +0530 Subject: [PATCH 4/8] handle unfitted scalers --- pytorch_forecasting/data/data_module.py | 74 ++++++++++++++++--------- 1 file changed, 47 insertions(+), 27 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index ae4faedef..4d7d9147c 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -345,34 +345,54 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: if col_name in self._scalers: scaler = self._scalers[col_name] feature_data = continuous[:, i] - - if isinstance(scaler, (TorchNormalizer, EncoderNormalizer)): - continuous[:, i] = scaler.transform(feature_data) - elif isinstance(scaler, (StandardScaler, RobustScaler)): - # if scaler is a sklearn scaler, we need to - # input numpy np.array - feature_data_np = feature_data.numpy().reshape(-1, 1) - scaled_feature_np = scaler.transform(feature_data_np) - continuous[:, i] = torch.tensor(scaled_feature_np.flatten()) + try: + if isinstance(scaler, (TorchNormalizer, EncoderNormalizer)): + continuous[:, i] = scaler.transform(feature_data) + elif isinstance(scaler, (StandardScaler, RobustScaler)): + # if scaler is a sklearn scaler, we need to + # input numpy np.array + feature_data_np = feature_data.numpy().reshape(-1, 1) + scaled_feature_np = scaler.transform(feature_data_np) + continuous[:, i] = torch.tensor(scaled_feature_np.flatten()) + except Exception as e: + import warnings + + warnings.warn( + f"Failed to transform feature '{col_name}' with scaler: {e}. " # noqa: E501 + f"Using unscaled values.", + UserWarning, + ) + continue if self._target_normalizer is not None: - if isinstance( - self._target_normalizer, (TorchNormalizer, EncoderNormalizer) - ): - # automatically handle multiple targets. - target = self._target_normalizer.transform(target) - elif isinstance(self._target_normalizer, (StandardScaler, RobustScaler)): - if target.ndim == 2: # (seq_len, n_targets) - target_scaled = [] - for i in range(target.shape[1]): - target_np = target[:, i].numpy().reshape(-1, 1) - scaled = self._target_normalizer.transform(target_np) - target_scaled.append(torch.tensor(scaled.flatten())) - target = torch.stack(target_scaled, dim=1) - else: - target_np = target.numpy().reshape(-1, 1) - target_scaled = self._target_normalizer.transform(target_np) - target = torch.tensor(target_scaled.flatten()) + try: + if isinstance( + self._target_normalizer, (TorchNormalizer, EncoderNormalizer) + ): + # automatically handle multiple targets. + target = self._target_normalizer.transform(target) + elif isinstance( + self._target_normalizer, (StandardScaler, RobustScaler) + ): + if target.ndim == 2: # (seq_len, n_targets) + target_scaled = [] + for i in range(target.shape[1]): + target_np = target[:, i].numpy().reshape(-1, 1) + scaled = self._target_normalizer.transform(target_np) + target_scaled.append(torch.tensor(scaled.flatten())) + target = torch.stack(target_scaled, dim=1) + else: + target_np = target.numpy().reshape(-1, 1) + target_scaled = self._target_normalizer.transform(target_np) + target = torch.tensor(target_scaled.flatten()) + except Exception as e: + import warnings + + warnings.warn( + f"Failed to transform target with scaler: {e}. " # noqa: E501 + f"Using unscaled values.", + UserWarning, + ) return { "features": {"categorical": categorical, "continuous": continuous}, @@ -728,7 +748,7 @@ def setup(self, stage: Optional[str] = None): ] self._test_indices = self._split_indices[self._train_size + self._val_size :] - if stage is None or stage == "fit": + if (stage is None or stage == "fit") and len(self._train_indices) > 0: self._fit_scalers(self._train_indices) if stage is None or stage == "fit": From f24229050f2c508528db2998b59d87f5537e7c24 Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Fri, 17 Oct 2025 18:12:51 +0530 Subject: [PATCH 5/8] change accelerator to cpu in v2 notebook cell 10 --- docs/source/tutorials/ptf_V2_example.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/tutorials/ptf_V2_example.ipynb b/docs/source/tutorials/ptf_V2_example.ipynb index 81313151d..a502ff165 100644 --- a/docs/source/tutorials/ptf_V2_example.ipynb +++ b/docs/source/tutorials/ptf_V2_example.ipynb @@ -493,7 +493,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -681,7 +681,7 @@ "print(\"\\nTraining model...\")\n", "trainer = Trainer(\n", " max_epochs=5,\n", - " accelerator=\"auto\",\n", + " accelerator=\"cpu\",\n", " devices=1,\n", " enable_progress_bar=True,\n", " log_every_n_steps=10,\n", From 54da1c4ffbf9d3feb217c57a4b53be40aee5e7d8 Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Sat, 18 Oct 2025 14:16:32 +0530 Subject: [PATCH 6/8] use torch.from_numpy instead of torch.tensor for numpy to torch conversisons --- pytorch_forecasting/data/data_module.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 4d7d9147c..7575d9bb7 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -353,7 +353,9 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: # input numpy np.array feature_data_np = feature_data.numpy().reshape(-1, 1) scaled_feature_np = scaler.transform(feature_data_np) - continuous[:, i] = torch.tensor(scaled_feature_np.flatten()) + continuous[:, i] = torch.from_numpy( + scaled_feature_np.flatten() + ) # noqa: E501 except Exception as e: import warnings @@ -379,12 +381,12 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: for i in range(target.shape[1]): target_np = target[:, i].numpy().reshape(-1, 1) scaled = self._target_normalizer.transform(target_np) - target_scaled.append(torch.tensor(scaled.flatten())) + target_scaled.append(torch.from_numpy(scaled.flatten())) target = torch.stack(target_scaled, dim=1) else: target_np = target.numpy().reshape(-1, 1) target_scaled = self._target_normalizer.transform(target_np) - target = torch.tensor(target_scaled.flatten()) + target = torch.from_numpy(target_scaled.flatten()) except Exception as e: import warnings From c145e9bd3825170fd24c2d121d82b4fd043c6611 Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Sat, 18 Oct 2025 14:18:42 +0530 Subject: [PATCH 7/8] revert accelerator mode to auto from cpu for example notebook training script --- docs/source/tutorials/ptf_V2_example.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tutorials/ptf_V2_example.ipynb b/docs/source/tutorials/ptf_V2_example.ipynb index a502ff165..108f42e1b 100644 --- a/docs/source/tutorials/ptf_V2_example.ipynb +++ b/docs/source/tutorials/ptf_V2_example.ipynb @@ -681,7 +681,7 @@ "print(\"\\nTraining model...\")\n", "trainer = Trainer(\n", " max_epochs=5,\n", - " accelerator=\"cpu\",\n", + " accelerator=\"auto\",\n", " devices=1,\n", " enable_progress_bar=True,\n", " log_every_n_steps=10,\n", From ec4cf0385b043eaac95472883fe1d5fe6817b00c Mon Sep 17 00:00:00 2001 From: Pranav Bhat Date: Tue, 21 Oct 2025 10:25:36 +0530 Subject: [PATCH 8/8] potential fix for issue in trainingof v2 --- pytorch_forecasting/data/data_module.py | 32 +++++++++++++++++++------ 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 7575d9bb7..c51c582c7 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -351,11 +351,18 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: elif isinstance(scaler, (StandardScaler, RobustScaler)): # if scaler is a sklearn scaler, we need to # input numpy np.array - feature_data_np = feature_data.numpy().reshape(-1, 1) + requires_grad = feature_data.requires_grad + device = feature_data.device + feature_data_np = ( + feature_data.cpu().detach().numpy().reshape(-1, 1) + ) # noqa: E501 scaled_feature_np = scaler.transform(feature_data_np) - continuous[:, i] = torch.from_numpy( + scaled_tensor = torch.from_numpy( scaled_feature_np.flatten() - ) # noqa: E501 + ).to(device) + if requires_grad: + scaled_tensor = scaled_tensor.requires_grad_(True) + continuous[:, i] = scaled_tensor except Exception as e: import warnings @@ -376,17 +383,28 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: elif isinstance( self._target_normalizer, (StandardScaler, RobustScaler) ): + requires_grad = target.requires_grad + device = target.device if target.ndim == 2: # (seq_len, n_targets) target_scaled = [] for i in range(target.shape[1]): - target_np = target[:, i].numpy().reshape(-1, 1) + target_np = ( + target[:, i].detach().cpu().numpy().reshape(-1, 1) + ) # noqa: E501 scaled = self._target_normalizer.transform(target_np) - target_scaled.append(torch.from_numpy(scaled.flatten())) + scaled_tensor = torch.from_numpy(scaled.flatten()).to( + device + ) # noqa: E501 + if requires_grad: + scaled_tensor = scaled_tensor.requires_grad_(True) + target_scaled.append(scaled_tensor) target = torch.stack(target_scaled, dim=1) else: - target_np = target.numpy().reshape(-1, 1) + target_np = target.detach().cpu().numpy().reshape(-1, 1) target_scaled = self._target_normalizer.transform(target_np) - target = torch.from_numpy(target_scaled.flatten()) + target = torch.from_numpy(target_scaled.flatten()).to(device) + if requires_grad: + target = target.requires_grad_(True) except Exception as e: import warnings