Skip to content

Commit e9a9920

Browse files
committed
PT engine, cleanup types
1 parent 3d8c8ef commit e9a9920

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

returnn/torch/updater.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import torch
99
import typing
10-
from typing import Set
10+
from typing import Any, Set, Dict
1111

1212
from returnn.log import log
1313
from returnn.util.basic import RefIdEq
@@ -35,7 +35,8 @@ def _init_optimizer_classes_dict():
3535

3636
def get_optimizer_class(class_name):
3737
"""
38-
:param str|function|type[torch.optim.Optimizer] class_name: Optimizer data, e.g. "adam", torch.optim.Adam...
38+
:param str|()->torch.optim.Optimizer|type[torch.optim.Optimizer] class_name:
39+
Optimizer data, e.g. "adam", torch.optim.Adam...
3940
:return: Optimizer class
4041
:rtype: type[torch.optim.Optimizer]
4142
"""
@@ -156,7 +157,7 @@ def _create_optimizer(self, optimizer_opts):
156157
if isinstance(optimizer_opts, torch.optim.Optimizer):
157158
return optimizer_opts
158159
elif callable(optimizer_opts):
159-
optimizer_opts = {"class": optimizer_opts}
160+
optimizer_opts: Dict[str, Any] = {"class": optimizer_opts}
160161
else:
161162
if not isinstance(optimizer_opts, dict):
162163
raise ValueError("'optimizer' must of type dict, callable or torch.optim.Optimizer instance.")

0 commit comments

Comments
 (0)