Skip to content

Commit 67fcc15

Browse files
sanandaraj5597Selvaraj Anandarajpre-commit-ci[bot]Selvaraj AnandarajpggPL
authored
Create GPU reload buffers on main stream (#2131)
* Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> * Fixed typo Signed-off-by: Selvaraj Anandaraj <[email protected]> --------- Signed-off-by: Selvaraj Anandaraj <[email protected]> Signed-off-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj <[email protected]> Co-authored-by: Paweł Gadziński <[email protected]>
1 parent e0e3d12 commit 67fcc15

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

transformer_engine/pytorch/cpu_offload.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -551,17 +551,23 @@ def bulk_reload_group(self, group_to_reload):
551551
buffer_idx = 0
552552
double_buffer_idx = group_to_reload % 2
553553

554+
main_stream = torch.cuda.current_stream()
555+
554556
with torch.cuda.stream(self.h2d_stream):
555557
# move back tensors
556558
for tensor_label, state in self.tensor_tag_to_state.items():
557559
group_id, _ = tensor_label
558560
if group_id == group_to_reload:
559-
if self.double_buffering:
560-
reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx]
561-
else:
562-
reload_buffer = None
563561

564562
if isinstance(state, tuple):
563+
if self.double_buffering:
564+
reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx]
565+
else:
566+
with torch.cuda.stream(main_stream):
567+
reload_buffer = torch.empty_like(
568+
state[1], device=torch.cuda.current_device()
569+
)
570+
565571
recovered_tensor = SynchronizedGroupOffloadHandler.reload(
566572
state, True, reload_buffer
567573
)
@@ -570,14 +576,18 @@ def bulk_reload_group(self, group_to_reload):
570576
elif isinstance(state, list):
571577
tensor_list = []
572578
for state_tuple in state:
573-
if self.double_buffering:
574-
reload_buffer = self.reload_double_buffer[double_buffer_idx][
575-
buffer_idx
576-
]
577-
else:
578-
reload_buffer = None
579579

580580
if isinstance(state_tuple, tuple):
581+
if self.double_buffering:
582+
reload_buffer = self.reload_double_buffer[double_buffer_idx][
583+
buffer_idx
584+
]
585+
else:
586+
with torch.cuda.stream(main_stream):
587+
reload_buffer = torch.empty_like(
588+
state_tuple[1], device=torch.cuda.current_device()
589+
)
590+
581591
tensor_list.append(
582592
SynchronizedGroupOffloadHandler.reload(
583593
state_tuple,

0 commit comments

Comments
 (0)