-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
147 lines (116 loc) · 6.33 KB
/
Copy pathtrain.py
File metadata and controls
147 lines (116 loc) · 6.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from datetime import datetime
import logging
import pprint
import os
import argparse
from modules.models.vqvae import VQVAE
from modules.models.forecasting import Forecaster
from modules.utils.training import train_vqvae, train_forecaster
from modules.utils.visualize import plot_loss_curve, plot_perplexity_curve, visualize_reconstructions, visualize_codebook, visualize_codebook_utilization, visualize_forecasts
from modules.data.preprocessing import TimeSeriesDataset, ForecastDataset
from modules.utils.helpers import count_parameters, load_configuration
def main() -> None:
# Parse command-line arguments
parser = argparse.ArgumentParser(description='Train VQVAE and/or Forecaster')
parser.add_argument("--config", type=str, default="config/default.yaml", help="Path to configuration file (default: config/default.yaml)")
parser.add_argument("--train", type=str, choices=["vqvae", "forecaster", "both"], default='both', help="Which models to train (default: both)")
parser.add_argument("--vqvae", type=str, default=None, help="Path to pretrained VQVAE model weights")
parser.add_argument("--forecaster", type=str, default=None, help="Path to pretrained Forecaster model weights")
args = parser.parse_args()
# Logging
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir = f"experiments/run_{timestamp}"
os.makedirs(save_dir, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s",
handlers=[
logging.FileHandler(os.path.join(save_dir, "training.log"))
]
)
logging.info(f"Command-line arguments: {vars(args)}")
config = load_configuration(args.config)
logging.info(f"CONFIGURATION\n {pprint.pformat(config)}")
# Configuration
num_embeddings = config["vqvae"]["num_embeddings"]
embedding_dim = config["vqvae"]["embedding_dim"]
commitment_cost = config["vqvae"]["commitment_cost"]
hidden_channels = config["vqvae"]["hidden_channels"]
compression_factor = config["vqvae"]["compression_factor"]
input_length = config["forecaster"]["input_length"]
output_length = config["forecaster"]["output_length"]
d_model = config["forecaster"]["d_model"]
num_heads = config["forecaster"]["num_heads"]
num_encoder_layers = config["forecaster"]["num_encoder_layers"]
ff_dim = config["forecaster"]["ff_dim"]
dropout = config["forecaster"]["dropout"]
data_folder = config["data"]["data_path"]
num_workers = config["data"]["num_workers"]
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize models
vqvae = VQVAE(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
commitment_cost=commitment_cost,
hidden_channels=hidden_channels,
compression_factor=compression_factor
).to(device)
vqvae_trained = False
if args.vqvae:
vqvae_trained = True
vqvae.load_state_dict(torch.load(args.vqvae, map_location=device))
logging.info(f"Loaded pretrained VQVAE weights from {args.vqvae}")
logging.info(f"VQVAE number of parameters: {count_parameters(vqvae)}")
assert input_length % vqvae.compression_factor == 0, \
"The Forecaster `input_length` must be divisible by the `compression_factor` of the VQVAE"
forecaster = Forecaster(
context_length=input_length // vqvae.compression_factor,
input_length=input_length,
output_length=output_length,
vocab_size=vqvae.num_embeddings,
d_model=d_model,
num_heads=num_heads,
num_encoder_layers=num_encoder_layers,
ff_dim=ff_dim,
dropout=dropout
).to(device)
if args.forecaster:
forecaster.load_state_dict(torch.load(args.forecaster, map_location=device))
logging.info(f"Loaded pretrained Forecaster weights from {args.forecaster}")
logging.info(f"Forecaster number of parameters: {count_parameters(forecaster)}")
## VQVAE Training
if args.train in {"vqvae", "both"}:
timesteps = config["vqvae"]["training"]["timesteps"]
batch_size = config["vqvae"]["training"]["batch_size"]
learning_rate = config["vqvae"]["training"]["learning_rate"]
num_epochs = config["vqvae"]["training"]["num_epochs"]
train_dataset = TimeSeriesDataset(timesteps, folder_path=data_folder)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
optimizer = optim.Adam(vqvae.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
losses, perplexities = train_vqvae(vqvae, optimizer, train_dataloader, num_epochs, save_dir=save_dir, scheduler=scheduler, device=device)
vqvae_trained = True
plot_loss_curve(losses, save_dir, title="VQVAE Loss Curve")
plot_perplexity_curve(perplexities, save_dir, title="VQVAE Perplexity Curve")
visualize_reconstructions(vqvae, train_dataloader, save_dir, device)
visualize_codebook(vqvae, save_dir)
visualize_codebook_utilization(vqvae, train_dataloader, vqvae.num_embeddings, save_dir, device)
# Forecaster Training
if args.train in {"forecaster", "both"}:
if not vqvae_trained:
raise ValueError("Forecaster training requires a pretrained VQVAE. Specify with --vqvae or train with --train=both")
batch_size = config["forecaster"]["training"]["batch_size"]
learning_rate = config["forecaster"]["training"]["learning_rate"]
num_epochs = config["forecaster"]["training"]["num_epochs"]
train_dataset = ForecastDataset(input_length, output_length, folder_path=data_folder)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
optimizer = optim.Adam(forecaster.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
losses = train_forecaster(forecaster, optimizer, vqvae, train_dataloader, num_epochs, save_dir, scheduler=scheduler, device=device)
plot_loss_curve(losses, save_dir, "Forecaster Loss Curve")
visualize_forecasts(forecaster, vqvae, train_dataloader, save_dir, device)
if __name__ == "__main__":
main()