diff --git a/torchsnapshot/io_preparer.py b/torchsnapshot/io_preparer.py index f837701..e4d9e1f 100644 --- a/torchsnapshot/io_preparer.py +++ b/torchsnapshot/io_preparer.py @@ -97,14 +97,7 @@ def prepare_write( return entry, [] storage_path = get_storage_path(obj, logical_path, rank, replicated) - if isinstance(obj, ShardedTensor): - return ShardedTensorIOPreparer.prepare_write( - storage_path=storage_path, - obj=obj, - is_async_snapshot=is_async_snapshot, - _tensor_prepare_func=_tensor_prepare_func, - ) - elif isinstance(obj, torch.Tensor): + if isinstance(obj, torch.Tensor): if obj.nelement() * obj.element_size() > get_max_chunk_size_bytes(): chunking_instruction = ChunkedTensorIOPreparer.chunk_tensor(obj) entry, obj_write_req = ChunkedTensorIOPreparer.prepare_write(