7
7
8
8
class TestOptim (unittest .TestCase ):
9
9
def test_SGD (self ):
10
- optim = SGD (torch .nn .Linear (10 , 3 ).parameters ())
10
+ optim = SGD (model_params = torch .nn .Linear (10 , 3 ).parameters ())
11
11
self .assertTrue ("lr" in optim .__dict__ ["settings" ])
12
12
self .assertTrue ("momentum" in optim .__dict__ ["settings" ])
13
13
res = optim .construct_from_pytorch (torch .nn .Linear (10 , 3 ).parameters ())
@@ -22,13 +22,18 @@ def test_SGD(self):
22
22
self .assertEqual (optim .__dict__ ["settings" ]["lr" ], 0.002 )
23
23
self .assertEqual (optim .__dict__ ["settings" ]["momentum" ], 0.989 )
24
24
25
- with self .assertRaises (RuntimeError ):
25
+ optim = SGD (0.001 )
26
+ self .assertEqual (optim .__dict__ ["settings" ]["lr" ], 0.001 )
27
+ res = optim .construct_from_pytorch (torch .nn .Linear (10 , 3 ).parameters ())
28
+ self .assertTrue (isinstance (res , torch .optim .SGD ))
29
+
30
+ with self .assertRaises (TypeError ):
26
31
_ = SGD ("???" )
27
- with self .assertRaises (RuntimeError ):
32
+ with self .assertRaises (TypeError ):
28
33
_ = SGD (0.001 , lr = 0.002 )
29
34
30
35
def test_Adam (self ):
31
- optim = Adam (torch .nn .Linear (10 , 3 ).parameters ())
36
+ optim = Adam (model_params = torch .nn .Linear (10 , 3 ).parameters ())
32
37
self .assertTrue ("lr" in optim .__dict__ ["settings" ])
33
38
self .assertTrue ("weight_decay" in optim .__dict__ ["settings" ])
34
39
res = optim .construct_from_pytorch (torch .nn .Linear (10 , 3 ).parameters ())
@@ -42,3 +47,8 @@ def test_Adam(self):
42
47
optim = Adam (lr = 0.002 , weight_decay = 0.989 )
43
48
self .assertEqual (optim .__dict__ ["settings" ]["lr" ], 0.002 )
44
49
self .assertEqual (optim .__dict__ ["settings" ]["weight_decay" ], 0.989 )
50
+
51
+ optim = Adam (0.001 )
52
+ self .assertEqual (optim .__dict__ ["settings" ]["lr" ], 0.001 )
53
+ res = optim .construct_from_pytorch (torch .nn .Linear (10 , 3 ).parameters ())
54
+ self .assertTrue (isinstance (res , torch .optim .Adam ))
0 commit comments