1010# limitations under the License.
1111
1212import logging
13- from typing import TYPE_CHECKING , Callable , Optional
13+ import warnings
14+ from typing import TYPE_CHECKING , Callable , List , Optional
15+
16+ import torch
1417
1518from monai .data import CSVSaver
1619from monai .handlers .utils import evenly_divisible_all_gather , string_list_all_gather
1720from monai .utils import ImageMetaKey as Key
18- from monai .utils import exact_version , optional_import
21+ from monai .utils import exact_version , issequenceiterable , optional_import
1922
2023idist , _ = optional_import ("ignite" , "0.4.4" , exact_version , "distributed" )
2124Events , _ = optional_import ("ignite.engine" , "0.4.4" , exact_version , "Events" )
@@ -59,13 +62,17 @@ def __init__(
5962 default to 0.
6063
6164 """
62- self ._expected_rank : bool = idist .get_rank () == save_rank
63- self .saver = CSVSaver (output_dir , filename , overwrite )
65+ self .save_rank = save_rank
66+ self .output_dir = output_dir
67+ self .filename = filename
68+ self .overwrite = overwrite
6469 self .batch_transform = batch_transform
6570 self .output_transform = output_transform
6671
6772 self .logger = logging .getLogger (name )
6873 self ._name = name
74+ self ._outputs : List [torch .Tensor ] = []
75+ self ._filenames : List [str ] = []
6976
7077 def attach (self , engine : Engine ) -> None :
7178 """
@@ -74,10 +81,16 @@ def attach(self, engine: Engine) -> None:
7481 """
7582 if self ._name is None :
7683 self .logger = engine .logger
84+ if not engine .has_event_handler (self ._started , Events .EPOCH_STARTED ):
85+ engine .add_event_handler (Events .EPOCH_STARTED , self ._started )
7786 if not engine .has_event_handler (self , Events .ITERATION_COMPLETED ):
7887 engine .add_event_handler (Events .ITERATION_COMPLETED , self )
79- if self ._expected_rank and not engine .has_event_handler (self .saver .finalize , Events .COMPLETED ):
80- engine .add_event_handler (Events .COMPLETED , lambda engine : self .saver .finalize ())
88+ if not engine .has_event_handler (self ._finalize , Events .EPOCH_COMPLETED ):
89+ engine .add_event_handler (Events .EPOCH_COMPLETED , self ._finalize )
90+
91+ def _started (self , engine : Engine ) -> None :
92+ self ._outputs = []
93+ self ._filenames = []
8194
8295 def __call__ (self , engine : Engine ) -> None :
8396 """
@@ -86,12 +99,39 @@ def __call__(self, engine: Engine) -> None:
8699 Args:
87100 engine: Ignite Engine, it can be a trainer, validator or evaluator.
88101 """
89- _meta_data = self .batch_transform (engine .state .batch )
90- if Key .FILENAME_OR_OBJ in _meta_data :
91- # all gather filenames across ranks, only filenames are necessary
92- _meta_data = {Key .FILENAME_OR_OBJ : string_list_all_gather (_meta_data [Key .FILENAME_OR_OBJ ])}
93- # all gather predictions across ranks
94- _engine_output = evenly_divisible_all_gather (self .output_transform (engine .state .output ))
95-
96- if self ._expected_rank :
97- self .saver .save_batch (_engine_output , _meta_data )
102+ filenames = self .batch_transform (engine .state .batch ).get (Key .FILENAME_OR_OBJ )
103+ if issequenceiterable (filenames ):
104+ self ._filenames .extend (filenames )
105+ outputs = self .output_transform (engine .state .output )
106+ if outputs is not None :
107+ self ._outputs .append (outputs )
108+
109+ def _finalize (self , engine : Engine ) -> None :
110+ """
111+ All gather classification results from ranks and save to CSV file.
112+
113+ Args:
114+ engine: Ignite Engine, it can be a trainer, validator or evaluator.
115+ """
116+ ws = idist .get_world_size ()
117+ if self .save_rank >= ws :
118+ raise ValueError ("target save rank is greater than the distributed group size." )
119+
120+ outputs = torch .cat (self ._outputs , dim = 0 )
121+ filenames = self ._filenames
122+ if ws > 1 :
123+ outputs = evenly_divisible_all_gather (outputs )
124+ filenames = string_list_all_gather (filenames )
125+
126+ if len (filenames ) == 0 :
127+ meta_dict = None
128+ else :
129+ if len (filenames ) != len (outputs ):
130+ warnings .warn (f"filenames length: { len (filenames )} doesn't match outputs length: { len (outputs )} ." )
131+ meta_dict = {Key .FILENAME_OR_OBJ : filenames }
132+
133+ # save to CSV file only in the expected rank
134+ if idist .get_rank () == self .save_rank :
135+ saver = CSVSaver (self .output_dir , self .filename , self .overwrite )
136+ saver .save_batch (outputs , meta_dict )
137+ saver .finalize ()
0 commit comments