Skip to content

Commit 9e1a7d4

Browse files
authored
2095 fix multi-gpu issue in ClassificationSaver (#2096)
* [DLMED] fix classification issue Signed-off-by: Nic Ma <[email protected]>
1 parent 5eb9198 commit 9e1a7d4

File tree

3 files changed

+72
-22
lines changed

3 files changed

+72
-22
lines changed

monai/handlers/classification_saver.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@
1010
# limitations under the License.
1111

1212
import logging
13-
from typing import TYPE_CHECKING, Callable, Optional
13+
import warnings
14+
from typing import TYPE_CHECKING, Callable, List, Optional
15+
16+
import torch
1417

1518
from monai.data import CSVSaver
1619
from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather
1720
from monai.utils import ImageMetaKey as Key
18-
from monai.utils import exact_version, optional_import
21+
from monai.utils import exact_version, issequenceiterable, optional_import
1922

2023
idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed")
2124
Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
@@ -59,13 +62,17 @@ def __init__(
5962
default to 0.
6063
6164
"""
62-
self._expected_rank: bool = idist.get_rank() == save_rank
63-
self.saver = CSVSaver(output_dir, filename, overwrite)
65+
self.save_rank = save_rank
66+
self.output_dir = output_dir
67+
self.filename = filename
68+
self.overwrite = overwrite
6469
self.batch_transform = batch_transform
6570
self.output_transform = output_transform
6671

6772
self.logger = logging.getLogger(name)
6873
self._name = name
74+
self._outputs: List[torch.Tensor] = []
75+
self._filenames: List[str] = []
6976

7077
def attach(self, engine: Engine) -> None:
7178
"""
@@ -74,10 +81,16 @@ def attach(self, engine: Engine) -> None:
7481
"""
7582
if self._name is None:
7683
self.logger = engine.logger
84+
if not engine.has_event_handler(self._started, Events.EPOCH_STARTED):
85+
engine.add_event_handler(Events.EPOCH_STARTED, self._started)
7786
if not engine.has_event_handler(self, Events.ITERATION_COMPLETED):
7887
engine.add_event_handler(Events.ITERATION_COMPLETED, self)
79-
if self._expected_rank and not engine.has_event_handler(self.saver.finalize, Events.COMPLETED):
80-
engine.add_event_handler(Events.COMPLETED, lambda engine: self.saver.finalize())
88+
if not engine.has_event_handler(self._finalize, Events.EPOCH_COMPLETED):
89+
engine.add_event_handler(Events.EPOCH_COMPLETED, self._finalize)
90+
91+
def _started(self, engine: Engine) -> None:
92+
self._outputs = []
93+
self._filenames = []
8194

8295
def __call__(self, engine: Engine) -> None:
8396
"""
@@ -86,12 +99,39 @@ def __call__(self, engine: Engine) -> None:
8699
Args:
87100
engine: Ignite Engine, it can be a trainer, validator or evaluator.
88101
"""
89-
_meta_data = self.batch_transform(engine.state.batch)
90-
if Key.FILENAME_OR_OBJ in _meta_data:
91-
# all gather filenames across ranks, only filenames are necessary
92-
_meta_data = {Key.FILENAME_OR_OBJ: string_list_all_gather(_meta_data[Key.FILENAME_OR_OBJ])}
93-
# all gather predictions across ranks
94-
_engine_output = evenly_divisible_all_gather(self.output_transform(engine.state.output))
95-
96-
if self._expected_rank:
97-
self.saver.save_batch(_engine_output, _meta_data)
102+
filenames = self.batch_transform(engine.state.batch).get(Key.FILENAME_OR_OBJ)
103+
if issequenceiterable(filenames):
104+
self._filenames.extend(filenames)
105+
outputs = self.output_transform(engine.state.output)
106+
if outputs is not None:
107+
self._outputs.append(outputs)
108+
109+
def _finalize(self, engine: Engine) -> None:
110+
"""
111+
All gather classification results from ranks and save to CSV file.
112+
113+
Args:
114+
engine: Ignite Engine, it can be a trainer, validator or evaluator.
115+
"""
116+
ws = idist.get_world_size()
117+
if self.save_rank >= ws:
118+
raise ValueError("target save rank is greater than the distributed group size.")
119+
120+
outputs = torch.cat(self._outputs, dim=0)
121+
filenames = self._filenames
122+
if ws > 1:
123+
outputs = evenly_divisible_all_gather(outputs)
124+
filenames = string_list_all_gather(filenames)
125+
126+
if len(filenames) == 0:
127+
meta_dict = None
128+
else:
129+
if len(filenames) != len(outputs):
130+
warnings.warn(f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}.")
131+
meta_dict = {Key.FILENAME_OR_OBJ: filenames}
132+
133+
# save to CSV file only in the expected rank
134+
if idist.get_rank() == self.save_rank:
135+
saver = CSVSaver(self.output_dir, self.filename, self.overwrite)
136+
saver.save_batch(outputs, meta_dict)
137+
saver.finalize()

monai/handlers/metrics_saver.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from monai.handlers.utils import string_list_all_gather, write_metrics_reports
1515
from monai.utils import ImageMetaKey as Key
16-
from monai.utils import ensure_tuple, exact_version, optional_import
16+
from monai.utils import ensure_tuple, exact_version, issequenceiterable, optional_import
1717

1818
Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
1919
idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed")
@@ -86,7 +86,7 @@ def attach(self, engine: Engine) -> None:
8686
Args:
8787
engine: Ignite Engine, it can be a trainer, validator or evaluator.
8888
"""
89-
engine.add_event_handler(Events.STARTED, self._started)
89+
engine.add_event_handler(Events.EPOCH_STARTED, self._started)
9090
engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames)
9191
engine.add_event_handler(Events.EPOCH_COMPLETED, self)
9292

@@ -95,8 +95,9 @@ def _started(self, engine: Engine) -> None:
9595

9696
def _get_filenames(self, engine: Engine) -> None:
9797
if self.metric_details is not None:
98-
_filenames = list(ensure_tuple(self.batch_transform(engine.state.batch)[Key.FILENAME_OR_OBJ]))
99-
self._filenames += _filenames
98+
filenames = self.batch_transform(engine.state.batch).get(Key.FILENAME_OR_OBJ)
99+
if issequenceiterable(filenames):
100+
self._filenames.extend(filenames)
100101

101102
def __call__(self, engine: Engine) -> None:
102103
"""
@@ -105,7 +106,7 @@ def __call__(self, engine: Engine) -> None:
105106
"""
106107
ws = idist.get_world_size()
107108
if self.save_rank >= ws:
108-
raise ValueError("target rank is greater than the distributed group size.")
109+
raise ValueError("target save rank is greater than the distributed group size.")
109110

110111
# all gather file names across ranks
111112
_images = string_list_all_gather(strings=self._filenames) if ws > 1 else self._filenames
@@ -123,7 +124,7 @@ def __call__(self, engine: Engine) -> None:
123124

124125
write_metrics_reports(
125126
save_dir=self.save_dir,
126-
images=_images,
127+
images=None if len(_images) == 0 else _images,
127128
metrics=_metrics,
128129
metric_details=_metric_details,
129130
summary_ops=self.summary_ops,

tests/test_handler_classification_saver_dist.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ def _train_func(engine, batch):
4747
"data_shape": [(1, 1) for _ in range(8 * rank, (8 + rank) * (rank + 1))],
4848
}
4949
]
50+
# rank 1 has more iterations
51+
if rank == 1:
52+
data.append(
53+
{
54+
"filename_or_obj": ["testfile" + str(i) for i in range(18, 28)],
55+
"data_shape": [(1, 1) for _ in range(18, 28)],
56+
}
57+
)
58+
5059
engine.run(data, max_epochs=1)
5160
filepath = os.path.join(tempdir, "predictions.csv")
5261
if rank == 1:
@@ -58,7 +67,7 @@ def _train_func(engine, batch):
5867
self.assertEqual(row[0], "testfile" + str(i))
5968
self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0)
6069
i += 1
61-
self.assertEqual(i, 18)
70+
self.assertEqual(i, 28)
6271

6372

6473
if __name__ == "__main__":

0 commit comments

Comments
 (0)