generated from Axect/pytorch_template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
332 lines (284 loc) · 10.4 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import torch
from torch import nn
import math
from mambapy.mamba import Mamba, MambaConfig
def create_net(sizes):
net = []
for i in range(len(sizes) - 1):
net.append(nn.Linear(sizes[i], sizes[i + 1]))
if i < len(sizes) - 2:
net.append(nn.GELU())
return nn.Sequential(*net)
@torch.compile
class DeepONet(nn.Module):
def __init__(self, hparams):
super().__init__()
nodes = hparams["nodes"]
layers = hparams["layers"]
branches = hparams["branches"]
input_size = 100
output_size = 1
self.branch_net = create_net(
[input_size] + [nodes] * (layers - 1) + [2 * branches]
)
self.trunk_net = create_net(
[output_size] + [nodes] * (layers - 1) + [2 * branches]
)
self.bias = nn.Parameter(torch.randn(2), requires_grad=True)
def forward(self, u, y):
B, _ = u.shape
window = y.shape[1]
branch_out = self.branch_net(u) # B x 2p
branch_out = branch_out.view(B, -1, 2) # B x p x 2
trunk_out = torch.stack(
[self.trunk_net(y[:, i : i + 1]).view(B, -1, 2) for i in range(window)],
dim=3,
)
pred = torch.einsum("bpq,bpqw->bqw", branch_out, trunk_out)
pred = pred.permute(0, 2, 1) # B x W x 2
pred = pred + self.bias
return pred[:, :, 0], pred[:, :, 1]
class Encoder(nn.Module):
def __init__(self, hidden_size=10, num_layers=1, dropout=0.1):
super().__init__()
self.rnn = nn.LSTM(
input_size=1,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
bidirectional=True,
)
def forward(self, x):
"""
- x: (B, W, 1)
- h_n: (D * L, B, H) (D = 2 for bidirectional)
- c_n: (D * L, B, H) (D = 2 for bidirectional)
"""
_, (h_n, c_n) = self.rnn(x)
return h_n, c_n
class Decoder(nn.Module):
def __init__(self, hidden_size=10, num_layers=1, dropout=0.1):
super().__init__()
self.rnn = nn.LSTM(
input_size=1,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
bidirectional=True,
)
self.fc = nn.Linear(2 * hidden_size, 1)
def forward(self, x, h_c):
"""
- x: (B, W, 1)
- h_c: (D * L, B, H) (D = 2 for bidirectional)
- o: (B, W, D * H) (D = 2 for bidirectional)
- out: (B, W, 1)
"""
o, _ = self.rnn(x, h_c)
out = self.fc(o)
return out
class VaRONet(nn.Module):
def __init__(self, hparams):
super().__init__()
hidden_size = hparams["hidden_size"]
num_layers = hparams["num_layers"]
latent_size = hparams["latent_size"]
dropout = hparams["dropout"]
kl_weight = hparams["kl_weight"]
self.branch_net = Encoder(hidden_size, num_layers, dropout)
self.trunk_x_net = Decoder(hidden_size, num_layers, dropout)
self.trunk_p_net = Decoder(hidden_size, num_layers, dropout)
self.fc_mu = nn.Linear(hidden_size, latent_size)
self.fc_var = nn.Linear(hidden_size, latent_size)
self.fc_z_x = nn.Linear(latent_size, hidden_size)
self.fc_z_p = nn.Linear(latent_size, hidden_size)
self.kl_weight = kl_weight
self.reparametrize = True
def forward(self, u, y):
B, W1 = u.shape
_, W2 = y.shape
u = u.view(B, W1, 1)
y = y.view(B, W2, 1)
# Encoding
(h0, c0) = self.branch_net(u)
# Reparameterize (VAE)
mu = self.fc_mu(h0) # D*L, B, Z
logvar = self.fc_var(h0) # D*L, B, Z
mu = mu.permute(1, 0, 2).contiguous() # B, D*L, Z
logvar = logvar.permute(1, 0, 2).contiguous() # B, D*L, Z
if self.reparametrize:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
else:
z = mu
# Decoding
hz_x = self.fc_z_x(z) # B, D * L, H
hz_p = self.fc_z_p(z) # B, D * L, H
hzp_x = hz_x.permute(1, 0, 2).contiguous() # D * L, B, H
hzp_p = hz_p.permute(1, 0, 2).contiguous() # D * L, B, H
h_c_x = (hzp_x, c0)
h_c_p = (hzp_p, c0)
o_x = self.trunk_x_net(y, h_c_x) # B, W2, 1
o_p = self.trunk_p_net(y, h_c_p) # B, W2, 1
return o_x.squeeze(-1), o_p.squeeze(-1), mu, logvar
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=100):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
"""
- x: (B, W, d_model)
- self.pe: (1, M, d_model)
- self.pe[:, :x.size(1), :]: (1, W, d_model)
- output: (B, W, d_model)
"""
x = x + self.pe[:, : x.size(1), :]
return x
class TFEncoder(nn.Module):
def __init__(self, d_model, nhead, num_layers, dim_feedforward, dropout):
super().__init__()
self.d_model = d_model
self.embedding = nn.Linear(1, d_model)
self.pos_encoder = PositionalEncoding(d_model)
# self.pos_encoder = LearnablePositionalEncoding(d_model)
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True,
)
self.transformer_encoder = nn.TransformerEncoder(
self.encoder_layer, num_layers, norm=nn.LayerNorm(d_model)
)
def forward(self, x):
"""
- x: (B, W1, 1)
- x (after embedding): (B, W1, d_model)
- out: (B, W1, d_model)
"""
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoder(x)
out = self.transformer_encoder(x)
return out
class TFDecoder(nn.Module):
def __init__(self, d_model, nhead, num_layers, dim_feedforward, dropout):
super().__init__()
self.d_model = d_model
self.embedding = nn.Linear(1, d_model)
self.pos_encoder = PositionalEncoding(d_model)
# self.pos_encoder = LearnablePositionalEncoding(d_model)
self.decoder_layer = nn.TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True,
)
self.transformer_decoder = nn.TransformerDecoder(
self.decoder_layer, num_layers, norm=nn.LayerNorm(d_model)
)
self.fc = nn.Linear(d_model, 2)
def forward(self, x, memory):
"""
- x: (B, W2, 1)
- x (after embedding): (B, W2, d_model)
- memory: (B, W1, d_model)
- out: (B, W2, d_model)
- out (after fc): (B, W2, 2)
"""
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoder(x)
out = self.transformer_decoder(x, memory)
return out
@torch.compile
class TraONet(nn.Module):
def __init__(self, hparams):
super().__init__()
d_model = hparams["d_model"]
nhead = hparams["nhead"]
num_layers = hparams["num_layers"]
dim_feedforward = hparams["dim_feedforward"]
dropout = hparams["dropout"]
self.branch_net = TFEncoder(
d_model, nhead, num_layers, dim_feedforward, dropout
)
self.trunk_net = TFDecoder(d_model, nhead, num_layers, dim_feedforward, dropout)
def forward(self, u, y):
"""
- u: (B, W1)
- y: (B, W2)
- u (after reshape): (B, W1, 1)
- y (after reshape): (B, W2, 1)
- memory: (B, W1, d_model)
- o: (B, W2)
"""
B, W1 = u.shape
_, W2 = y.shape
u = u.view(B, W1, 1)
y = y.view(B, W2, 1)
# Encoding
memory = self.branch_net(u)
# Decoding
o = self.trunk_net(y, memory)
return o[:, :, 0], o[:, :, 1]
# ┌──────────────────────────────────────────────────────────┐
# Mamba
# └──────────────────────────────────────────────────────────┘
class MambaEncoder(nn.Module):
def __init__(self, d_model, num_layers):
super().__init__()
self.embedding = nn.Linear(1, d_model)
config = MambaConfig(d_model=d_model, n_layers=num_layers)
self.mamba = Mamba(config)
def forward(self, x):
"""
- x: (B, W, 1)
- x (after embedding): (B, W, d_model)
- out: (B, W, d_model)
"""
x = self.embedding(x)
out = self.mamba(x)
return out
class MambONet(nn.Module):
def __init__(self, hparams):
super().__init__()
d_model = hparams["d_model"] # hidden_size
# d_state = hparams["d_state"] # SSM state expansion factor
# d_conv = hparams["d_conv"] # Local convolution width
# expand = hparams["expand"] # Block expansion factor
num_layers1 = hparams["num_layers1"] # Number of layers (Mamba)
n_head = hparams["n_head"] # Number of heads
num_layers2 = hparams["num_layers2"] # Number of layers (Decoder)
d_ff = hparams["d_ff"] # Feedforward dimension
self.encoder = MambaEncoder(d_model, num_layers1)
self.decoder = TFDecoder(d_model, n_head, num_layers2, d_ff, 0.0)
def forward(self, u, y):
"""
- u: (B, W1)
- y: (B, W2)
- u (after reshape): (B, W1, 1)
- y (after reshape): (B, W2, 1)
- memory: (B, W1, d_model)
- o: (B, W2)
"""
B, W1 = u.shape
_, W2 = y.shape
u = u.view(B, W1, 1)
y = y.view(B, W2, 1)
# Encoding
memory = self.encoder(u)
# Decoding
o = self.decoder(y, memory)
return o[:, :, 0], o[:, :, 1]