Skip to content
2 changes: 1 addition & 1 deletion docs/source/tutorials/ptf_V2_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand Down
124 changes: 124 additions & 0 deletions pytorch_forecasting/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,81 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]:
else torch.zeros((features.shape[0], 0))
)

if self._scalers and self.continuous_indices:
for i, orig_idx in enumerate(self.continuous_indices):
col_name = self.time_series_metadata["cols"]["x"][orig_idx]
if col_name in self._scalers:
scaler = self._scalers[col_name]
feature_data = continuous[:, i]
try:
if isinstance(scaler, (TorchNormalizer, EncoderNormalizer)):
continuous[:, i] = scaler.transform(feature_data)
elif isinstance(scaler, (StandardScaler, RobustScaler)):
# if scaler is a sklearn scaler, we need to
# input numpy np.array
requires_grad = feature_data.requires_grad
device = feature_data.device
feature_data_np = (
feature_data.cpu().detach().numpy().reshape(-1, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have doubt: Wouldn't using detach again detach the tensor from the computation graph? That would again lead to the same issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as my knowledge of pytorch goes, I think it's a good practice to use .detach() before converting the pytorch tensor to a numpy array. Anyways, the numpy array will not track the gradients, so this won't matter.

) # noqa: E501
scaled_feature_np = scaler.transform(feature_data_np)
scaled_tensor = torch.from_numpy(
scaled_feature_np.flatten()
).to(device)
if requires_grad:
scaled_tensor = scaled_tensor.requires_grad_(True)
continuous[:, i] = scaled_tensor
except Exception as e:
import warnings

warnings.warn(
f"Failed to transform feature '{col_name}' with scaler: {e}. " # noqa: E501
f"Using unscaled values.",
UserWarning,
)
continue

if self._target_normalizer is not None:
try:
if isinstance(
self._target_normalizer, (TorchNormalizer, EncoderNormalizer)
):
# automatically handle multiple targets.
target = self._target_normalizer.transform(target)
elif isinstance(
self._target_normalizer, (StandardScaler, RobustScaler)
):
requires_grad = target.requires_grad
device = target.device
if target.ndim == 2: # (seq_len, n_targets)
target_scaled = []
for i in range(target.shape[1]):
target_np = (
target[:, i].detach().cpu().numpy().reshape(-1, 1)
) # noqa: E501
scaled = self._target_normalizer.transform(target_np)
scaled_tensor = torch.from_numpy(scaled.flatten()).to(
device
) # noqa: E501
if requires_grad:
scaled_tensor = scaled_tensor.requires_grad_(True)
target_scaled.append(scaled_tensor)
target = torch.stack(target_scaled, dim=1)
else:
target_np = target.detach().cpu().numpy().reshape(-1, 1)
target_scaled = self._target_normalizer.transform(target_np)
target = torch.from_numpy(target_scaled.flatten()).to(device)
if requires_grad:
target = target.requires_grad_(True)
except Exception as e:
import warnings

warnings.warn(
f"Failed to transform target with scaler: {e}. " # noqa: E501
f"Using unscaled values.",
UserWarning,
)

return {
"features": {"categorical": categorical, "continuous": continuous},
"target": target,
Expand Down Expand Up @@ -623,6 +698,52 @@ def _create_windows(self, indices: torch.Tensor) -> list[tuple[int, int, int, in

return windows

def _fit_scalers(self, train_indices: torch.Tensor):
"""Fit scaler on the training dataset.

Parameters
----------
train_indices : torch.Tensor
Indices of the training time series in `time_series_dataset`.
"""

train_targets = []
train_features = []

for series_idx in train_indices:
sample = self.time_series_dataset[series_idx.item()]
target = sample["y"]
features = sample["x"]

train_targets.append(target)
train_features.append(features)

train_targets = torch.cat(train_targets, dim=0)
train_features = torch.cat(train_features, dim=0)

if self._target_normalizer is not None:
if isinstance(
self._target_normalizer, (TorchNormalizer, EncoderNormalizer)
):
self._target_normalizer.fit(train_targets)
elif isinstance(self._target_normalizer, (StandardScaler, RobustScaler)):
target_np = train_targets.numpy()
if target_np.ndim == 1:
target_np = target_np.reshape(-1, 1)
self._target_normalizer.fit(target_np)

if self._scalers and self.continuous_indices:
for i, orig_idx in enumerate(self.continuous_indices):
col_name = self.time_series_metadata["cols"]["x"][orig_idx]
if col_name in self._scalers:
scaler = self._scalers[col_name]
feature_data = train_features[:, orig_idx]

if isinstance(scaler, (StandardScaler, RobustScaler)):
feature_data = feature_data.numpy().reshape(-1, 1)

scaler.fit(feature_data)

def setup(self, stage: Optional[str] = None):
"""Prepare the datasets for training, validation, testing, or prediction.

Expand All @@ -647,6 +768,9 @@ def setup(self, stage: Optional[str] = None):
]
self._test_indices = self._split_indices[self._train_size + self._val_size :]

if (stage is None or stage == "fit") and len(self._train_indices) > 0:
self._fit_scalers(self._train_indices)

if stage is None or stage == "fit":
if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"):
self.train_windows = self._create_windows(self._train_indices)
Expand Down
Loading