Skip to content

Commit dd0f5fc

Browse files
committed
deep copy state dict for checkpoint
Summary: deep copy the state dict for sending checkpoint because if the replica moves to the next step, the state dict can change before the checkpoint is sent
1 parent fef4abc commit dd0f5fc

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

torchft/checkpointing/http_transport.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Generator, List, Optional, TypeVar, cast
1717

1818
import torch
19+
from torch.distributed.tensor import DTensor, distribute_tensor
1920
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
2021

2122
from torchft.checkpointing._rwlock import RWLock
@@ -278,7 +279,13 @@ def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
278279
else:
279280
out.append(v.cpu())
280281
else:
281-
out.append(v)
282+
if isinstance(v, DTensor):
283+
clone = distribute_tensor(
284+
v.to_local().clone(), v.device_mesh, v.placements
285+
)
286+
else:
287+
clone = v.clone()
288+
out.append(clone)
282289
else:
283290
out.append(v)
284291
return out

0 commit comments

Comments
 (0)