File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -84,7 +84,7 @@ class ModelWrapper(Stateful):
8484 def __init__ (self , model : Union [nn .Module , List [nn .Module ]]) -> None :
8585 self .model = [model ] if isinstance (model , nn .Module ) else model
8686
87- def state_dict (self ) -> None :
87+ def state_dict (self ) -> Dict [ str , Any ] :
8888 return {
8989 k : v for sd in map (get_model_state_dict , self .model ) for k , v in sd .items ()
9090 }
@@ -107,7 +107,7 @@ def __init__(
107107 self .model = [model ] if isinstance (model , nn .Module ) else model
108108 self .optim = [optim ] if isinstance (optim , torch .optim .Optimizer ) else optim
109109
110- def state_dict (self ) -> None :
110+ def state_dict (self ) -> Dict [ str , Any ] :
111111 func = functools .partial (
112112 get_optimizer_state_dict ,
113113 options = StateDictOptions (flatten_optimizer_state_dict = True ),
You can’t perform that action at this time.
0 commit comments