forked from facebookresearch/vissl
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_mlp.py
66 lines (50 loc) · 1.86 KB
/
test_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from vissl.config import AttrDict
from vissl.models.heads import MLP, LinearEvalMLP
class TestMLP(unittest.TestCase):
"""
Unit test to verify that correct construction of MLP layers
and linear evaluation MLP layers
"""
MODEL_CONFIG = AttrDict(
{
"HEAD": {
"BATCHNORM_EPS": 1e-6,
"BATCHNORM_MOMENTUM": 0.99,
"PARAMS_MULTIPLIER": 1.0,
}
}
)
def test_mlp(self):
mlp = MLP(self.MODEL_CONFIG, dims=[2048, 100])
x = torch.randn(size=(4, 2048))
out = mlp(x)
assert out.shape == torch.Size([4, 100])
x = torch.randn(size=(1, 2048))
out = mlp(x)
assert out.shape == torch.Size([1, 100])
def test_mlp_reshaping(self):
mlp = MLP(self.MODEL_CONFIG, dims=[2048, 100])
x = torch.randn(size=(1, 2048, 1, 1))
out = mlp(x)
assert out.shape == torch.Size([1, 100])
def test_mlp_catch_bad_shapes(self):
mlp = MLP(self.MODEL_CONFIG, dims=[2048, 100])
x = torch.randn(size=(1, 2048, 2, 1))
with self.assertRaises(AssertionError) as context:
mlp(x)
assert context.exception is not None
def test_eval_mlp_shape(self):
eval_mlp = LinearEvalMLP(
self.MODEL_CONFIG, in_channels=2048, dims=[2048 * 2 * 2, 1000]
)
resnet_feature_map = torch.randn(size=(4, 2048, 2, 2))
out = eval_mlp(resnet_feature_map)
assert out.shape == torch.Size([4, 1000])
resnet_feature_map = torch.randn(size=(1, 2048, 2, 2))
out = eval_mlp(resnet_feature_map)
assert out.shape == torch.Size([1, 1000])