-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMLP.py
22 lines (17 loc) · 813 Bytes
/
MLP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from torch import nn
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.output_dim = output_dim
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
self.bns = nn.ModuleList(nn.BatchNorm1d(k) for k in h + [output_dim])
def forward(self, x):
B, N, D = x.size()
x = x.reshape(B * N, D)
for i, (bn, layer) in enumerate(zip(self.bns, self.layers)):
x = nn.functional.relu(bn(layer(x))) if i < self.num_layers - 1 else layer(x)
x = x.view(B, N, self.output_dim)
return x