diff --git a/README.md b/README.md index fb3269b9..d7101e64 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,10 @@ See [notebooks/](notebooks/) for visualizations explaining some concepts behind [example.py](example.py) is a self-contained training script for MNIST and CIFAR that imports the standalone S4 file. The default settings `python example.py` reaches 88% accuracy on sequential CIFAR with a very simple S4D model of 200k parameters. This script can be used as an example for using S4 variants in external repositories. +In addition, [`line_forecast.py`](line_forecast.py) demonstrates training S4 on a synthetic linear forecasting task using a one-step prediction horizon, printing validation loss to show the model learning a straight line. After training, it runs autoregressive generation using the model's `step` function so you can observe the predictions converge to the target line. +The script also saves a PNG plot comparing the input sequence, ground-truth future line, and the generated predictions. +For a version that uses the repository's Hydra configuration, run [`run_line_forecast.py`](run_line_forecast.py), which internally calls `train.py` with suitable overrides. + ### Training with this Repository (Internal Usage) diff --git a/configs/dataset/line.yaml b/configs/dataset/line.yaml new file mode 100644 index 00000000..e63ff455 --- /dev/null +++ b/configs/dataset/line.yaml @@ -0,0 +1,12 @@ +_name_: line +seq_len: 24 +pred_len: 12 +n_train: 1000 +n_val: 200 +n_test: 200 +slope_range: [0.1, 1.0] +intercept_range: [0.0, 1.0] +noise_std: 0.0 +seed: 0 +__l_max: ${eval:${.seq_len}+${.pred_len}} + diff --git a/line_forecast.py b/line_forecast.py new file mode 100644 index 00000000..90397e64 --- /dev/null +++ b/line_forecast.py @@ -0,0 +1,143 @@ +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from models.s4.s4 import S4Block +from src.dataloaders.datasets.line import LineDataset + + +def plot_predictions(x, y, preds, filename="line_forecast_plot.png"): + """Plot the input sequence, target, and generated prediction.""" + x = x.squeeze(-1).cpu().numpy() + y = y.squeeze(-1).cpu().numpy() + preds = preds.squeeze(-1).cpu().numpy() + seq_len = len(x) + t_input = list(range(seq_len)) + t_future = list(range(seq_len, seq_len + len(y))) + + plt.figure() + plt.plot(t_input, x, label="input") + plt.plot(t_future, y, label="target") + plt.plot(t_future, preds, "--", label="generated") + plt.legend() + plt.xlabel("t") + plt.ylabel("value") + plt.tight_layout() + plt.savefig(filename) + plt.close() + print(f"Saved plot to {filename}") + + + +class ForecastModel(nn.Module): + def __init__(self, d_model=64, n_layers=2, dropout=0.0): + super().__init__() + self.encoder = nn.Linear(1, d_model) + self.s4_layers = nn.ModuleList( + + [ + S4Block(d_model, transposed=False, dropout=dropout) + for _ in range(n_layers) + ] + + ) + self.norms = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(n_layers)]) + self.decoder = nn.Linear(d_model, 1) + + + def setup_step(self): + for layer in self.s4_layers: + layer.setup_step() + + def default_state(self, batch_size, device=None): + return [ + layer.default_state(batch_size, device=device) for layer in self.s4_layers + ] + + def step(self, x_t, states): + x = self.encoder(x_t) + new_states = [] + for layer, norm, s in zip(self.s4_layers, self.norms, states): + y, s = layer.step(x, s) + x = norm(x + y) + new_states.append(s) + out = self.decoder(x) + return out, new_states + + def forward(self, x): + x = self.encoder(x) + for layer, norm in zip(self.s4_layers, self.norms): + z, _ = layer(x) + x = norm(x + z) + x_last = x[:, -1] + out = self.decoder(x_last) + return out + + +def train_model(): + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Use a 1-step forecast horizon so the model output matches the target + train_dataset = LineDataset(pred_len=1) + val_dataset = LineDataset(pred_len=1, seed=1) + train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=32) + + model = ForecastModel().to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + criterion = nn.MSELoss() + + for epoch in range(10): + model.train() + + train_loss = 0.0 + for x, y in train_loader: + x = x.to(device) + y = y.to(device).squeeze(1) + optimizer.zero_grad() + out = model(x) + loss = criterion(out, y) + loss.backward() + optimizer.step() + train_loss += loss.item() * x.size(0) + avg_train = train_loss / len(train_loader.dataset) + + model.eval() + plot_predictions(x[0], y_true[0], preds[0]) + val_loss = 0.0 + with torch.no_grad(): + for x, y in val_loader: + x = x.to(device) + y = y.to(device).squeeze(1) + out = model(x) + loss = criterion(out, y) + val_loss += loss.item() * x.size(0) + avg_val = val_loss / len(val_loader.dataset) + + print(f"Epoch {epoch+1}: train {avg_train:.6f}, val {avg_val:.6f}") + + # Demonstrate autoregressive generation using the trained model + model.eval() + with torch.no_grad(): + x, y_true = next(iter(val_loader)) + x = x.to(device) + y_true = y_true.to(device) + model.setup_step() + state = model.default_state(x.size(0), device=device) + for t in range(x.size(1)): + _, state = model.step(x[:, t], state) + preds = [] + x_t = x[:, -1] + for _ in range(y_true.size(1)): + out, state = model.step(x_t, state) + preds.append(out) + x_t = out + preds = torch.stack(preds, dim=1) + print("Target:", y_true[0].squeeze().cpu().numpy()) + print("Preds :", preds[0].squeeze().cpu().numpy()) + + + +if __name__ == "__main__": + train_model() diff --git a/run_line_forecast.py b/run_line_forecast.py new file mode 100644 index 00000000..f7bfc754 --- /dev/null +++ b/run_line_forecast.py @@ -0,0 +1,27 @@ +import os + +from hydra import compose, initialize +from omegaconf import OmegaConf + +from train import train + + +def main(): + """Train S4 on the synthetic line forecasting dataset.""" + overrides = [ + "pipeline=informer", + "model=s4", + "dataset=line", + "dataset.pred_len=1", + "loader.batch_size=32", + "trainer.max_epochs=10", + "wandb=null", # disable wandb logging by default + ] + with initialize(config_path="configs"): + cfg = compose(config_name="config.yaml", overrides=overrides) + print(OmegaConf.to_yaml(cfg)) + train(cfg) + + +if __name__ == "__main__": + main() diff --git a/src/dataloaders/__init__.py b/src/dataloaders/__init__.py index 9213cde0..24df09e5 100644 --- a/src/dataloaders/__init__.py +++ b/src/dataloaders/__init__.py @@ -1,2 +1,2 @@ -from . import audio, basic, et, lm, lra, synthetic, ts, vision +from . import audio, basic, et, lm, lra, synthetic, ts, vision, line from .base import SequenceDataset diff --git a/src/dataloaders/datasets/line.py b/src/dataloaders/datasets/line.py new file mode 100644 index 00000000..f47fba17 --- /dev/null +++ b/src/dataloaders/datasets/line.py @@ -0,0 +1,26 @@ +import torch + +class LineDataset(torch.utils.data.TensorDataset): + def __init__(self, seq_len=24, pred_len=12, n_samples=1000, + slope_range=(0.1, 1.0), intercept_range=(0.0, 1.0), + noise_std=0.0, seed=0): + self.seq_len = seq_len + self.pred_len = pred_len + self.n_samples = n_samples + self.slope_range = slope_range + self.intercept_range = intercept_range + self.noise_std = noise_std + self.seed = seed + + generator = torch.Generator().manual_seed(seed) + total_len = seq_len + pred_len + t = torch.arange(total_len, dtype=torch.float32) + slopes = torch.empty(n_samples).uniform_(slope_range[0], slope_range[1], generator=generator) + intercepts = torch.empty(n_samples).uniform_(intercept_range[0], intercept_range[1], generator=generator) + lines = slopes[:, None] * t + intercepts[:, None] + if noise_std > 0: + lines += noise_std * torch.randn(n_samples, total_len, generator=generator) + x = lines[:, :seq_len].unsqueeze(-1) + y = lines[:, seq_len:].unsqueeze(-1) + super().__init__(x, y) + self.forecast_horizon = pred_len diff --git a/src/dataloaders/line.py b/src/dataloaders/line.py new file mode 100644 index 00000000..b60473ae --- /dev/null +++ b/src/dataloaders/line.py @@ -0,0 +1,49 @@ +from src.dataloaders.base import SequenceDataset +from .datasets.line import LineDataset + +class Line(SequenceDataset): + _name_ = "line" + d_input = 1 + d_output = 1 + + @property + def init_defaults(self): + return { + "seq_len": 24, + "pred_len": 12, + "n_train": 1000, + "n_val": 200, + "n_test": 200, + "slope_range": (0.1, 1.0), + "intercept_range": (0.0, 1.0), + "noise_std": 0.0, + "seed": 0, + } + + @property + def l_output(self): + return self.pred_len + + def setup(self): + self.dataset_train = LineDataset( + self.seq_len, self.pred_len, self.n_train, + self.slope_range, self.intercept_range, + self.noise_std, seed=self.seed + ) + self.dataset_val = LineDataset( + self.seq_len, self.pred_len, self.n_val, + self.slope_range, self.intercept_range, + self.noise_std, seed=self.seed + 1 + ) + self.dataset_test = LineDataset( + self.seq_len, self.pred_len, self.n_test, + self.slope_range, self.intercept_range, + self.noise_std, seed=self.seed + 2 + ) + # forecast horizon property used by forecasting task + self.dataset_train.forecast_horizon = self.pred_len + self.dataset_val.forecast_horizon = self.pred_len + self.dataset_test.forecast_horizon = self.pred_len + + def __str__(self): + return f"line{self.seq_len}_{self.pred_len}"