File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change 77
88import torch
99import typing
10- from typing import Set
10+ from typing import Any , Set , Dict
1111
1212from returnn .log import log
1313from returnn .util .basic import RefIdEq
@@ -35,7 +35,8 @@ def _init_optimizer_classes_dict():
3535
3636def get_optimizer_class (class_name ):
3737 """
38- :param str|function|type[torch.optim.Optimizer] class_name: Optimizer data, e.g. "adam", torch.optim.Adam...
38+ :param str|()->torch.optim.Optimizer|type[torch.optim.Optimizer] class_name:
39+ Optimizer data, e.g. "adam", torch.optim.Adam...
3940 :return: Optimizer class
4041 :rtype: type[torch.optim.Optimizer]
4142 """
@@ -156,7 +157,7 @@ def _create_optimizer(self, optimizer_opts):
156157 if isinstance (optimizer_opts , torch .optim .Optimizer ):
157158 return optimizer_opts
158159 elif callable (optimizer_opts ):
159- optimizer_opts = {"class" : optimizer_opts }
160+ optimizer_opts : Dict [ str , Any ] = {"class" : optimizer_opts }
160161 else :
161162 if not isinstance (optimizer_opts , dict ):
162163 raise ValueError ("'optimizer' must of type dict, callable or torch.optim.Optimizer instance." )
You can’t perform that action at this time.
0 commit comments