-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMultiLayerNet.py
37 lines (32 loc) · 1.38 KB
/
MultiLayerNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from config import *
import torch
class MultiLayerNet(torch.nn.Module):
def __init__(self, D_in, H, D_out):
"""
In the constructor we instantiate two nn.Linear modules and assign them as
member variables.
"""
super(MultiLayerNet, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, H)
self.linear3 = torch.nn.Linear(H, H)
self.linear4 = torch.nn.Linear(H, D_out)
torch.nn.init.constant_(self.linear1.bias, 0.)
torch.nn.init.constant_(self.linear2.bias, 0.)
torch.nn.init.constant_(self.linear3.bias, 0.)
torch.nn.init.constant_(self.linear4.bias, 0.)
torch.nn.init.normal_(self.linear1.weight, mean=0, std=0.1)
torch.nn.init.normal_(self.linear2.weight, mean=0, std=0.1)
torch.nn.init.normal_(self.linear3.weight, mean=0, std=0.1)
torch.nn.init.normal_(self.linear4.weight, mean=0, std=0.1)
def forward(self, x):
"""
In the forward function we accept a Tensor of input data and we must return
a Tensor of output data. We can use Modules defined in the constructor as
well as arbitrary operators on Tensors.
"""
y1 = torch.tanh(self.linear1(x))
y2 = torch.tanh(self.linear2(y1))
y3 = torch.tanh(self.linear3(y2))
y = self.linear4(y3)
return y