|
1 |
| -import torch |
2 |
| - |
3 |
| -from math import sqrt |
4 | 1 | from torch import Tensor, exp, log, nn
|
5 | 2 | from torch.nn.parameter import Parameter
|
6 | 3 | from torch.nn.init import xavier_uniform_
|
7 |
| -from torch.nn.functional import tanh, sigmoid, linear |
8 |
| -from NeuralAccumulator import NeuralAccumulator |
| 4 | +from torch.nn.functional import sigmoid, linear |
| 5 | +from .nac_cell import NacCell |
9 | 6 |
|
10 | 7 |
|
11 |
| -class NALU(nn.Module): |
| 8 | +class NaluCell(nn.Module): |
12 | 9 | """Basic NALU unit implementation
|
13 | 10 | from https://arxiv.org/pdf/1808.00508.pdf
|
14 | 11 | """
|
15 | 12 |
|
16 |
| - def __init__(self, inputs, outputs): |
| 13 | + def __init__(self, in_shape, out_shape): |
17 | 14 | """
|
18 |
| - inputs: input sample size |
19 |
| - outputs: output sample size |
| 15 | + in_shape: input sample dimension |
| 16 | + out_shape: output sample dimension |
20 | 17 | """
|
21 | 18 | super().__init__()
|
22 |
| - self.inputs = inputs |
23 |
| - self.outputs = outputs |
24 |
| - self.G = Parameter(Tensor(outputs, inputs)) |
25 |
| - self.W = Parameter(Tensor(outputs, inputs)) |
26 |
| - self.nac = NeuralAccumulator(outputs, inputs) |
| 19 | + self.in_shape = in_shape |
| 20 | + self.out_shape = out_shape |
| 21 | + self.G = Parameter(Tensor(out_shape, in_shape)) |
| 22 | + self.W = Parameter(Tensor(out_shape, in_shape)) |
| 23 | + self.nac = NacCell(out_shape, in_shape) |
| 24 | + xavier_uniform_(self.G), xavier_uniform_(self.W) |
27 | 25 | self.eps = 1e-5
|
28 | 26 | self.register_parameter('bias', None)
|
29 |
| - xavier_uniform_(self.G), xavier_uniform_(self.W) |
30 | 27 |
|
31 | 28 | def forward(self, input):
|
32 | 29 | a = self.nac(input)
|
|
0 commit comments