From fc3b3c0a332a1671a0a9b047d38d84e937480cfa Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Sun, 25 Jun 2023 15:36:56 +0800 Subject: [PATCH 01/24] exclude xpu --- paddle/phi/kernels/kps/elementwise_add_kernel.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/phi/kernels/kps/elementwise_add_kernel.cu b/paddle/phi/kernels/kps/elementwise_add_kernel.cu index 31d1e4e32cf54b..b3fe46a1cd3100 100644 --- a/paddle/phi/kernels/kps/elementwise_add_kernel.cu +++ b/paddle/phi/kernels/kps/elementwise_add_kernel.cu @@ -72,14 +72,18 @@ void AddKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* out) { +#ifdef PADDLE_WITH_CUDA if (x.dtype() == phi::DataType::FLOAT32 && (y.dtype() == phi::DataType::BFLOAT16 || y.dtype() == phi::DataType::FLOAT16)) { using Type = DataTypeToCppType::type; Float32Bfloat16OrFloat16AddCudaFunctor(dev_ctx, x, y, out); } else { +#endif AddCudaFunctor(dev_ctx, x, y, -1, out); +#ifdef PADDLE_WITH_CUDA } +#endif } template From 738f5d59be1d7c5505269094274cd7254fa1a204 Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Thu, 23 Nov 2023 17:22:54 +0800 Subject: [PATCH 02/24] demo of running dygraph distributed save load --- .../distributed/checkpoint/broadcast_test.py | 2 + .../distributed/checkpoint/hang_test.py | 53 +++ .../distributed/checkpoint/load_state_dict.py | 322 ++++++++++++++++++ .../paddle/distributed/checkpoint/metadata.py | 33 ++ .../distributed/checkpoint/output/0.metadata | Bin 0 -> 371 bytes .../distributed/checkpoint/output/0_0.distcp | Bin 0 -> 259 bytes .../distributed/checkpoint/output/1_0.distcp | Bin 0 -> 259 bytes .../distributed/checkpoint/output2/0.metadata | Bin 0 -> 371 bytes .../distributed/checkpoint/output2/0_0.distcp | Bin 0 -> 259 bytes .../distributed/checkpoint/save_state_dict.py | 142 ++++++++ .../distributed/checkpoint/test_hang_test.py | 3 + .../checkpoint/test_load_state_dict.sh | 3 + .../checkpoint/test_save_state_dict.sh | 2 + .../distributed/checkpoint/test_utils.sh | 2 + python/paddle/distributed/checkpoint/utils.py | 79 +++++ 15 files changed, 641 insertions(+) create mode 100644 python/paddle/distributed/checkpoint/broadcast_test.py create mode 100644 python/paddle/distributed/checkpoint/hang_test.py create mode 100644 python/paddle/distributed/checkpoint/load_state_dict.py create mode 100644 python/paddle/distributed/checkpoint/metadata.py create mode 100644 python/paddle/distributed/checkpoint/output/0.metadata create mode 100644 python/paddle/distributed/checkpoint/output/0_0.distcp create mode 100644 python/paddle/distributed/checkpoint/output/1_0.distcp create mode 100644 python/paddle/distributed/checkpoint/output2/0.metadata create mode 100644 python/paddle/distributed/checkpoint/output2/0_0.distcp create mode 100644 python/paddle/distributed/checkpoint/save_state_dict.py create mode 100644 python/paddle/distributed/checkpoint/test_hang_test.py create mode 100644 python/paddle/distributed/checkpoint/test_load_state_dict.sh create mode 100644 python/paddle/distributed/checkpoint/test_save_state_dict.sh create mode 100644 python/paddle/distributed/checkpoint/test_utils.sh create mode 100644 python/paddle/distributed/checkpoint/utils.py diff --git a/python/paddle/distributed/checkpoint/broadcast_test.py b/python/paddle/distributed/checkpoint/broadcast_test.py new file mode 100644 index 00000000000000..de5bd9f7d44912 --- /dev/null +++ b/python/paddle/distributed/checkpoint/broadcast_test.py @@ -0,0 +1,2 @@ +import paddle + diff --git a/python/paddle/distributed/checkpoint/hang_test.py b/python/paddle/distributed/checkpoint/hang_test.py new file mode 100644 index 00000000000000..4f7fbd3f7bce22 --- /dev/null +++ b/python/paddle/distributed/checkpoint/hang_test.py @@ -0,0 +1,53 @@ +import pysnooper + +import paddle +from paddle.distributed.communication.group import is_initialized + +@pysnooper.snoop(output=f"snooper{paddle.distributed.get_rank()}.log", depth=1, max_variable_length=200) +def get_read_items(path, state_dict, process_group): + print(f"pure hang test", flush=True) + # for param_name, val in state_dict.items(): + for param_name, val in enumerate(range(2)): + if True or isinstance(val, paddle.Tensor): + print(f"before val:{val}, type:{type(val)}", flush=True) + if True or val.is_dist(): + paddle.distributed.barrier() + # pass + # local_shape, global_offset = compute_local_shape_and_global_offset(val.shape, val.dist_attr.process_mesh, val.dist_attr.dims_mapping) + # cur_chunk_metadata = ChunkMetadata(local_shape, global_offset) + # assert param_name in param_to_chunkmetadata, f"param_name:{param_name} not found in param_to_chunkmetadata:{param_to_chunkmetadata}." + # for storage_chunk_metadata in param_to_chunkmetadata[param_name]: + for storage_chunk_metadata in range(2): + print(f"rank:{paddle.distributed.get_rank()}, storage_chunk_metadata:{storage_chunk_metadata}", flush=True) + # paddle.distributed.barrier() + print(f"param_name:{param_name}, storage_chunk_metadata:{storage_chunk_metadata}") + if paddle.distributed.get_rank() == 0 or paddle.distributed.get_rank() == 1: + continue + else: + continue + else: + print(f"val:{val}, type:{type(val)}") + pass + else: + pass + return + +def main(): + path = "./output" + ###!!! Init the Disttensor and turn on the pysnooper at the same time will lead to hang !!! + + # import paddle.distributed as dist + # w1 = paddle.arange(8).reshape([4, 2]) + # w2 = paddle.arange(8, 12).reshape([2, 2]) + # mesh = dist.ProcessMesh([0,1,2,3], dim_names=["x"]) + # w1_dist_attr = dist.DistAttr(mesh, sharding_specs=["x", None]) + # sharded_w1 = dist.shard_tensor(w1, dist_attr=w1_dist_attr) + # w2_dist_attr = dist.DistAttr(mesh, sharding_specs=[None, None]) + # sharded_w2 = dist.shard_tensor(w2, dist_attr=w2_dist_attr) + # state_dict = {"w1": sharded_w1, "w2": sharded_w2} + + not is_initialized() and paddle.distributed.init_parallel_env() + get_read_items(path, None, None) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py new file mode 100644 index 00000000000000..6b087a16278f81 --- /dev/null +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -0,0 +1,322 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass +from typing import List, Tuple + +import paddle +from paddle.distributed.communication.group import is_initialized + +from metadata import Metadata, ChunkMetadata, MetadataIndex +from utils import compute_local_shape_and_global_offset + +@dataclass(frozen=True) +class ReadItem: + rank:int + meta_index:MetadataIndex + cur_offset:Tuple[int] + storage_offset:Tuple[int] + lengths:Tuple[int] + + +def get_local_load_files(path, state_dict, process_group): + # step 1, get neccesary files to be read + accessible_files = os.listdir(path) + metadata_files = [file for file in accessible_files if file.endswith(".metadata")] + assert len(metadata_files) > 0, "No metadata file found in the checkpoint directory:{path}." + # The neccesary files to be read + necessary_files = [] + for metadata_file in metadata_files: + metadata = paddle.load(os.path.join(path, metadata_file)) + for metadata_index, file_name in metadata.storage_metadata.items(): + if metadata_index.param in state_dict: + necessary_files.append(file_name) + necessary_files_set = set(necessary_files) + # allgather all accessible files + local_data_files = [file for file in accessible_files if file.endswith(".distcp")] + global_data_files = [] + paddle.distributed.all_gather_object(global_data_files, local_data_files, process_group) + tmp = [] + for files in global_data_files: + tmp += files + global_data_files_set = set(tmp) + print(f"necessary_files_set:{necessary_files_set}, global_data_files_set:{global_data_files_set}") + # check neccesary files in global_data_files + assert global_data_files_set & necessary_files_set == necessary_files_set, \ + f"The checkpoint files are not complete. Please check the checkpoint directory:{path}.global_data_files_set:{global_data_files_set}, necessary_files_set:{necessary_files_set}" + # step 2, get mapping between ranks and local files + rank_to_files = {} + file_to_ranks = {} + for rank, local_files in enumerate(global_data_files): + if len(local_files) > 0: + local_files = [f for f in local_files if f in necessary_files_set] + rank_to_files[rank] = local_files + for file in local_files: + if file not in file_to_ranks: + file_to_ranks[file] = [] + file_to_ranks[file].append(rank) + print(f"mapping rank_to_files:{rank_to_files}, file_to_ranks:{file_to_ranks}") + rank_to_read_files = {rank:[] for rank in rank_to_files.keys()} + for file, ranks in file_to_ranks.items(): + if len(ranks) == 1: + rank = ranks[0] + rank_to_read_files[rank].append(file) + rank_to_files[rank].remove(file) + if len(rank_to_files[rank]) == 0: + rank_to_files.pop(rank) + + print(f"start rank_to_read_files:{rank_to_read_files}, rank_to_files:{rank_to_files}") + # step 3, update the rank_to_read_files + def get_least_read_files_ranks(rank_to_read_files): + nums = [(rank, len(files)) for rank, files in rank_to_read_files.items()] + sorted(nums, key=lambda x: x[1]) + ranks = [rank for rank, num in nums if num == nums[0][0]] + return ranks + def get_read_rank_file(rank_to_files, ranks): + if len(rank_to_files) == 0: + return (None, None) + nums = [(rank, len(files)) for rank, files in rank_to_files.items() if rank in ranks] + sorted(nums, key=lambda x: x[1]) + rank = nums[0][0] + return (rank, rank_to_files[rank][0]) + def update(rank_to_read_files, rank_to_files, rank_file): + rank, file = rank_file + if rank is None and file is None: + return + if rank not in rank_to_read_files: + rank_to_read_files[rank] = [] + rank_to_read_files[rank].append(file) + # update rank_to_files + file_to_ranks = {} + for r, files in rank_to_files.items(): + for f in files: + if f not in file_to_ranks: + file_to_ranks[f] = [] + file_to_ranks[f].append(r) + print(f"file_to_ranks:{file_to_ranks}") + if file in file_to_ranks: + for r in file_to_ranks[file]: + rank_to_files[r].remove(file) + if len(rank_to_files[r]) == 0: + rank_to_files.pop(r) + # step 4, get final rank_to_read_files + while len(rank_to_files) > 0: + ranks = get_least_read_files_ranks(rank_to_read_files) + rank_file = get_read_rank_file(rank_to_files, ranks) + update(rank_to_read_files, rank_to_files, rank_file) + print(f"update rank_to_read_files:{rank_to_read_files}, rank_to_files:{rank_to_files}, ranks:{ranks}, rank_file:{rank_file}") + print(f"rank_to_read_files:{rank_to_read_files}") + cur_rank = paddle.distributed.get_rank() + if cur_rank in rank_to_read_files: + print(f"cur_rank:{cur_rank}, rank_to_read_files[cur_rank]:{rank_to_read_files[cur_rank]}") + return rank_to_read_files[cur_rank] + else: + print(f"rank:{cur_rank} does not need to load checkpoint") + return [] + + +def get_load_infos(path, local_load_files, process_group): + load_info = {} + accessible_files = os.listdir(path) + metadata_files = [file for file in accessible_files if file.endswith(".metadata")] + assert len(metadata_files) > 0, "No metadata file found in the checkpoint directory:{path}." + for metadata_file in metadata_files: + metadata = paddle.load(os.path.join(path, metadata_file)) + for meta_index, file_name in metadata.storage_metadata.items(): + if file_name in local_load_files: + load_info[meta_index] = (paddle.distributed.get_rank(), file_name) + load_info_list = [] + paddle.distributed.all_gather_object(load_info_list, load_info, process_group) + load_infos = {} + for load_info in load_info_list: + for meta_index, (rank, file_name) in load_info.items(): + assert meta_index not in load_infos + load_infos[meta_index] = (rank, file_name) + return load_infos + + +def compute_overlap(cur_chunk_metadata:ChunkMetadata, storage_chunk_metadata:ChunkMetadata): + cur_offsets = [] + storage_offsets = [] + lengths = [] + for cur_len, cur_offset, strorage_len, storage_offset in zip( + cur_chunk_metadata.local_shape, + cur_chunk_metadata.global_offset, + storage_chunk_metadata.local_shape, + storage_chunk_metadata.global_offset + ): + begin_offset = max(cur_offset, storage_offset) + end_offset = min(cur_offset + cur_len, storage_offset + strorage_len) + if begin_offset == cur_offset: + cur_offsets.append(0) + storage_offsets.append(begin_offset - storage_offset) + elif begin_offset == storage_offset: + cur_offsets.append(end_offset - cur_offset) + storage_offsets.append(0) + else: + assert False, "Should not reach here." + lengths.append(end_offset - begin_offset) + assert lengths[-1] > 0, f"Invalid length:{lengths[-1]}, end_offset:{end_offset}, begin_offset:{begin_offset}" + return cur_offsets, storage_offsets, lengths + + +def not_overlap(cur_chunk_metadata:ChunkMetadata, storage_chunk_metadata:ChunkMetadata): + for cur_len, cur_offset, strorage_len, storage_offset in zip( + cur_chunk_metadata.local_shape, + cur_chunk_metadata.global_offset, + storage_chunk_metadata.local_shape, + storage_chunk_metadata.global_offset + ): + if cur_offset >= (storage_offset + strorage_len) or (cur_offset + cur_len) <= storage_offset: + return True + return False + +def get_read_items(path, state_dict, process_group): + accessible_files = os.listdir(path) + metadata_files = [file for file in accessible_files if file.endswith(".metadata")] + assert len(metadata_files) > 0, "No metadata file found in the checkpoint directory:{path}." + param_to_chunkmetadata = {} + for metadata_file in metadata_files: + metadata = paddle.load(os.path.join(path, metadata_file)) + for param_name, chunk_metadata in metadata.state_dict_metadata.items(): + if param_name not in param_to_chunkmetadata: + param_to_chunkmetadata[param_name] = [] + param_to_chunkmetadata[param_name] += chunk_metadata + read_items = [] + print(f"param_to_chunkmetadata:{param_to_chunkmetadata}\n state_dict:{state_dict}") + for param_name, val in state_dict.items(): + if isinstance(val, paddle.Tensor): + if val.is_dist(): + local_shape, global_offset = compute_local_shape_and_global_offset(val.shape, val.dist_attr.process_mesh, val.dist_attr.dims_mapping) + cur_chunk_metadata = ChunkMetadata(local_shape, global_offset) + assert param_name in param_to_chunkmetadata, f"param_name:{param_name} not found in param_to_chunkmetadata:{param_to_chunkmetadata}." + for storage_chunk_metadata in param_to_chunkmetadata[param_name]: + if not_overlap(cur_chunk_metadata, storage_chunk_metadata): + continue + cur_offsets, storage_offsets, lengths = compute_overlap(cur_chunk_metadata, storage_chunk_metadata) + storage_meta_index = MetadataIndex(param_name, tuple(storage_chunk_metadata.global_offset)) + read_items.append(ReadItem(paddle.distributed.get_rank(), storage_meta_index, tuple(cur_offsets), tuple(storage_offsets), tuple(lengths))) + else: + assert False, f"Only support distributed tensor., val type:{type(val)}" + else: + assert False, f"Only support paddle.Tensor., val type:{type(val)}" + global_read_items = [] + tmp = [] + paddle.distributed.all_gather_object(tmp, read_items, process_group) + for items in tmp: + for item in items: + global_read_items.append(item) + return global_read_items + +def flatten_state_dict(state_dict): + # TODO, {"model": {"w0": xxx}} -> {model.w0: xxx} + return state_dict + +def load_state_dict(state_dict, path, process_group=None, coordinator_rank=0, use_dist=True) -> None: + """ + Load the state_dict inplace from a checkpoint path. + Args: + state_dict: The state_dict to load. It will be modified inplace after loading. + path: The directory to load checkpoint files. + process_group: ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards. + coordinator_rank: The rank used to coordinate the checkpoint. Rank0 is used by default. + use_dict: Whether to load the state_dict in distributed mode. Set True by default. + Example: + .. code-block:: python + import paddle + ... + """ + if process_group is None: + not is_initialized() and paddle.distributed.init_parallel_env() + process_group = paddle.distributed.new_group(list(range(paddle.distributed.ParallelEnv().nranks)), backend="nccl") + + state_dict = flatten_state_dict(state_dict) + local_load_files = get_local_load_files(path, state_dict, process_group) + # load_infos: {MetaIndex: (rank, file_name)} + load_infos = get_load_infos(path, local_load_files, process_group) + read_items = get_read_items(path, state_dict, process_group) + loaded_state_dict = {} + print(f"before load, state_dict:{state_dict},\n load_infos:{load_infos},\n read_items:{read_items}") + # return + for item in read_items: + assert item.meta_index in load_infos, f"item:{item}, load_infos:{load_infos}" + src_rank, file_name = load_infos[item.meta_index] + if src_rank == paddle.distributed.get_rank(): + if file_name not in loaded_state_dict: + # The load state_dict is not distributed tensor but a normal tensor. + loaded_state_dict[file_name] = paddle.load(os.path.join(path, file_name)) + storage_state_dict = loaded_state_dict[file_name] + assert item.meta_index.param in storage_state_dict + storage_local_tensor = storage_state_dict[item.meta_index.param] + storage_offsets = item.storage_offset + storage_lengths = item.lengths + storage_ends = [storage_offset + storage_length for storage_offset, storage_length in zip(storage_offsets, storage_lengths)] + # storage_chunk_tensor = paddle.cast(paddle.slice(storage_local_tensor, list(range(len(storage_lengths))), storage_offsets, storage_ends), paddle.float32) + storage_chunk_tensor = paddle.slice(storage_local_tensor, list(range(len(storage_lengths))), storage_offsets, storage_ends) + print(f"src_ran:{src_rank}, item.rank:{item.rank}, process_group:{process_group}, storage_local_tensor:{storage_local_tensor}, storage_chunk_tensor:{storage_chunk_tensor}") + paddle.distributed.broadcast(storage_chunk_tensor, src=src_rank, group=process_group) + if src_rank == item.rank: + cur_local_tensor = state_dict[item.meta_index.param]._local_value() + cur_offsets = item.cur_offset + cur_lengths = item.lengths + cur_ends = [cur_offset + cur_length for cur_offset, cur_length in zip(cur_offsets, cur_lengths)] + cur_sub_chunk_tensor = paddle.slice(cur_local_tensor, list(range(len(cur_lengths))), cur_offsets, cur_ends) + paddle.assign(storage_chunk_tensor, cur_sub_chunk_tensor) + + elif item.rank == paddle.distributed.get_rank(): + assert item.meta_index.param in state_dict, f"item:{item}, state_dict:{state_dict}" + cur_local_tensor = state_dict[item.meta_index.param]._local_value() + cur_offsets = item.cur_offset + cur_lengths = item.lengths + cur_ends = [cur_offset + cur_length for cur_offset, cur_length in zip(cur_offsets, cur_lengths)] + cur_sub_chunk_tensor = paddle.slice(cur_local_tensor, list(range(len(cur_lengths))), cur_offsets, cur_ends) + print(f"cur_sub_chunk_tensor :{cur_sub_chunk_tensor}, cur_local_tensor:{cur_local_tensor}") + paddle.distributed.broadcast(cur_sub_chunk_tensor, src=src_rank, group=process_group) + print(f"src_rank:{src_rank}, item.rank:{item.rank}, process_group:{process_group}, cur_sub_chunk_tensor:{cur_sub_chunk_tensor}") + else: + dummy_tensor = paddle.zeros(item.lengths, dtype=state_dict[item.meta_index.param].dtype) + print(f"dummy_tensor:{dummy_tensor}") + paddle.distributed.broadcast(dummy_tensor, src=src_rank, group=process_group) + print(f"src_rank:{src_rank}, item.rank:{item.rank}, process_group:{process_group}, dummy_tensor:{dummy_tensor}") + # break + print(f"after load, state_dict:{state_dict}") + + +def test_get_local_load_files(): + if paddle.distributed.get_rank() == 0: + path = "./output" + else: + path = "./output2" + # path = "./output" + # build state_dict + import paddle.distributed as dist + w1 = paddle.zeros([4,2], dtype=paddle.int64) + w2 = paddle.zeros([2,2], dtype=paddle.int64) + mesh = dist.ProcessMesh([0,1,2,3], dim_names=["x"]) + w1_dist_attr = dist.DistAttr(mesh, sharding_specs=["x", None]) + sharded_w1 = dist.shard_tensor(w1, dist_attr=w1_dist_attr) + w2_dist_attr = dist.DistAttr(mesh, sharding_specs=[None, None]) + sharded_w2 = dist.shard_tensor(w2, dist_attr=w2_dist_attr) + state_dict = {"w1": sharded_w1, "w2": sharded_w2} + load_state_dict(state_dict, path) + + + + +def test_load_state_dict(): + test_get_local_load_files() + +if __name__ == "__main__": + test_load_state_dict() diff --git a/python/paddle/distributed/checkpoint/metadata.py b/python/paddle/distributed/checkpoint/metadata.py new file mode 100644 index 00000000000000..f46dcf671c95b9 --- /dev/null +++ b/python/paddle/distributed/checkpoint/metadata.py @@ -0,0 +1,33 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple, Dict, Optional +from dataclasses import dataclass + + + +@dataclass +class ChunkMetadata: + local_shape: List[int] + global_offset: List[int] + +@dataclass(frozen=True) +class MetadataIndex: + param: str + global_offset: Tuple[int] + +@dataclass +class Metadata: + state_dict_metadata: Dict[str, List[ChunkMetadata]] = None + storage_metadata: Dict[MetadataIndex, str] = None \ No newline at end of file diff --git a/python/paddle/distributed/checkpoint/output/0.metadata b/python/paddle/distributed/checkpoint/output/0.metadata new file mode 100644 index 0000000000000000000000000000000000000000..09caecdb1b76583760ae3bde8d85cbff4eea9cd3 GIT binary patch literal 371 zcmZ{fu?oUK42Ek_5Tz)%3xb0#VtoZCDL#Ufw2elrws?2wBIx3DoAG_Uw$>tIrjY;3 zpFh|7lWJC8y17RQ5lCR&T$b&F$2Y$4NcV&UiGH*VrT+ml!9KaX;t^TenUgx~eplOZ z5<=_?f)q2o%XQ2wHRh%TJWT{hhnTqD%%+UX+!UL4<3Yi>X zgI{U0lx4uhA;L3%9&Ds=ggxi~P{%VqcUFOJP5tR4+6oz_O#OGVK^~%_X!s?$Y$Urh Ka4Q3)g2@X36@xPX literal 0 HcmV?d00001 diff --git a/python/paddle/distributed/checkpoint/output/0_0.distcp b/python/paddle/distributed/checkpoint/output/0_0.distcp new file mode 100644 index 0000000000000000000000000000000000000000..147bdc313b3372f3629f84f60705d79c84959730 GIT binary patch literal 259 zcmZo*nfikP0&1sd^e~khPU#WNE6pva)Jx7UO4Z9P%_+%DEGkN@oYKP+UzD1hpI2N` zRGM5eW%86BRb3titBtuH|0Aa68761SM literal 0 HcmV?d00001 diff --git a/python/paddle/distributed/checkpoint/output/1_0.distcp b/python/paddle/distributed/checkpoint/output/1_0.distcp new file mode 100644 index 0000000000000000000000000000000000000000..00693745275c1e4b50a7081799f27131283323e2 GIT binary patch literal 259 zcmZo*nfikP0&1sd^e~khPU#WNE6pva)Jx7UO4Z9P%_+%DEGkN@oYKP+UzD1hpI2N` zRGM5eW%86BR;DH$vo c>>x`sI8m(>1RBBtRm=&cxu7&Rk|Cvf0B4^}CIA2c literal 0 HcmV?d00001 diff --git a/python/paddle/distributed/checkpoint/output2/0.metadata b/python/paddle/distributed/checkpoint/output2/0.metadata new file mode 100644 index 0000000000000000000000000000000000000000..09caecdb1b76583760ae3bde8d85cbff4eea9cd3 GIT binary patch literal 371 zcmZ{fu?oUK42Ek_5Tz)%3xb0#VtoZCDL#Ufw2elrws?2wBIx3DoAG_Uw$>tIrjY;3 zpFh|7lWJC8y17RQ5lCR&T$b&F$2Y$4NcV&UiGH*VrT+ml!9KaX;t^TenUgx~eplOZ z5<=_?f)q2o%XQ2wHRh%TJWT{hhnTqD%%+UX+!UL4<3Yi>X zgI{U0lx4uhA;L3%9&Ds=ggxi~P{%VqcUFOJP5tR4+6oz_O#OGVK^~%_X!s?$Y$Urh Ka4Q3)g2@X36@xPX literal 0 HcmV?d00001 diff --git a/python/paddle/distributed/checkpoint/output2/0_0.distcp b/python/paddle/distributed/checkpoint/output2/0_0.distcp new file mode 100644 index 0000000000000000000000000000000000000000..147bdc313b3372f3629f84f60705d79c84959730 GIT binary patch literal 259 zcmZo*nfikP0&1sd^e~khPU#WNE6pva)Jx7UO4Z9P%_+%DEGkN@oYKP+UzD1hpI2N` zRGM5eW%86BRb3titBtuH|0Aa68761SM literal 0 HcmV?d00001 diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py new file mode 100644 index 00000000000000..2059cca4a4b1ee --- /dev/null +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -0,0 +1,142 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List, Dict +import numpy as np + +import paddle +from paddle.distributed.communication.group import is_initialized +from metadata import Metadata, ChunkMetadata, MetadataIndex +from utils import merge_state_dict, dedup_state_dict, compute_local_shape_and_global_offset + +def check_state_dict(state_dict, process_group): + local_keys = list(state_dict.keys()) + gloabl_keys = [] + paddle.distributed.all_gather_object(gloabl_keys, local_keys, process_group) + for keys in gloabl_keys[1:]: + assert keys == gloabl_keys[0], f"keys:{keys} != first_keys: {gloabl_keys[0]}" + +def check_file_name(file_name, process_group): + all_unique_id = [] + unique_id = int(file_name.split(".")[0].split("_")[1]) + paddle.distributed.all_gather_object(all_unique_id, unique_id, process_group) + for id in all_unique_id[1:]: + assert id == all_unique_id[0], f"id:{id} != all_unique_id[0]:{file_name}" + +# def merge_state_dict(global_state_dict): +# assert isinstance(global_state_dict, List), "The global_state_dict should be a list." +# out = {} +# for state_dict in global_state_dict: +# for key, val in state_dict.items(): +# if key in out and val not in out[key]: +# out[key].append(val) +# else: +# out[key] = [val] +# return out + +# def dedup_state_dict(global_state_dict): +# out = {} +# for state_dict in global_state_dict: +# for key, val in state_dict.items(): +# if key in out: +# continue +# out[key] = val +# return out + +def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, use_dist=True) -> None: + """ + Save the state_dict of model to path. + + Args: + state_dict: The state_dict to save. + path: The directory to save state_dict. + process_group: ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards. + coordinator_rank: The rank used to coordinate the checkpoint. Rank0 is used by default. + use_dist: Whether to save the state_dict in distributed mode. Set True by default. + + Examples: + .. code-block:: python + + import paddle + ... + + """ + assert isinstance(state_dict, dict), "The state_dict should be a dictionary." + if len(state_dict) > 0: + for val in state_dict.values(): + assert isinstance(val, (paddle.Tensor, paddle.base.framework.EagerParamBase)), "Only support dygraph Tensor now, support static DistributedTensor later" + # + if process_group is None: + not is_initialized() and paddle.distributed.init_parallel_env() + process_group = paddle.distributed.new_group(list(range(paddle.distributed.ParallelEnv().nranks)), backend="nccl") + # calculate (global offset, local shape) of each DTensor + local_state_dict = {} + metadata = Metadata() + unique_id = 0 + file_name = "" + while(True): + file_name = f"{paddle.distributed.get_rank()}_{unique_id}.distcp" + if not os.path.exists(os.path.join(path, file_name)): + break + unique_id += 1 + print(f"file_name:{file_name}") + check_file_name(file_name, process_group) + # the parameter_name and order in state_dict should be the same + check_state_dict(state_dict, process_group) + local_chunk_metadata = {} + local_storage_metadata = {} + for key, val in state_dict.items(): + if isinstance(val, paddle.Tensor): + if val.is_dist(): + local_tensor = val.get_tensor().get_tensor() + local_shape, global_offset = compute_local_shape_and_global_offset(val.shape, val.dist_attr.process_mesh, val.dist_attr.dims_mapping) + # gather local_shape and global_offset from all ranks of each parameter + local_chunk_metadata[key] = ChunkMetadata(local_shape, global_offset) + local_storage_metadata[MetadataIndex(key, tuple(global_offset))] = file_name + else: + local_tensor = val + local_state_dict[key] = local_tensor + global_chunk_metadata = [] + global_storage_metadata = [] + paddle.distributed.all_gather_object(global_chunk_metadata, local_chunk_metadata, process_group) + paddle.distributed.all_gather_object(global_storage_metadata, local_storage_metadata, process_group) + metadata.state_dict_metadata = merge_state_dict(global_chunk_metadata) + metadata.storage_metadata = dedup_state_dict(global_storage_metadata) + if coordinator_rank == paddle.distributed.get_rank(): + print(f"metadata:{metadata}") + paddle.save(metadata, os.path.join(path, f"{unique_id}.metadata")) + print(f"local_state_dict:{local_state_dict}") + for k,v in local_state_dict.items(): + # the phi::DenseTensor only support convert to np.array + local_state_dict[k] = np.array(v) + print(f"local_state_dict name:{k}, val:{local_state_dict[k]}, type:{type(local_state_dict[k])}") + paddle.save(local_state_dict, os.path.join(path, file_name)) + + + +def test_save_state_dict(): + import paddle.distributed as dist + w1 = paddle.arange(8).reshape([4, 2]) + w2 = paddle.arange(8, 12).reshape([2, 2]) + mesh = dist.ProcessMesh([0,1], dim_names=["x"]) + w1_dist_attr = dist.DistAttr(mesh, sharding_specs=["x", None]) + sharded_w1 = dist.shard_tensor(w1, dist_attr=w1_dist_attr) + w2_dist_attr = dist.DistAttr(mesh, sharding_specs=[None, None]) + sharded_w2 = dist.shard_tensor(w2, dist_attr=w2_dist_attr) + state_dict = {"w1": sharded_w1, "w2": sharded_w2} + save_state_dict(state_dict, "./output") + +if __name__ == "__main__": + test_save_state_dict() diff --git a/python/paddle/distributed/checkpoint/test_hang_test.py b/python/paddle/distributed/checkpoint/test_hang_test.py new file mode 100644 index 00000000000000..07556ce8e2afd5 --- /dev/null +++ b/python/paddle/distributed/checkpoint/test_hang_test.py @@ -0,0 +1,3 @@ +rm -rf log/* +rm -f snooper* +python -u -m paddle.distributed.launch --log_dir ./log --devices 0,1,2,3 test.py diff --git a/python/paddle/distributed/checkpoint/test_load_state_dict.sh b/python/paddle/distributed/checkpoint/test_load_state_dict.sh new file mode 100644 index 00000000000000..100e0e66500459 --- /dev/null +++ b/python/paddle/distributed/checkpoint/test_load_state_dict.sh @@ -0,0 +1,3 @@ +rm -rf log/* +rm -f snooper* +python -u -m paddle.distributed.launch --log_dir ./log --devices 1,2,3,4 load_state_dict.py diff --git a/python/paddle/distributed/checkpoint/test_save_state_dict.sh b/python/paddle/distributed/checkpoint/test_save_state_dict.sh new file mode 100644 index 00000000000000..8e4dadc4c2095f --- /dev/null +++ b/python/paddle/distributed/checkpoint/test_save_state_dict.sh @@ -0,0 +1,2 @@ +rm -rf log/* +python -u -m paddle.distributed.launch --log_dir ./log --devices 0,1 save_state_dict.py diff --git a/python/paddle/distributed/checkpoint/test_utils.sh b/python/paddle/distributed/checkpoint/test_utils.sh new file mode 100644 index 00000000000000..03ce3c88325122 --- /dev/null +++ b/python/paddle/distributed/checkpoint/test_utils.sh @@ -0,0 +1,2 @@ +rm -rf log/* +python -u -m paddle.distributed.launch --log_dir ./log --devices 0,1,2,3 utils.py diff --git a/python/paddle/distributed/checkpoint/utils.py b/python/paddle/distributed/checkpoint/utils.py new file mode 100644 index 00000000000000..7482dfca0f48d6 --- /dev/null +++ b/python/paddle/distributed/checkpoint/utils.py @@ -0,0 +1,79 @@ +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import List, Optional + +import numpy as np +import paddle +from paddle.framework import core + +def get_coordinator(mesh:np.array, rank:int): + mesh = paddle.to_tensor(mesh) + rand_coordinator = (mesh == rank).nonzero() + assert rand_coordinator.shape[0] in (0, 1), f"rand_coordinator.shape: {rand_coordinator.shape}" + return rand_coordinator[0].tolist() if rand_coordinator.shape[0] > 0 else None + +def merge_state_dict(global_state_dict): + assert isinstance(global_state_dict, List), "The global_state_dict should be a list." + out = {} + for state_dict in global_state_dict: + for key, val in state_dict.items(): + if key in out and val not in out[key]: + out[key].append(val) + else: + out[key] = [val] + return out + +def dedup_state_dict(global_state_dict): + out = {} + for state_dict in global_state_dict: + for key, val in state_dict.items(): + if key in out: + continue + out[key] = val + return out + +# TODO(pangengzheng): support DeviceMesh and Placement later, device_mesh:Optional[core.ProcessMesh, core.DeviceMesh], placements:Optional[List[int], core.Placement] +def compute_local_shape_and_global_offset(global_shape:List[int], process_mesh:core.ProcessMesh, dims_mapping:List[int]): + """ + tensor dist_attr look like: {process_mesh: {shape: [2], process_ids: [0,1], dim_names: [x]}, dims_mapping: [-1,0], batch_dim: 0, dynamic_dims: [], annotated: [dims_mapping: 1,process_mesh: 1], partial: [].} + the tensor dims=2, dims_mapping means the dim0 is replicate, dim1 is shard by dim0 of process_mesh + """ + mesh = np.array(process_mesh.process_ids).reshape(process_mesh.shape) + if paddle.distributed.get_rank() not in mesh: + return ((), ()) + rank_coordinator = get_coordinator(mesh, paddle.distributed.get_rank()) + local_shape = copy.copy(global_shape) + global_offset = [0 for _ in global_shape] + # print(f"rank_coordinator:{rank_coordinator}") + for i, dim in enumerate(dims_mapping): + if dim == -1: + continue + else: + assert global_shape[i] % process_mesh.shape[dim] == 0, f"i:{i}, global_shape[i]:{global_shape[i]}, process_mesh.shape[dim]:{process_mesh.shape[dim]}" + local_shape[i] = global_shape[i] // process_mesh.shape[dim] + chunk_idx = rank_coordinator[dim] + global_offset[i] = chunk_idx * local_shape[i] + + return local_shape, global_offset + +def main_test(): + import paddle.distributed as dist + + tensor = paddle.arange(8).reshape([4, 2]) + global_shape = tensor.shape + mesh = dist.ProcessMesh([[0,1], [2,3]], dim_names=["x", "y"]) + dist_attr = dist.DistAttr(mesh, sharding_specs=["x", "y"]) + sharded_tensor = dist.shard_tensor(tensor, dist_attr=dist_attr) + print(f"get_tensor:{sharded_tensor.get_tensor().get_tensor()}, sharded_tensor.dist_attr:{sharded_tensor.dist_attr}") + local_shape, global_offset = compute_local_shape_and_global_offset(global_shape, sharded_tensor.dist_attr.process_mesh, sharded_tensor.dist_attr.dims_mapping) + print(f"local_shape:{local_shape}, global_offset: {global_offset}") + +if __name__ == "__main__": + main_test() From 7134583baa8c594fc6e4e01b08c287982eac4c8e Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Fri, 24 Nov 2023 18:25:36 +0800 Subject: [PATCH 03/24] support save cross mesh state_dict --- .../distributed/checkpoint/load_state_dict.py | 53 +++++++------- .../paddle/distributed/checkpoint/metadata.py | 4 +- .../distributed/checkpoint/output/0.metadata | Bin 371 -> 0 bytes .../distributed/checkpoint/output/0_0.distcp | Bin 259 -> 0 bytes .../distributed/checkpoint/output/1_0.distcp | Bin 259 -> 0 bytes .../distributed/checkpoint/save_state_dict.py | 69 ++++++++++-------- .../checkpoint/test_save_state_dict.sh | 2 +- python/paddle/distributed/checkpoint/utils.py | 26 +------ 8 files changed, 72 insertions(+), 82 deletions(-) delete mode 100644 python/paddle/distributed/checkpoint/output/0.metadata delete mode 100644 python/paddle/distributed/checkpoint/output/0_0.distcp delete mode 100644 python/paddle/distributed/checkpoint/output/1_0.distcp diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index 6b087a16278f81..9d9dc51eef59ba 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -168,7 +168,7 @@ def compute_overlap(cur_chunk_metadata:ChunkMetadata, storage_chunk_metadata:Chu else: assert False, "Should not reach here." lengths.append(end_offset - begin_offset) - assert lengths[-1] > 0, f"Invalid length:{lengths[-1]}, end_offset:{end_offset}, begin_offset:{begin_offset}" + assert lengths[-1] >= 0, f"Invalid length:{lengths[-1]}, end_offset:{end_offset}, begin_offset:{begin_offset}" return cur_offsets, storage_offsets, lengths @@ -239,8 +239,9 @@ def load_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us ... """ if process_group is None: + # Init the default global process group not is_initialized() and paddle.distributed.init_parallel_env() - process_group = paddle.distributed.new_group(list(range(paddle.distributed.ParallelEnv().nranks)), backend="nccl") + # process_group = paddle.distributed.new_group(list(range(paddle.distributed.ParallelEnv().nranks)), backend="nccl") state_dict = flatten_state_dict(state_dict) local_load_files = get_local_load_files(path, state_dict, process_group) @@ -249,10 +250,12 @@ def load_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us read_items = get_read_items(path, state_dict, process_group) loaded_state_dict = {} print(f"before load, state_dict:{state_dict},\n load_infos:{load_infos},\n read_items:{read_items}") - # return for item in read_items: assert item.meta_index in load_infos, f"item:{item}, load_infos:{load_infos}" src_rank, file_name = load_infos[item.meta_index] + storage_chunk_tensor = None + cur_sub_chunk_tensor = None + # The src rank need to load the state_dict. if src_rank == paddle.distributed.get_rank(): if file_name not in loaded_state_dict: # The load state_dict is not distributed tensor but a normal tensor. @@ -263,43 +266,37 @@ def load_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us storage_offsets = item.storage_offset storage_lengths = item.lengths storage_ends = [storage_offset + storage_length for storage_offset, storage_length in zip(storage_offsets, storage_lengths)] - # storage_chunk_tensor = paddle.cast(paddle.slice(storage_local_tensor, list(range(len(storage_lengths))), storage_offsets, storage_ends), paddle.float32) storage_chunk_tensor = paddle.slice(storage_local_tensor, list(range(len(storage_lengths))), storage_offsets, storage_ends) - print(f"src_ran:{src_rank}, item.rank:{item.rank}, process_group:{process_group}, storage_local_tensor:{storage_local_tensor}, storage_chunk_tensor:{storage_chunk_tensor}") - paddle.distributed.broadcast(storage_chunk_tensor, src=src_rank, group=process_group) - if src_rank == item.rank: - cur_local_tensor = state_dict[item.meta_index.param]._local_value() - cur_offsets = item.cur_offset - cur_lengths = item.lengths - cur_ends = [cur_offset + cur_length for cur_offset, cur_length in zip(cur_offsets, cur_lengths)] - cur_sub_chunk_tensor = paddle.slice(cur_local_tensor, list(range(len(cur_lengths))), cur_offsets, cur_ends) - paddle.assign(storage_chunk_tensor, cur_sub_chunk_tensor) - - elif item.rank == paddle.distributed.get_rank(): + # The read item rank need to be assigned + if item.rank == paddle.distributed.get_rank(): assert item.meta_index.param in state_dict, f"item:{item}, state_dict:{state_dict}" cur_local_tensor = state_dict[item.meta_index.param]._local_value() cur_offsets = item.cur_offset cur_lengths = item.lengths cur_ends = [cur_offset + cur_length for cur_offset, cur_length in zip(cur_offsets, cur_lengths)] cur_sub_chunk_tensor = paddle.slice(cur_local_tensor, list(range(len(cur_lengths))), cur_offsets, cur_ends) - print(f"cur_sub_chunk_tensor :{cur_sub_chunk_tensor}, cur_local_tensor:{cur_local_tensor}") - paddle.distributed.broadcast(cur_sub_chunk_tensor, src=src_rank, group=process_group) - print(f"src_rank:{src_rank}, item.rank:{item.rank}, process_group:{process_group}, cur_sub_chunk_tensor:{cur_sub_chunk_tensor}") else: - dummy_tensor = paddle.zeros(item.lengths, dtype=state_dict[item.meta_index.param].dtype) - print(f"dummy_tensor:{dummy_tensor}") - paddle.distributed.broadcast(dummy_tensor, src=src_rank, group=process_group) - print(f"src_rank:{src_rank}, item.rank:{item.rank}, process_group:{process_group}, dummy_tensor:{dummy_tensor}") - # break + cur_sub_chunk_tensor = paddle.zeros(item.lengths, dtype=state_dict[item.meta_index.param].dtype) + + if src_rank == item.rank: + # assign value locally + paddle.assign(storage_chunk_tensor, cur_sub_chunk_tensor) + else: + # assign value remotely + if src_rank == paddle.distributed.get_rank(): + paddle.distributed.broadcast(storage_chunk_tensor, src=src_rank, group=process_group) + else: + paddle.distributed.broadcast(cur_sub_chunk_tensor, src=src_rank, group=process_group) + print(f"after load, state_dict:{state_dict}") def test_get_local_load_files(): - if paddle.distributed.get_rank() == 0: - path = "./output" - else: - path = "./output2" - # path = "./output" + # if paddle.distributed.get_rank() == 0: + # path = "./output" + # else: + # path = "./output2" + path = "./output" # build state_dict import paddle.distributed as dist w1 = paddle.zeros([4,2], dtype=paddle.int64) diff --git a/python/paddle/distributed/checkpoint/metadata.py b/python/paddle/distributed/checkpoint/metadata.py index f46dcf671c95b9..7456899f30cbcc 100644 --- a/python/paddle/distributed/checkpoint/metadata.py +++ b/python/paddle/distributed/checkpoint/metadata.py @@ -19,8 +19,8 @@ @dataclass class ChunkMetadata: - local_shape: List[int] - global_offset: List[int] + local_shape: Tuple[int] + global_offset: Tuple[int] @dataclass(frozen=True) class MetadataIndex: diff --git a/python/paddle/distributed/checkpoint/output/0.metadata b/python/paddle/distributed/checkpoint/output/0.metadata deleted file mode 100644 index 09caecdb1b76583760ae3bde8d85cbff4eea9cd3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 371 zcmZ{fu?oUK42Ek_5Tz)%3xb0#VtoZCDL#Ufw2elrws?2wBIx3DoAG_Uw$>tIrjY;3 zpFh|7lWJC8y17RQ5lCR&T$b&F$2Y$4NcV&UiGH*VrT+ml!9KaX;t^TenUgx~eplOZ z5<=_?f)q2o%XQ2wHRh%TJWT{hhnTqD%%+UX+!UL4<3Yi>X zgI{U0lx4uhA;L3%9&Ds=ggxi~P{%VqcUFOJP5tR4+6oz_O#OGVK^~%_X!s?$Y$Urh Ka4Q3)g2@X36@xPX diff --git a/python/paddle/distributed/checkpoint/output/0_0.distcp b/python/paddle/distributed/checkpoint/output/0_0.distcp deleted file mode 100644 index 147bdc313b3372f3629f84f60705d79c84959730..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 259 zcmZo*nfikP0&1sd^e~khPU#WNE6pva)Jx7UO4Z9P%_+%DEGkN@oYKP+UzD1hpI2N` zRGM5eW%86BRb3titBtuH|0Aa68761SM diff --git a/python/paddle/distributed/checkpoint/output/1_0.distcp b/python/paddle/distributed/checkpoint/output/1_0.distcp deleted file mode 100644 index 00693745275c1e4b50a7081799f27131283323e2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 259 zcmZo*nfikP0&1sd^e~khPU#WNE6pva)Jx7UO4Z9P%_+%DEGkN@oYKP+UzD1hpI2N` zRGM5eW%86BR;DH$vo c>>x`sI8m(>1RBBtRm=&cxu7&Rk|Cvf0B4^}CIA2c diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index 2059cca4a4b1ee..58433899de9ca1 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -19,7 +19,7 @@ import paddle from paddle.distributed.communication.group import is_initialized from metadata import Metadata, ChunkMetadata, MetadataIndex -from utils import merge_state_dict, dedup_state_dict, compute_local_shape_and_global_offset +from utils import compute_local_shape_and_global_offset def check_state_dict(state_dict, process_group): local_keys = list(state_dict.keys()) @@ -35,25 +35,27 @@ def check_file_name(file_name, process_group): for id in all_unique_id[1:]: assert id == all_unique_id[0], f"id:{id} != all_unique_id[0]:{file_name}" -# def merge_state_dict(global_state_dict): -# assert isinstance(global_state_dict, List), "The global_state_dict should be a list." -# out = {} -# for state_dict in global_state_dict: -# for key, val in state_dict.items(): -# if key in out and val not in out[key]: -# out[key].append(val) -# else: -# out[key] = [val] -# return out +def merge_state_dict(global_state_dict): + assert isinstance(global_state_dict, List), "The global_state_dict should be a list." + out = {} + for state_dict in global_state_dict: + for key, val in state_dict.items(): + if key in out: + if val in out[key]: + continue + out[key].append(val) + else: + out[key] = [val] + return out -# def dedup_state_dict(global_state_dict): -# out = {} -# for state_dict in global_state_dict: -# for key, val in state_dict.items(): -# if key in out: -# continue -# out[key] = val -# return out +def dedup_state_dict(global_state_dict): + out = {} + for state_dict in global_state_dict: + for key, val in state_dict.items(): + if key in out: + continue + out[key] = val + return out def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, use_dist=True) -> None: """ @@ -79,11 +81,11 @@ def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us assert isinstance(val, (paddle.Tensor, paddle.base.framework.EagerParamBase)), "Only support dygraph Tensor now, support static DistributedTensor later" # if process_group is None: + # Init the default global process group not is_initialized() and paddle.distributed.init_parallel_env() - process_group = paddle.distributed.new_group(list(range(paddle.distributed.ParallelEnv().nranks)), backend="nccl") + # TODO(pangengzheng): use global default process group + # process_group = paddle.distributed.new_group(list(range(paddle.distributed.ParallelEnv().nranks)), backend="nccl") # calculate (global offset, local shape) of each DTensor - local_state_dict = {} - metadata = Metadata() unique_id = 0 file_name = "" while(True): @@ -95,16 +97,22 @@ def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us check_file_name(file_name, process_group) # the parameter_name and order in state_dict should be the same check_state_dict(state_dict, process_group) + local_state_dict = {} + metadata = Metadata() local_chunk_metadata = {} local_storage_metadata = {} for key, val in state_dict.items(): if isinstance(val, paddle.Tensor): + # Case1: not initialized means this tensor is placed in another mesh which do not contain this rank + if not val._is_initialized(): + continue if val.is_dist(): - local_tensor = val.get_tensor().get_tensor() local_shape, global_offset = compute_local_shape_and_global_offset(val.shape, val.dist_attr.process_mesh, val.dist_attr.dims_mapping) - # gather local_shape and global_offset from all ranks of each parameter + if not local_shape or not global_offset: + continue local_chunk_metadata[key] = ChunkMetadata(local_shape, global_offset) local_storage_metadata[MetadataIndex(key, tuple(global_offset))] = file_name + local_tensor = val._local_value() else: local_tensor = val local_state_dict[key] = local_tensor @@ -115,13 +123,14 @@ def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us metadata.state_dict_metadata = merge_state_dict(global_chunk_metadata) metadata.storage_metadata = dedup_state_dict(global_storage_metadata) if coordinator_rank == paddle.distributed.get_rank(): + print(f"global_chunk_metadata:{global_chunk_metadata}") + print(f"global_storage_metadata:{global_storage_metadata}") print(f"metadata:{metadata}") paddle.save(metadata, os.path.join(path, f"{unique_id}.metadata")) print(f"local_state_dict:{local_state_dict}") - for k,v in local_state_dict.items(): - # the phi::DenseTensor only support convert to np.array - local_state_dict[k] = np.array(v) - print(f"local_state_dict name:{k}, val:{local_state_dict[k]}, type:{type(local_state_dict[k])}") + # for k,v in local_state_dict.items(): + # local_state_dict[k] = np.array(v) + # print(f"local_state_dict name:{k}, val:{local_state_dict[k]}, type:{type(local_state_dict[k])}") paddle.save(local_state_dict, os.path.join(path, file_name)) @@ -131,9 +140,11 @@ def test_save_state_dict(): w1 = paddle.arange(8).reshape([4, 2]) w2 = paddle.arange(8, 12).reshape([2, 2]) mesh = dist.ProcessMesh([0,1], dim_names=["x"]) + mesh2 = dist.ProcessMesh([2,3], dim_names=["x"]) w1_dist_attr = dist.DistAttr(mesh, sharding_specs=["x", None]) sharded_w1 = dist.shard_tensor(w1, dist_attr=w1_dist_attr) - w2_dist_attr = dist.DistAttr(mesh, sharding_specs=[None, None]) + # w2_dist_attr = dist.DistAttr(mesh, sharding_specs=[None, None]) + w2_dist_attr = dist.DistAttr(mesh2, sharding_specs=["x", None]) sharded_w2 = dist.shard_tensor(w2, dist_attr=w2_dist_attr) state_dict = {"w1": sharded_w1, "w2": sharded_w2} save_state_dict(state_dict, "./output") diff --git a/python/paddle/distributed/checkpoint/test_save_state_dict.sh b/python/paddle/distributed/checkpoint/test_save_state_dict.sh index 8e4dadc4c2095f..a211bc9299e199 100644 --- a/python/paddle/distributed/checkpoint/test_save_state_dict.sh +++ b/python/paddle/distributed/checkpoint/test_save_state_dict.sh @@ -1,2 +1,2 @@ rm -rf log/* -python -u -m paddle.distributed.launch --log_dir ./log --devices 0,1 save_state_dict.py +python -u -m paddle.distributed.launch --log_dir ./log --devices 0,1,2,3 save_state_dict.py diff --git a/python/paddle/distributed/checkpoint/utils.py b/python/paddle/distributed/checkpoint/utils.py index 7482dfca0f48d6..639862e4dfbe2f 100644 --- a/python/paddle/distributed/checkpoint/utils.py +++ b/python/paddle/distributed/checkpoint/utils.py @@ -6,6 +6,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple import copy from typing import List, Optional @@ -19,33 +20,14 @@ def get_coordinator(mesh:np.array, rank:int): assert rand_coordinator.shape[0] in (0, 1), f"rand_coordinator.shape: {rand_coordinator.shape}" return rand_coordinator[0].tolist() if rand_coordinator.shape[0] > 0 else None -def merge_state_dict(global_state_dict): - assert isinstance(global_state_dict, List), "The global_state_dict should be a list." - out = {} - for state_dict in global_state_dict: - for key, val in state_dict.items(): - if key in out and val not in out[key]: - out[key].append(val) - else: - out[key] = [val] - return out - -def dedup_state_dict(global_state_dict): - out = {} - for state_dict in global_state_dict: - for key, val in state_dict.items(): - if key in out: - continue - out[key] = val - return out - # TODO(pangengzheng): support DeviceMesh and Placement later, device_mesh:Optional[core.ProcessMesh, core.DeviceMesh], placements:Optional[List[int], core.Placement] -def compute_local_shape_and_global_offset(global_shape:List[int], process_mesh:core.ProcessMesh, dims_mapping:List[int]): +def compute_local_shape_and_global_offset(global_shape:List[int], process_mesh:core.ProcessMesh, dims_mapping:List[int]) -> Tuple[Tuple[int], Tuple[int]]: """ tensor dist_attr look like: {process_mesh: {shape: [2], process_ids: [0,1], dim_names: [x]}, dims_mapping: [-1,0], batch_dim: 0, dynamic_dims: [], annotated: [dims_mapping: 1,process_mesh: 1], partial: [].} the tensor dims=2, dims_mapping means the dim0 is replicate, dim1 is shard by dim0 of process_mesh """ mesh = np.array(process_mesh.process_ids).reshape(process_mesh.shape) + # deal with cross mesh case if paddle.distributed.get_rank() not in mesh: return ((), ()) rank_coordinator = get_coordinator(mesh, paddle.distributed.get_rank()) @@ -61,7 +43,7 @@ def compute_local_shape_and_global_offset(global_shape:List[int], process_mesh:c chunk_idx = rank_coordinator[dim] global_offset[i] = chunk_idx * local_shape[i] - return local_shape, global_offset + return tuple(local_shape), tuple(global_offset) def main_test(): import paddle.distributed as dist From 9e2094af94c514fffab32281b4aa93e24814ea2a Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Fri, 24 Nov 2023 18:25:56 +0800 Subject: [PATCH 04/24] polish --- .../distributed/checkpoint/output2/0.metadata | Bin 371 -> 0 bytes .../distributed/checkpoint/output2/0_0.distcp | Bin 259 -> 0 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 python/paddle/distributed/checkpoint/output2/0.metadata delete mode 100644 python/paddle/distributed/checkpoint/output2/0_0.distcp diff --git a/python/paddle/distributed/checkpoint/output2/0.metadata b/python/paddle/distributed/checkpoint/output2/0.metadata deleted file mode 100644 index 09caecdb1b76583760ae3bde8d85cbff4eea9cd3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 371 zcmZ{fu?oUK42Ek_5Tz)%3xb0#VtoZCDL#Ufw2elrws?2wBIx3DoAG_Uw$>tIrjY;3 zpFh|7lWJC8y17RQ5lCR&T$b&F$2Y$4NcV&UiGH*VrT+ml!9KaX;t^TenUgx~eplOZ z5<=_?f)q2o%XQ2wHRh%TJWT{hhnTqD%%+UX+!UL4<3Yi>X zgI{U0lx4uhA;L3%9&Ds=ggxi~P{%VqcUFOJP5tR4+6oz_O#OGVK^~%_X!s?$Y$Urh Ka4Q3)g2@X36@xPX diff --git a/python/paddle/distributed/checkpoint/output2/0_0.distcp b/python/paddle/distributed/checkpoint/output2/0_0.distcp deleted file mode 100644 index 147bdc313b3372f3629f84f60705d79c84959730..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 259 zcmZo*nfikP0&1sd^e~khPU#WNE6pva)Jx7UO4Z9P%_+%DEGkN@oYKP+UzD1hpI2N` zRGM5eW%86BRb3titBtuH|0Aa68761SM From 786a318a20e4dcc6b93f9831840d1f3edc3b79f0 Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Tue, 28 Nov 2023 11:19:37 +0800 Subject: [PATCH 05/24] fix compute overlap bug --- python/paddle/distributed/checkpoint/load_state_dict.py | 8 ++++++-- python/paddle/distributed/checkpoint/save_state_dict.py | 9 ++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index 9d9dc51eef59ba..506163a8df0f62 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -163,7 +163,7 @@ def compute_overlap(cur_chunk_metadata:ChunkMetadata, storage_chunk_metadata:Chu cur_offsets.append(0) storage_offsets.append(begin_offset - storage_offset) elif begin_offset == storage_offset: - cur_offsets.append(end_offset - cur_offset) + cur_offsets.append(begin_offset - cur_offset) storage_offsets.append(0) else: assert False, "Should not reach here." @@ -200,6 +200,8 @@ def get_read_items(path, state_dict, process_group): if isinstance(val, paddle.Tensor): if val.is_dist(): local_shape, global_offset = compute_local_shape_and_global_offset(val.shape, val.dist_attr.process_mesh, val.dist_attr.dims_mapping) + if not local_shape or not global_offset: + continue cur_chunk_metadata = ChunkMetadata(local_shape, global_offset) assert param_name in param_to_chunkmetadata, f"param_name:{param_name} not found in param_to_chunkmetadata:{param_to_chunkmetadata}." for storage_chunk_metadata in param_to_chunkmetadata[param_name]: @@ -224,6 +226,7 @@ def flatten_state_dict(state_dict): # TODO, {"model": {"w0": xxx}} -> {model.w0: xxx} return state_dict + def load_state_dict(state_dict, path, process_group=None, coordinator_rank=0, use_dist=True) -> None: """ Load the state_dict inplace from a checkpoint path. @@ -288,7 +291,8 @@ def load_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us else: paddle.distributed.broadcast(cur_sub_chunk_tensor, src=src_rank, group=process_group) - print(f"after load, state_dict:{state_dict}") + local_state_dict = { k:v._local_value() for k, v in state_dict.items()} + print(f"after load, local_state_dict:{local_state_dict} \n state_dict:{state_dict}") def test_get_local_load_files(): diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index 58433899de9ca1..9c7b1e9997dff5 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -79,13 +79,11 @@ def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us if len(state_dict) > 0: for val in state_dict.values(): assert isinstance(val, (paddle.Tensor, paddle.base.framework.EagerParamBase)), "Only support dygraph Tensor now, support static DistributedTensor later" - # + if process_group is None: # Init the default global process group not is_initialized() and paddle.distributed.init_parallel_env() - # TODO(pangengzheng): use global default process group - # process_group = paddle.distributed.new_group(list(range(paddle.distributed.ParallelEnv().nranks)), backend="nccl") - # calculate (global offset, local shape) of each DTensor + unique_id = 0 file_name = "" while(True): @@ -128,9 +126,6 @@ def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us print(f"metadata:{metadata}") paddle.save(metadata, os.path.join(path, f"{unique_id}.metadata")) print(f"local_state_dict:{local_state_dict}") - # for k,v in local_state_dict.items(): - # local_state_dict[k] = np.array(v) - # print(f"local_state_dict name:{k}, val:{local_state_dict[k]}, type:{type(local_state_dict[k])}") paddle.save(local_state_dict, os.path.join(path, file_name)) From 8f64e81f043bb7a3c51f73a18bed4a4c0650bf74 Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Wed, 29 Nov 2023 20:53:03 +0800 Subject: [PATCH 06/24] test save load in dp_mp unittest --- python/paddle/distributed/__init__.py | 5 +++++ .../distributed/checkpoint/load_state_dict.py | 18 ++++++------------ .../distributed/checkpoint/save_state_dict.py | 18 +++++++++--------- python/paddle/distributed/checkpoint/utils.py | 2 +- setup.py | 1 + .../semi_auto_parallel_simple_net_dp_mp.py | 17 +++++++++++++++++ 6 files changed, 39 insertions(+), 22 deletions(-) diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 5dd29c4d74fbb6..7c2ea5d3a20e30 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -103,6 +103,9 @@ from . import rpc # noqa: F401 +from .checkpoint.save_state_dict import save_state_dict +from .checkpoint.load_state_dict import load_state_dict + __all__ = [ "io", "spawn", @@ -157,4 +160,6 @@ "Shard", "Replicate", "Partial", + "save_state_dict", + "load_state_dict", ] diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index 506163a8df0f62..e3eb659c60b6a4 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -19,8 +19,8 @@ import paddle from paddle.distributed.communication.group import is_initialized -from metadata import Metadata, ChunkMetadata, MetadataIndex -from utils import compute_local_shape_and_global_offset +from .metadata import Metadata, ChunkMetadata, MetadataIndex +from .utils import compute_local_shape_and_global_offset @dataclass(frozen=True) class ReadItem: @@ -195,7 +195,7 @@ def get_read_items(path, state_dict, process_group): param_to_chunkmetadata[param_name] = [] param_to_chunkmetadata[param_name] += chunk_metadata read_items = [] - print(f"param_to_chunkmetadata:{param_to_chunkmetadata}\n state_dict:{state_dict}") + print(f"param_to_chunkmetadata:{param_to_chunkmetadata}") for param_name, val in state_dict.items(): if isinstance(val, paddle.Tensor): if val.is_dist(): @@ -296,20 +296,14 @@ def load_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us def test_get_local_load_files(): - # if paddle.distributed.get_rank() == 0: - # path = "./output" - # else: - # path = "./output2" path = "./output" # build state_dict import paddle.distributed as dist w1 = paddle.zeros([4,2], dtype=paddle.int64) w2 = paddle.zeros([2,2], dtype=paddle.int64) - mesh = dist.ProcessMesh([0,1,2,3], dim_names=["x"]) - w1_dist_attr = dist.DistAttr(mesh, sharding_specs=["x", None]) - sharded_w1 = dist.shard_tensor(w1, dist_attr=w1_dist_attr) - w2_dist_attr = dist.DistAttr(mesh, sharding_specs=[None, None]) - sharded_w2 = dist.shard_tensor(w2, dist_attr=w2_dist_attr) + mesh = dist.ProcessMesh([0,1,2,3]) + sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0), dist.Replicate()]) + sharded_w2 = dist.shard_tensor(w2, mesh, [dist.Replicate(), dist.Replicate()]) state_dict = {"w1": sharded_w1, "w2": sharded_w2} load_state_dict(state_dict, path) diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index 9c7b1e9997dff5..dba1e9f1e21802 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -18,8 +18,8 @@ import paddle from paddle.distributed.communication.group import is_initialized -from metadata import Metadata, ChunkMetadata, MetadataIndex -from utils import compute_local_shape_and_global_offset +from .metadata import Metadata, ChunkMetadata, MetadataIndex +from .utils import compute_local_shape_and_global_offset def check_state_dict(state_dict, process_group): local_keys = list(state_dict.keys()) @@ -80,6 +80,9 @@ def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us for val in state_dict.values(): assert isinstance(val, (paddle.Tensor, paddle.base.framework.EagerParamBase)), "Only support dygraph Tensor now, support static DistributedTensor later" + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) + if process_group is None: # Init the default global process group not is_initialized() and paddle.distributed.init_parallel_env() @@ -134,13 +137,10 @@ def test_save_state_dict(): import paddle.distributed as dist w1 = paddle.arange(8).reshape([4, 2]) w2 = paddle.arange(8, 12).reshape([2, 2]) - mesh = dist.ProcessMesh([0,1], dim_names=["x"]) - mesh2 = dist.ProcessMesh([2,3], dim_names=["x"]) - w1_dist_attr = dist.DistAttr(mesh, sharding_specs=["x", None]) - sharded_w1 = dist.shard_tensor(w1, dist_attr=w1_dist_attr) - # w2_dist_attr = dist.DistAttr(mesh, sharding_specs=[None, None]) - w2_dist_attr = dist.DistAttr(mesh2, sharding_specs=["x", None]) - sharded_w2 = dist.shard_tensor(w2, dist_attr=w2_dist_attr) + mesh = dist.ProcessMesh([0,1]) + mesh2 = dist.ProcessMesh([2,3]) + sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0), dist.Replicate()]) + sharded_w2 = dist.shard_tensor(w2, mesh2, [dist.Shard(0), dist.Replicate()]) state_dict = {"w1": sharded_w1, "w2": sharded_w2} save_state_dict(state_dict, "./output") diff --git a/python/paddle/distributed/checkpoint/utils.py b/python/paddle/distributed/checkpoint/utils.py index 639862e4dfbe2f..25b9df74d9cd01 100644 --- a/python/paddle/distributed/checkpoint/utils.py +++ b/python/paddle/distributed/checkpoint/utils.py @@ -20,7 +20,7 @@ def get_coordinator(mesh:np.array, rank:int): assert rand_coordinator.shape[0] in (0, 1), f"rand_coordinator.shape: {rand_coordinator.shape}" return rand_coordinator[0].tolist() if rand_coordinator.shape[0] > 0 else None -# TODO(pangengzheng): support DeviceMesh and Placement later, device_mesh:Optional[core.ProcessMesh, core.DeviceMesh], placements:Optional[List[int], core.Placement] + def compute_local_shape_and_global_offset(global_shape:List[int], process_mesh:core.ProcessMesh, dims_mapping:List[int]) -> Tuple[Tuple[int], Tuple[int]]: """ tensor dist_attr look like: {process_mesh: {shape: [2], process_ids: [0,1], dim_names: [x]}, dims_mapping: [-1,0], batch_dim: 0, dynamic_dims: [], annotated: [dims_mapping: 1,process_mesh: 1], partial: [].} diff --git a/setup.py b/setup.py index e0e52c27d5b639..8b9225ef7a6d08 100644 --- a/setup.py +++ b/setup.py @@ -1380,6 +1380,7 @@ def get_setup_parameters(): 'paddle.dataset', 'paddle.reader', 'paddle.distributed', + 'paddle.distributed.checkpoint', 'paddle.distributed.communication', 'paddle.distributed.communication.stream', 'paddle.distributed.metric', diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py index 81b55cf266a08a..3d8974df0784e1 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py @@ -43,6 +43,7 @@ def test_dp_mp_demo_net(self): DemoNet("dp_mp_hybrid_strategy"), self._mesh, self.shard_fn ) + ( self.dp_mp_loss, self.dp_mp_parameters, @@ -55,6 +56,22 @@ def test_dp_mp_demo_net(self): self.check_tensor_eq(param, param_base) self.check_tensor_eq(param.grad, param_base.grad) + # save load + ckpt_path = "/ckpt_output/" + state_dict = model.state_dict() + local_state_dict = {} + for k, v in state_dict.items(): + local_state_dict[k] = v._local_value().clone() + paddle.distributed.save_state_dict(state_dict, ckpt_path) + for k, v in state_dict.items(): + v._local_value().add_(paddle.ones_like(v._local_value())) + paddle.distributed.load_state_dict(state_dict, ckpt_path) + for k, v in state_dict.items(): + assert k in local_state_dict, k + self.check_tensor_eq(v._local_value(), local_state_dict[k]) + os.system(f"rm -rf {ckpt_path}") + + def run_test_case(self): self.test_dp_mp_demo_net() From 250b1b7c626fd5ca8746fb103c9597fb071d80ea Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Fri, 1 Dec 2023 15:58:28 +0800 Subject: [PATCH 07/24] fix get local file bug and test --- .../distributed/checkpoint/load_state_dict.py | 6 +- python/setup.py.in | 1 + .../hybrid_strategy/CMakeLists.txt | 6 + .../hybrid_strategy/load_state_dict.py | 109 ++++++++++++++++++ .../hybrid_strategy/save_state_dict.py | 46 ++++++++ .../test_save_load_state_dict.py | 45 ++++++++ .../hybrid_strategy/testslist.csv | 1 + 7 files changed, 211 insertions(+), 3 deletions(-) create mode 100644 test/auto_parallel/hybrid_strategy/load_state_dict.py create mode 100644 test/auto_parallel/hybrid_strategy/save_state_dict.py create mode 100644 test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index e3eb659c60b6a4..15313fa042884f 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -81,14 +81,14 @@ def get_local_load_files(path, state_dict, process_group): # step 3, update the rank_to_read_files def get_least_read_files_ranks(rank_to_read_files): nums = [(rank, len(files)) for rank, files in rank_to_read_files.items()] - sorted(nums, key=lambda x: x[1]) - ranks = [rank for rank, num in nums if num == nums[0][0]] + nums = sorted(nums, key=lambda x: x[1]) + ranks = [rank for rank, num in nums if num == nums[0][1]] return ranks def get_read_rank_file(rank_to_files, ranks): if len(rank_to_files) == 0: return (None, None) nums = [(rank, len(files)) for rank, files in rank_to_files.items() if rank in ranks] - sorted(nums, key=lambda x: x[1]) + nums = sorted(nums, key=lambda x: x[1]) rank = nums[0][0] return (rank, rank_to_files[rank][0]) def update(rank_to_read_files, rank_to_files, rank_file): diff --git a/python/setup.py.in b/python/setup.py.in index 25e1c2ca8df7cc..bb78960840916c 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -371,6 +371,7 @@ packages=['paddle', 'paddle.dataset', 'paddle.reader', 'paddle.distributed', + 'paddle.distributed.checkpoint', 'paddle.distributed.communication', 'paddle.distributed.communication.stream', 'paddle.distributed.metric', diff --git a/test/auto_parallel/hybrid_strategy/CMakeLists.txt b/test/auto_parallel/hybrid_strategy/CMakeLists.txt index 257f716dfa192b..ef1aaf5376445e 100644 --- a/test/auto_parallel/hybrid_strategy/CMakeLists.txt +++ b/test/auto_parallel/hybrid_strategy/CMakeLists.txt @@ -11,4 +11,10 @@ if((WITH_GPU) AND (LINUX)) "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") set_tests_properties(test_semi_auto_parallel_hybrid_strategy PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID") + py_test_modules( + test_save_load_state_dict MODULES + test_save_load_state_dict ENVS + "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_save_load_state_dict + PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID") endif() diff --git a/test/auto_parallel/hybrid_strategy/load_state_dict.py b/test/auto_parallel/hybrid_strategy/load_state_dict.py new file mode 100644 index 00000000000000..b0a359c7f0ce35 --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/load_state_dict.py @@ -0,0 +1,109 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.distributed import save_state_dict, load_state_dict +from paddle.distributed.checkpoint.utils import get_coordinator, compute_local_shape_and_global_offset +from auto_parallel.hybrid_strategy.save_state_dict import get_global_state_dict, ckpt_path + + +class TestLoadStateDict: + def test_load_state_dict_with_same_cards(self): + global_state_dict = get_global_state_dict() + saved_w1, saved_w2 = list(global_state_dict.values()) + w1 = paddle.zeros_like(saved_w1) + w2 = paddle.zeros_like(saved_w2) + mesh = dist.ProcessMesh([0,1,2,3]) + sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0), dist.Replicate()]) + sharded_w2 = dist.shard_tensor(w2, mesh, [dist.Replicate(), dist.Replicate()]) + state_dict = dict(zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2])) + load_state_dict(state_dict, ckpt_path()) + # check + cur_rank = paddle.distributed.get_rank() + expect_w1 = saved_w1.split(4, axis=0)[cur_rank] + expect_w2 = sharded_w2 + expect_state_dict = dict(zip(list(global_state_dict.keys()), [expect_w1, expect_w2])) + for k, v in state_dict.items(): + assert k in expect_state_dict, k + print(f"k:{k}, v:{v}, expect_state_dict[k]:{expect_state_dict[k]}") + self.check_tensor_eq(v._local_value(), expect_state_dict[k]) + + def test_load_state_dict_with_less_cards(self): + global_state_dict = get_global_state_dict() + saved_w1, saved_w2 = list(global_state_dict.values()) + w1 = paddle.zeros_like(saved_w1) + w2 = paddle.zeros_like(saved_w2) + mesh = dist.ProcessMesh([0,1]) + sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)]) + sharded_w2 = dist.shard_tensor(w2, mesh, [dist.Shard(1)]) + state_dict = dict(zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2])) + load_state_dict(state_dict, ckpt_path()) + # check + cur_rank = paddle.distributed.get_rank() + expect_w1 = saved_w1.split(2, axis=0)[cur_rank] + expect_w2 = saved_w2.split(2, axis=1)[cur_rank] + expect_state_dict = dict(zip(list(global_state_dict.keys()), [expect_w1, expect_w2])) + for k, v in state_dict.items(): + assert k in expect_state_dict, k + print(f"k:{k}, v:{v}, expect_state_dict[k]:{expect_state_dict[k]}") + self.check_tensor_eq(v._local_value(), expect_state_dict[k]) + + def test_load_state_dict_with_more_cards(self): + global_state_dict = get_global_state_dict() + saved_w1, saved_w2 = list(global_state_dict.values()) + w1 = paddle.zeros_like(saved_w1) + w2 = paddle.zeros_like(saved_w2) + mesh = dist.ProcessMesh([[0,1,2,3], [4,5,6,7]]) + sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(1), dist.Shard(0)]) + sharded_w2 = dist.shard_tensor(w2, mesh, [dist.Shard(0)]) + state_dict = dict(zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2])) + load_state_dict(state_dict, ckpt_path()) + # check + cur_rank = paddle.distributed.get_rank() + local_shape, global_offset = compute_local_shape_and_global_offset(sharded_w1.shape, sharded_w1.dist_attr.process_mesh, sharded_w1.dist_attr.dims_mapping) + end_offset = [offset + length for offset, length in zip(global_offset, local_shape)] + print(f"local_shape:{local_shape}, global_offset:{global_offset}, end_offset:{end_offset}") + expect_w1 = paddle.slice(saved_w1, axes=[0, 1], starts=global_offset, ends=end_offset) + cur_coordinator = get_coordinator(np.array([[0,1,2,3], [4,5,6,7]]), cur_rank) + expect_w2 = saved_w2.split(2, axis=0)[cur_coordinator[0]] + expect_state_dict = dict(zip(list(global_state_dict.keys()), [expect_w1, expect_w2])) + for k, v in state_dict.items(): + assert k in expect_state_dict, k + print(f"k:{k}, v:{v}, expect_state_dict[k]:{expect_state_dict[k]}") + self.check_tensor_eq(v._local_value(), expect_state_dict[k]) + + def check_tensor_eq(self, a, b, verbose=True): + np1 = a.astype("float32").numpy() + np2 = b.astype("float32").numpy() + np.testing.assert_equal( + np1, np2, verbose=verbose + ) + + + def run_test_case(self): + device_num = int(os.getenv("device_num")) + if device_num == 2: + self.test_load_state_dict_with_less_cards() + elif device_num == 4: + self.test_load_state_dict_with_same_cards() + elif device_num == 8: + self.test_load_state_dict_with_more_cards() + else: + raise ValueError("device_num should be 2,4 or 8") + +if __name__ == '__main__': + TestLoadStateDict().run_test_case() \ No newline at end of file diff --git a/test/auto_parallel/hybrid_strategy/save_state_dict.py b/test/auto_parallel/hybrid_strategy/save_state_dict.py new file mode 100644 index 00000000000000..6a38a6477779ca --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/save_state_dict.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import paddle +import paddle.distributed as dist +from paddle.distributed import save_state_dict + + +def get_global_state_dict(): + w1 = paddle.arange(32).reshape([4, 8]) + w2 = paddle.arange(32, 36).reshape([2, 2]) + return {"w1":w1, "w2":w2} + +def ckpt_path(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp_ckpt_output") + +class TestSaveStateDict: + def test_save_state_dict(self): + global_state_dict = get_global_state_dict() + keys = list(global_state_dict.keys()) + w1, w2 = list(global_state_dict.values()) + mesh = dist.ProcessMesh([0,1]) + mesh2 = dist.ProcessMesh([2,3]) + sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0), dist.Replicate()]) + sharded_w2 = dist.shard_tensor(w2, mesh2, [dist.Shard(0), dist.Replicate()]) + state_dict = dict(zip(keys, [sharded_w1, sharded_w2])) + save_state_dict(state_dict, ckpt_path()) + + def run_test_case(self): + self.test_save_state_dict() + +if __name__ == "__main__": + TestSaveStateDict().run_test_case() diff --git a/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py b/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py new file mode 100644 index 00000000000000..1be29d3dd70328 --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import collective.test_communication_api_base as test_base +from auto_parallel.hybrid_strategy.save_state_dict import ckpt_path + +class TestSaveLoadStateDict(test_base.CommunicationTestDistBase): + def setUp(self): + self._default_envs = {} + self._changeable_envs = {"device_num": ['2','4','8']} + + def test_save_load_state_dict(self): + # save with 4 devices + os.system(f"rm -rf {ckpt_path()}") + super().setUp(num_of_devices=4, timeout=120, nnode=1) + self.run_test_case("save_state_dict.py") + # load with 2, 4, 8 devices + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + super().setUp(save_log_dir="./log", num_of_devices=int(envs["device_num"]), timeout=120, nnode=1) + self.run_test_case( + "load_state_dict.py", + user_defined_envs=envs, + ) + os.system(f"rm -rf {ckpt_path()}") + +if __name__ == '__main__': + unittest.main() + diff --git a/test/auto_parallel/hybrid_strategy/testslist.csv b/test/auto_parallel/hybrid_strategy/testslist.csv index 8a9e3fe28e21c2..0820f4611e2a58 100644 --- a/test/auto_parallel/hybrid_strategy/testslist.csv +++ b/test/auto_parallel/hybrid_strategy/testslist.csv @@ -1,2 +1,3 @@ name,os,arch,timeout,run_type,launcher,num_port,run_serial,envs,conditions test_semi_auto_parallel_hybrid_strategy,LINUX,GPU,120,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_save_load_state_dict.py,LINUX,GPU,120,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., From bd9348f601e7c82016c7464999a3e95413f6c65b Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Fri, 1 Dec 2023 17:07:17 +0800 Subject: [PATCH 08/24] delete useless files, and rename var --- .../paddle/distributed/checkpoint/__init__.py | 21 ++ .../distributed/checkpoint/load_state_dict.py | 193 +++++++++--------- .../paddle/distributed/checkpoint/metadata.py | 20 +- .../distributed/checkpoint/save_state_dict.py | 45 ++-- python/paddle/distributed/checkpoint/utils.py | 32 +-- 5 files changed, 159 insertions(+), 152 deletions(-) create mode 100644 python/paddle/distributed/checkpoint/__init__.py diff --git a/python/paddle/distributed/checkpoint/__init__.py b/python/paddle/distributed/checkpoint/__init__.py new file mode 100644 index 00000000000000..7de3d719cd6d1c --- /dev/null +++ b/python/paddle/distributed/checkpoint/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .save_state_dict import save_state_dict +from .load_state_dict import load_state_dict + +__all__ = [ + "save_state_dict", + "load_state_dict", +] \ No newline at end of file diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index 15313fa042884f..fc518f7635612e 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -18,32 +18,35 @@ import paddle from paddle.distributed.communication.group import is_initialized +from paddle.distributed.fleet.utils.log_util import logger -from .metadata import Metadata, ChunkMetadata, MetadataIndex -from .utils import compute_local_shape_and_global_offset +from .metadata import Metadata, LocalTensorMetadata, LocalTensorIndex +from .utils import compute_local_shape_and_global_offset, flatten_state_dict @dataclass(frozen=True) class ReadItem: + local_tensor_index:LocalTensorIndex rank:int - meta_index:MetadataIndex cur_offset:Tuple[int] storage_offset:Tuple[int] lengths:Tuple[int] -def get_local_load_files(path, state_dict, process_group): +def get_rank_to_files(path, state_dict, process_group): # step 1, get neccesary files to be read accessible_files = os.listdir(path) metadata_files = [file for file in accessible_files if file.endswith(".metadata")] assert len(metadata_files) > 0, "No metadata file found in the checkpoint directory:{path}." # The neccesary files to be read + tensor_id_list = [] necessary_files = [] for metadata_file in metadata_files: metadata = paddle.load(os.path.join(path, metadata_file)) for metadata_index, file_name in metadata.storage_metadata.items(): - if metadata_index.param in state_dict: + tensor_id_list.append(metadata_index.tensor_id) + if metadata_index.tensor_id in state_dict: necessary_files.append(file_name) - necessary_files_set = set(necessary_files) + necessary_data_files_set = set(necessary_files) # allgather all accessible files local_data_files = [file for file in accessible_files if file.endswith(".distcp")] global_data_files = [] @@ -52,22 +55,40 @@ def get_local_load_files(path, state_dict, process_group): for files in global_data_files: tmp += files global_data_files_set = set(tmp) - print(f"necessary_files_set:{necessary_files_set}, global_data_files_set:{global_data_files_set}") + logger.info(f"necessary_data_files_set:{necessary_data_files_set}, global_data_files_set:{global_data_files_set}") # check neccesary files in global_data_files - assert global_data_files_set & necessary_files_set == necessary_files_set, \ - f"The checkpoint files are not complete. Please check the checkpoint directory:{path}.global_data_files_set:{global_data_files_set}, necessary_files_set:{necessary_files_set}" + assert global_data_files_set & necessary_data_files_set == necessary_data_files_set, \ + f"The checkpoint files are not complete. Please check the checkpoint directory:{path}.global_data_files_set:{global_data_files_set}, necessary_data_files_set:{necessary_data_files_set}" + missing_keys = set(state_dict.keys()) - set(tensor_id_list) + logger.info(f"missing_keys:{missing_keys}") # step 2, get mapping between ranks and local files rank_to_files = {} - file_to_ranks = {} for rank, local_files in enumerate(global_data_files): if len(local_files) > 0: - local_files = [f for f in local_files if f in necessary_files_set] + local_files = [f for f in local_files if f in necessary_data_files_set] rank_to_files[rank] = local_files - for file in local_files: + logger.info(f"mapping rank_to_files:{rank_to_files}") + +def get_local_load_files(rank_to_files): + """ + Load files in a load-balanced manner. + Example: + Case1: all ranks access the same data files + rank_to_files = {rank0:[0_0.distcp, 1_0.distcp, 2_0.distcp, 3_0.distcp], rank1:[0_0.distcp, 1_0.distcp, 2_0.distcp, 3_0.distcp]} + rank0 return [0_0.distcp, 1_0.distcp], rank1 return [2_0.distcp, 3_0.distcp] + Case2: all ranks access different data files but some overlapped + rank_to_files = {rank0:[0_0.distcp, 1_0.distcp, 2_0.distcp], rank1:[2_0.distcp, 3_0.distcp] + rank0 return [0_0.distcp, 1_0.distcp], rank1 return [2_0.distcp, 3_0.distcp] + Case3: all ranks access different data files and no overlapped + rank_to_files = {rank0:[0_0.distcp, 1_0.distcp], rank1:[2_0.distcp, 3_0.distcp] + rank0 return [0_0.distcp, 1_0.distcp], rank1 return [2_0.distcp, 3_0.distcp] + """ + file_to_ranks = {} + for rank, files in rank_to_files.items(): + for file in files: if file not in file_to_ranks: file_to_ranks[file] = [] file_to_ranks[file].append(rank) - print(f"mapping rank_to_files:{rank_to_files}, file_to_ranks:{file_to_ranks}") rank_to_read_files = {rank:[] for rank in rank_to_files.keys()} for file, ranks in file_to_ranks.items(): if len(ranks) == 1: @@ -77,8 +98,7 @@ def get_local_load_files(path, state_dict, process_group): if len(rank_to_files[rank]) == 0: rank_to_files.pop(rank) - print(f"start rank_to_read_files:{rank_to_read_files}, rank_to_files:{rank_to_files}") - # step 3, update the rank_to_read_files + logger.info(f"start rank_to_read_files:{rank_to_read_files}, rank_to_files:{rank_to_files}") def get_least_read_files_ranks(rank_to_read_files): nums = [(rank, len(files)) for rank, files in rank_to_read_files.items()] nums = sorted(nums, key=lambda x: x[1]) @@ -105,25 +125,25 @@ def update(rank_to_read_files, rank_to_files, rank_file): if f not in file_to_ranks: file_to_ranks[f] = [] file_to_ranks[f].append(r) - print(f"file_to_ranks:{file_to_ranks}") + logger.info(f"file_to_ranks:{file_to_ranks}") if file in file_to_ranks: for r in file_to_ranks[file]: rank_to_files[r].remove(file) if len(rank_to_files[r]) == 0: rank_to_files.pop(r) - # step 4, get final rank_to_read_files + while len(rank_to_files) > 0: ranks = get_least_read_files_ranks(rank_to_read_files) rank_file = get_read_rank_file(rank_to_files, ranks) update(rank_to_read_files, rank_to_files, rank_file) - print(f"update rank_to_read_files:{rank_to_read_files}, rank_to_files:{rank_to_files}, ranks:{ranks}, rank_file:{rank_file}") - print(f"rank_to_read_files:{rank_to_read_files}") + logger.info(f"update rank_to_read_files:{rank_to_read_files}, rank_to_files:{rank_to_files}, ranks:{ranks}, rank_file:{rank_file}") + logger.info(f"rank_to_read_files:{rank_to_read_files}") cur_rank = paddle.distributed.get_rank() if cur_rank in rank_to_read_files: - print(f"cur_rank:{cur_rank}, rank_to_read_files[cur_rank]:{rank_to_read_files[cur_rank]}") + logger.info(f"cur_rank:{cur_rank}, rank_to_read_files[cur_rank]:{rank_to_read_files[cur_rank]}") return rank_to_read_files[cur_rank] else: - print(f"rank:{cur_rank} does not need to load checkpoint") + logger.info(f"rank:{cur_rank} does not need to load checkpoint") return [] @@ -134,28 +154,28 @@ def get_load_infos(path, local_load_files, process_group): assert len(metadata_files) > 0, "No metadata file found in the checkpoint directory:{path}." for metadata_file in metadata_files: metadata = paddle.load(os.path.join(path, metadata_file)) - for meta_index, file_name in metadata.storage_metadata.items(): + for local_tensor_index, file_name in metadata.storage_metadata.items(): if file_name in local_load_files: - load_info[meta_index] = (paddle.distributed.get_rank(), file_name) + load_info[local_tensor_index] = (paddle.distributed.get_rank(), file_name) load_info_list = [] paddle.distributed.all_gather_object(load_info_list, load_info, process_group) load_infos = {} for load_info in load_info_list: - for meta_index, (rank, file_name) in load_info.items(): - assert meta_index not in load_infos - load_infos[meta_index] = (rank, file_name) + for local_tensor_index, (rank, file_name) in load_info.items(): + assert local_tensor_index not in load_infos + load_infos[local_tensor_index] = (rank, file_name) return load_infos -def compute_overlap(cur_chunk_metadata:ChunkMetadata, storage_chunk_metadata:ChunkMetadata): +def compute_overlap(cur_chunk_metadata:LocalTensorMetadata, storage_local_tensor_metadata:LocalTensorMetadata): cur_offsets = [] storage_offsets = [] lengths = [] for cur_len, cur_offset, strorage_len, storage_offset in zip( cur_chunk_metadata.local_shape, cur_chunk_metadata.global_offset, - storage_chunk_metadata.local_shape, - storage_chunk_metadata.global_offset + storage_local_tensor_metadata.local_shape, + storage_local_tensor_metadata.global_offset ): begin_offset = max(cur_offset, storage_offset) end_offset = min(cur_offset + cur_len, storage_offset + strorage_len) @@ -172,12 +192,12 @@ def compute_overlap(cur_chunk_metadata:ChunkMetadata, storage_chunk_metadata:Chu return cur_offsets, storage_offsets, lengths -def not_overlap(cur_chunk_metadata:ChunkMetadata, storage_chunk_metadata:ChunkMetadata): +def not_overlap(cur_chunk_metadata:LocalTensorMetadata, storage_local_tensor_metadata:LocalTensorMetadata): for cur_len, cur_offset, strorage_len, storage_offset in zip( cur_chunk_metadata.local_shape, cur_chunk_metadata.global_offset, - storage_chunk_metadata.local_shape, - storage_chunk_metadata.global_offset + storage_local_tensor_metadata.local_shape, + storage_local_tensor_metadata.global_offset ): if cur_offset >= (storage_offset + strorage_len) or (cur_offset + cur_len) <= storage_offset: return True @@ -187,29 +207,29 @@ def get_read_items(path, state_dict, process_group): accessible_files = os.listdir(path) metadata_files = [file for file in accessible_files if file.endswith(".metadata")] assert len(metadata_files) > 0, "No metadata file found in the checkpoint directory:{path}." - param_to_chunkmetadata = {} + storage_state_dict_metadata = {} for metadata_file in metadata_files: metadata = paddle.load(os.path.join(path, metadata_file)) - for param_name, chunk_metadata in metadata.state_dict_metadata.items(): - if param_name not in param_to_chunkmetadata: - param_to_chunkmetadata[param_name] = [] - param_to_chunkmetadata[param_name] += chunk_metadata + for tensor_id, local_tensor_metadata in metadata.state_dict_metadata.items(): + if tensor_id not in storage_state_dict_metadata: + storage_state_dict_metadata[tensor_id] = [] + storage_state_dict_metadata[tensor_id] += local_tensor_metadata read_items = [] - print(f"param_to_chunkmetadata:{param_to_chunkmetadata}") - for param_name, val in state_dict.items(): + logger.info(f"storage_state_dict_metadata:{storage_state_dict_metadata}") + for tensor_id, val in state_dict.items(): if isinstance(val, paddle.Tensor): if val.is_dist(): local_shape, global_offset = compute_local_shape_and_global_offset(val.shape, val.dist_attr.process_mesh, val.dist_attr.dims_mapping) if not local_shape or not global_offset: continue - cur_chunk_metadata = ChunkMetadata(local_shape, global_offset) - assert param_name in param_to_chunkmetadata, f"param_name:{param_name} not found in param_to_chunkmetadata:{param_to_chunkmetadata}." - for storage_chunk_metadata in param_to_chunkmetadata[param_name]: - if not_overlap(cur_chunk_metadata, storage_chunk_metadata): + cur_chunk_metadata = LocalTensorMetadata(global_offset, local_shape) + assert tensor_id in storage_state_dict_metadata, f"tensor_id:{tensor_id} not found in storage_state_dict_metadata:{storage_state_dict_metadata}." + for storage_local_tensor_metadata in storage_state_dict_metadata[tensor_id]: + if not_overlap(cur_chunk_metadata, storage_local_tensor_metadata): continue - cur_offsets, storage_offsets, lengths = compute_overlap(cur_chunk_metadata, storage_chunk_metadata) - storage_meta_index = MetadataIndex(param_name, tuple(storage_chunk_metadata.global_offset)) - read_items.append(ReadItem(paddle.distributed.get_rank(), storage_meta_index, tuple(cur_offsets), tuple(storage_offsets), tuple(lengths))) + cur_offsets, storage_offsets, lengths = compute_overlap(cur_chunk_metadata, storage_local_tensor_metadata) + storage_local_tensor_index = LocalTensorIndex(tensor_id, tuple(storage_local_tensor_metadata.global_offset)) + read_items.append(ReadItem(storage_local_tensor_index, paddle.distributed.get_rank(), tuple(cur_offsets), tuple(storage_offsets), tuple(lengths))) else: assert False, f"Only support distributed tensor., val type:{type(val)}" else: @@ -222,10 +242,6 @@ def get_read_items(path, state_dict, process_group): global_read_items.append(item) return global_read_items -def flatten_state_dict(state_dict): - # TODO, {"model": {"w0": xxx}} -> {model.w0: xxx} - return state_dict - def load_state_dict(state_dict, path, process_group=None, coordinator_rank=0, use_dist=True) -> None: """ @@ -235,83 +251,70 @@ def load_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us path: The directory to load checkpoint files. process_group: ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards. coordinator_rank: The rank used to coordinate the checkpoint. Rank0 is used by default. - use_dict: Whether to load the state_dict in distributed mode. Set True by default. + use_dist: Whether to load the state_dict in distributed mode. Set True by default. Example: .. code-block:: python import paddle ... """ + assert isinstance(state_dict, dict), "The state_dict should be a dictionary." + state_dict = flatten_state_dict(state_dict) + if len(state_dict) > 0: + for val in state_dict.values(): + assert isinstance(val, (paddle.Tensor, paddle.base.framework.EagerParamBase)), "Only support dygraph Tensor now, support static DistributedTensor later" + if process_group is None: # Init the default global process group not is_initialized() and paddle.distributed.init_parallel_env() - # process_group = paddle.distributed.new_group(list(range(paddle.distributed.ParallelEnv().nranks)), backend="nccl") - state_dict = flatten_state_dict(state_dict) - local_load_files = get_local_load_files(path, state_dict, process_group) - # load_infos: {MetaIndex: (rank, file_name)} + rank_to_files = get_rank_to_files(path, state_dict, process_group) + local_load_files = get_local_load_files(rank_to_files) + # load_infos: {LocalTensorIndex: (rank, file_name)}, which local tensor located in which file, and the file is load in which rank. load_infos = get_load_infos(path, local_load_files, process_group) + # read_items: [ReadItem(local_tensor_index, rank, cur_offsets, storage_offsets, lengths)], + # slice the storage local tensor in (storage_offsets, lengths) to assign the current tensor in (cur_offsets, lengths) in rank. read_items = get_read_items(path, state_dict, process_group) - loaded_state_dict = {} - print(f"before load, state_dict:{state_dict},\n load_infos:{load_infos},\n read_items:{read_items}") + storage_file_to_state_dict = {} + logger.info(f"before load, state_dict:{state_dict},\n load_infos:{load_infos},\n read_items:{read_items}") for item in read_items: - assert item.meta_index in load_infos, f"item:{item}, load_infos:{load_infos}" - src_rank, file_name = load_infos[item.meta_index] + assert item.local_tensor_index in load_infos, f"item:{item}, load_infos:{load_infos}" + src_rank, file_name = load_infos[item.local_tensor_index] storage_chunk_tensor = None - cur_sub_chunk_tensor = None + cur_chunk_tensor = None # The src rank need to load the state_dict. if src_rank == paddle.distributed.get_rank(): - if file_name not in loaded_state_dict: - # The load state_dict is not distributed tensor but a normal tensor. - loaded_state_dict[file_name] = paddle.load(os.path.join(path, file_name)) - storage_state_dict = loaded_state_dict[file_name] - assert item.meta_index.param in storage_state_dict - storage_local_tensor = storage_state_dict[item.meta_index.param] + if file_name not in storage_file_to_state_dict: + # The value in state_dict is not distributed tensor but a normal tensor. + storage_file_to_state_dict[file_name] = paddle.load(os.path.join(path, file_name)) + storage_state_dict = storage_file_to_state_dict[file_name] + assert item.local_tensor_index.tensor_id in storage_state_dict + storage_local_tensor = storage_state_dict[item.local_tensor_index.tensor_id] storage_offsets = item.storage_offset storage_lengths = item.lengths storage_ends = [storage_offset + storage_length for storage_offset, storage_length in zip(storage_offsets, storage_lengths)] + # The storage_chunk_tensor and storage_local_tensor share the same memory. storage_chunk_tensor = paddle.slice(storage_local_tensor, list(range(len(storage_lengths))), storage_offsets, storage_ends) # The read item rank need to be assigned if item.rank == paddle.distributed.get_rank(): - assert item.meta_index.param in state_dict, f"item:{item}, state_dict:{state_dict}" - cur_local_tensor = state_dict[item.meta_index.param]._local_value() + assert item.local_tensor_index.tensor_id in state_dict, f"item:{item}, state_dict:{state_dict}" + cur_local_tensor = state_dict[item.local_tensor_index.tensor_id]._local_value() cur_offsets = item.cur_offset cur_lengths = item.lengths cur_ends = [cur_offset + cur_length for cur_offset, cur_length in zip(cur_offsets, cur_lengths)] - cur_sub_chunk_tensor = paddle.slice(cur_local_tensor, list(range(len(cur_lengths))), cur_offsets, cur_ends) + # The cur_chunk_tensor and cur_local_tensor share the same memory. + cur_chunk_tensor = paddle.slice(cur_local_tensor, list(range(len(cur_lengths))), cur_offsets, cur_ends) else: - cur_sub_chunk_tensor = paddle.zeros(item.lengths, dtype=state_dict[item.meta_index.param].dtype) + cur_chunk_tensor = paddle.zeros(item.lengths, dtype=state_dict[item.local_tensor_index.tensor_id].dtype) if src_rank == item.rank: # assign value locally - paddle.assign(storage_chunk_tensor, cur_sub_chunk_tensor) + paddle.assign(storage_chunk_tensor, cur_chunk_tensor) else: # assign value remotely if src_rank == paddle.distributed.get_rank(): paddle.distributed.broadcast(storage_chunk_tensor, src=src_rank, group=process_group) else: - paddle.distributed.broadcast(cur_sub_chunk_tensor, src=src_rank, group=process_group) + paddle.distributed.broadcast(cur_chunk_tensor, src=src_rank, group=process_group) local_state_dict = { k:v._local_value() for k, v in state_dict.items()} - print(f"after load, local_state_dict:{local_state_dict} \n state_dict:{state_dict}") - - -def test_get_local_load_files(): - path = "./output" - # build state_dict - import paddle.distributed as dist - w1 = paddle.zeros([4,2], dtype=paddle.int64) - w2 = paddle.zeros([2,2], dtype=paddle.int64) - mesh = dist.ProcessMesh([0,1,2,3]) - sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0), dist.Replicate()]) - sharded_w2 = dist.shard_tensor(w2, mesh, [dist.Replicate(), dist.Replicate()]) - state_dict = {"w1": sharded_w1, "w2": sharded_w2} - load_state_dict(state_dict, path) - - - - -def test_load_state_dict(): - test_get_local_load_files() - -if __name__ == "__main__": - test_load_state_dict() + logger.info(f"after load, local_state_dict:{local_state_dict} \n state_dict:{state_dict}") diff --git a/python/paddle/distributed/checkpoint/metadata.py b/python/paddle/distributed/checkpoint/metadata.py index 7456899f30cbcc..96e238fd8e8de7 100644 --- a/python/paddle/distributed/checkpoint/metadata.py +++ b/python/paddle/distributed/checkpoint/metadata.py @@ -12,22 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Dict, Optional +from typing import List, Tuple, Dict from dataclasses import dataclass @dataclass -class ChunkMetadata: - local_shape: Tuple[int] +class LocalTensorMetadata: + """ + The location of a local tensor in the global tensor. + """ global_offset: Tuple[int] + local_shape: Tuple[int] @dataclass(frozen=True) -class MetadataIndex: - param: str +class LocalTensorIndex: + """ + The identifier of a local tensor. + """ + tensor_id: str global_offset: Tuple[int] @dataclass class Metadata: - state_dict_metadata: Dict[str, List[ChunkMetadata]] = None - storage_metadata: Dict[MetadataIndex, str] = None \ No newline at end of file + state_dict_metadata: Dict[str, List[LocalTensorMetadata]] = None + storage_metadata: Dict[LocalTensorIndex, str] = None \ No newline at end of file diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index dba1e9f1e21802..1acb57ec398e30 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -18,8 +18,9 @@ import paddle from paddle.distributed.communication.group import is_initialized -from .metadata import Metadata, ChunkMetadata, MetadataIndex -from .utils import compute_local_shape_and_global_offset +from paddle.distributed.fleet.utils.log_util import logger +from .metadata import Metadata, LocalTensorMetadata, LocalTensorIndex +from .utils import compute_local_shape_and_global_offset, flatten_state_dict def check_state_dict(state_dict, process_group): local_keys = list(state_dict.keys()) @@ -76,6 +77,7 @@ def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us """ assert isinstance(state_dict, dict), "The state_dict should be a dictionary." + state_dict = flatten_state_dict(state_dict) if len(state_dict) > 0: for val in state_dict.values(): assert isinstance(val, (paddle.Tensor, paddle.base.framework.EagerParamBase)), "Only support dygraph Tensor now, support static DistributedTensor later" @@ -87,6 +89,7 @@ def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us # Init the default global process group not is_initialized() and paddle.distributed.init_parallel_env() + unique_id = 0 file_name = "" while(True): @@ -94,13 +97,13 @@ def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us if not os.path.exists(os.path.join(path, file_name)): break unique_id += 1 - print(f"file_name:{file_name}") + logger.info(f"file_name:{file_name}") check_file_name(file_name, process_group) # the parameter_name and order in state_dict should be the same check_state_dict(state_dict, process_group) local_state_dict = {} metadata = Metadata() - local_chunk_metadata = {} + local_tensor_metadata = {} local_storage_metadata = {} for key, val in state_dict.items(): if isinstance(val, paddle.Tensor): @@ -111,38 +114,22 @@ def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us local_shape, global_offset = compute_local_shape_and_global_offset(val.shape, val.dist_attr.process_mesh, val.dist_attr.dims_mapping) if not local_shape or not global_offset: continue - local_chunk_metadata[key] = ChunkMetadata(local_shape, global_offset) - local_storage_metadata[MetadataIndex(key, tuple(global_offset))] = file_name + local_tensor_metadata[key] = LocalTensorMetadata(global_offset, local_shape) + local_storage_metadata[LocalTensorIndex(key, tuple(global_offset))] = file_name local_tensor = val._local_value() else: local_tensor = val local_state_dict[key] = local_tensor - global_chunk_metadata = [] + global_tensor_metadata = [] global_storage_metadata = [] - paddle.distributed.all_gather_object(global_chunk_metadata, local_chunk_metadata, process_group) + paddle.distributed.all_gather_object(global_tensor_metadata, local_tensor_metadata, process_group) paddle.distributed.all_gather_object(global_storage_metadata, local_storage_metadata, process_group) - metadata.state_dict_metadata = merge_state_dict(global_chunk_metadata) + metadata.state_dict_metadata = merge_state_dict(global_tensor_metadata) metadata.storage_metadata = dedup_state_dict(global_storage_metadata) if coordinator_rank == paddle.distributed.get_rank(): - print(f"global_chunk_metadata:{global_chunk_metadata}") - print(f"global_storage_metadata:{global_storage_metadata}") - print(f"metadata:{metadata}") + logger.info(f"global_tensor_metadata:{global_tensor_metadata}") + logger.info(f"global_storage_metadata:{global_storage_metadata}") + logger.info(f"metadata:{metadata}") paddle.save(metadata, os.path.join(path, f"{unique_id}.metadata")) - print(f"local_state_dict:{local_state_dict}") + logger.info(f"local_state_dict:{local_state_dict}") paddle.save(local_state_dict, os.path.join(path, file_name)) - - - -def test_save_state_dict(): - import paddle.distributed as dist - w1 = paddle.arange(8).reshape([4, 2]) - w2 = paddle.arange(8, 12).reshape([2, 2]) - mesh = dist.ProcessMesh([0,1]) - mesh2 = dist.ProcessMesh([2,3]) - sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0), dist.Replicate()]) - sharded_w2 = dist.shard_tensor(w2, mesh2, [dist.Shard(0), dist.Replicate()]) - state_dict = {"w1": sharded_w1, "w2": sharded_w2} - save_state_dict(state_dict, "./output") - -if __name__ == "__main__": - test_save_state_dict() diff --git a/python/paddle/distributed/checkpoint/utils.py b/python/paddle/distributed/checkpoint/utils.py index 25b9df74d9cd01..aa8ec0d113b0c0 100644 --- a/python/paddle/distributed/checkpoint/utils.py +++ b/python/paddle/distributed/checkpoint/utils.py @@ -1,3 +1,9 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software @@ -8,13 +14,13 @@ from typing import Tuple import copy -from typing import List, Optional +from typing import List, Union import numpy as np import paddle from paddle.framework import core -def get_coordinator(mesh:np.array, rank:int): +def get_coordinator(mesh:Union[np.array, List[List[int]]], rank:int): mesh = paddle.to_tensor(mesh) rand_coordinator = (mesh == rank).nonzero() assert rand_coordinator.shape[0] in (0, 1), f"rand_coordinator.shape: {rand_coordinator.shape}" @@ -22,10 +28,6 @@ def get_coordinator(mesh:np.array, rank:int): def compute_local_shape_and_global_offset(global_shape:List[int], process_mesh:core.ProcessMesh, dims_mapping:List[int]) -> Tuple[Tuple[int], Tuple[int]]: - """ - tensor dist_attr look like: {process_mesh: {shape: [2], process_ids: [0,1], dim_names: [x]}, dims_mapping: [-1,0], batch_dim: 0, dynamic_dims: [], annotated: [dims_mapping: 1,process_mesh: 1], partial: [].} - the tensor dims=2, dims_mapping means the dim0 is replicate, dim1 is shard by dim0 of process_mesh - """ mesh = np.array(process_mesh.process_ids).reshape(process_mesh.shape) # deal with cross mesh case if paddle.distributed.get_rank() not in mesh: @@ -33,7 +35,6 @@ def compute_local_shape_and_global_offset(global_shape:List[int], process_mesh:c rank_coordinator = get_coordinator(mesh, paddle.distributed.get_rank()) local_shape = copy.copy(global_shape) global_offset = [0 for _ in global_shape] - # print(f"rank_coordinator:{rank_coordinator}") for i, dim in enumerate(dims_mapping): if dim == -1: continue @@ -45,17 +46,6 @@ def compute_local_shape_and_global_offset(global_shape:List[int], process_mesh:c return tuple(local_shape), tuple(global_offset) -def main_test(): - import paddle.distributed as dist - - tensor = paddle.arange(8).reshape([4, 2]) - global_shape = tensor.shape - mesh = dist.ProcessMesh([[0,1], [2,3]], dim_names=["x", "y"]) - dist_attr = dist.DistAttr(mesh, sharding_specs=["x", "y"]) - sharded_tensor = dist.shard_tensor(tensor, dist_attr=dist_attr) - print(f"get_tensor:{sharded_tensor.get_tensor().get_tensor()}, sharded_tensor.dist_attr:{sharded_tensor.dist_attr}") - local_shape, global_offset = compute_local_shape_and_global_offset(global_shape, sharded_tensor.dist_attr.process_mesh, sharded_tensor.dist_attr.dims_mapping) - print(f"local_shape:{local_shape}, global_offset: {global_offset}") - -if __name__ == "__main__": - main_test() +def flatten_state_dict(state_dict): + # TODO, {"model": {"w0": xxx}} -> {model.w0: xxx} + return state_dict From ecee68bab26f08b2f0c29cbe54f53bf34c3b905c Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Fri, 1 Dec 2023 17:07:46 +0800 Subject: [PATCH 09/24] polish --- .../distributed/checkpoint/broadcast_test.py | 2 - .../distributed/checkpoint/hang_test.py | 53 ------------------- .../distributed/checkpoint/test_hang_test.py | 3 -- .../checkpoint/test_load_state_dict.sh | 3 -- .../checkpoint/test_save_state_dict.sh | 2 - .../distributed/checkpoint/test_utils.sh | 2 - 6 files changed, 65 deletions(-) delete mode 100644 python/paddle/distributed/checkpoint/broadcast_test.py delete mode 100644 python/paddle/distributed/checkpoint/hang_test.py delete mode 100644 python/paddle/distributed/checkpoint/test_hang_test.py delete mode 100644 python/paddle/distributed/checkpoint/test_load_state_dict.sh delete mode 100644 python/paddle/distributed/checkpoint/test_save_state_dict.sh delete mode 100644 python/paddle/distributed/checkpoint/test_utils.sh diff --git a/python/paddle/distributed/checkpoint/broadcast_test.py b/python/paddle/distributed/checkpoint/broadcast_test.py deleted file mode 100644 index de5bd9f7d44912..00000000000000 --- a/python/paddle/distributed/checkpoint/broadcast_test.py +++ /dev/null @@ -1,2 +0,0 @@ -import paddle - diff --git a/python/paddle/distributed/checkpoint/hang_test.py b/python/paddle/distributed/checkpoint/hang_test.py deleted file mode 100644 index 4f7fbd3f7bce22..00000000000000 --- a/python/paddle/distributed/checkpoint/hang_test.py +++ /dev/null @@ -1,53 +0,0 @@ -import pysnooper - -import paddle -from paddle.distributed.communication.group import is_initialized - -@pysnooper.snoop(output=f"snooper{paddle.distributed.get_rank()}.log", depth=1, max_variable_length=200) -def get_read_items(path, state_dict, process_group): - print(f"pure hang test", flush=True) - # for param_name, val in state_dict.items(): - for param_name, val in enumerate(range(2)): - if True or isinstance(val, paddle.Tensor): - print(f"before val:{val}, type:{type(val)}", flush=True) - if True or val.is_dist(): - paddle.distributed.barrier() - # pass - # local_shape, global_offset = compute_local_shape_and_global_offset(val.shape, val.dist_attr.process_mesh, val.dist_attr.dims_mapping) - # cur_chunk_metadata = ChunkMetadata(local_shape, global_offset) - # assert param_name in param_to_chunkmetadata, f"param_name:{param_name} not found in param_to_chunkmetadata:{param_to_chunkmetadata}." - # for storage_chunk_metadata in param_to_chunkmetadata[param_name]: - for storage_chunk_metadata in range(2): - print(f"rank:{paddle.distributed.get_rank()}, storage_chunk_metadata:{storage_chunk_metadata}", flush=True) - # paddle.distributed.barrier() - print(f"param_name:{param_name}, storage_chunk_metadata:{storage_chunk_metadata}") - if paddle.distributed.get_rank() == 0 or paddle.distributed.get_rank() == 1: - continue - else: - continue - else: - print(f"val:{val}, type:{type(val)}") - pass - else: - pass - return - -def main(): - path = "./output" - ###!!! Init the Disttensor and turn on the pysnooper at the same time will lead to hang !!! - - # import paddle.distributed as dist - # w1 = paddle.arange(8).reshape([4, 2]) - # w2 = paddle.arange(8, 12).reshape([2, 2]) - # mesh = dist.ProcessMesh([0,1,2,3], dim_names=["x"]) - # w1_dist_attr = dist.DistAttr(mesh, sharding_specs=["x", None]) - # sharded_w1 = dist.shard_tensor(w1, dist_attr=w1_dist_attr) - # w2_dist_attr = dist.DistAttr(mesh, sharding_specs=[None, None]) - # sharded_w2 = dist.shard_tensor(w2, dist_attr=w2_dist_attr) - # state_dict = {"w1": sharded_w1, "w2": sharded_w2} - - not is_initialized() and paddle.distributed.init_parallel_env() - get_read_items(path, None, None) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/python/paddle/distributed/checkpoint/test_hang_test.py b/python/paddle/distributed/checkpoint/test_hang_test.py deleted file mode 100644 index 07556ce8e2afd5..00000000000000 --- a/python/paddle/distributed/checkpoint/test_hang_test.py +++ /dev/null @@ -1,3 +0,0 @@ -rm -rf log/* -rm -f snooper* -python -u -m paddle.distributed.launch --log_dir ./log --devices 0,1,2,3 test.py diff --git a/python/paddle/distributed/checkpoint/test_load_state_dict.sh b/python/paddle/distributed/checkpoint/test_load_state_dict.sh deleted file mode 100644 index 100e0e66500459..00000000000000 --- a/python/paddle/distributed/checkpoint/test_load_state_dict.sh +++ /dev/null @@ -1,3 +0,0 @@ -rm -rf log/* -rm -f snooper* -python -u -m paddle.distributed.launch --log_dir ./log --devices 1,2,3,4 load_state_dict.py diff --git a/python/paddle/distributed/checkpoint/test_save_state_dict.sh b/python/paddle/distributed/checkpoint/test_save_state_dict.sh deleted file mode 100644 index a211bc9299e199..00000000000000 --- a/python/paddle/distributed/checkpoint/test_save_state_dict.sh +++ /dev/null @@ -1,2 +0,0 @@ -rm -rf log/* -python -u -m paddle.distributed.launch --log_dir ./log --devices 0,1,2,3 save_state_dict.py diff --git a/python/paddle/distributed/checkpoint/test_utils.sh b/python/paddle/distributed/checkpoint/test_utils.sh deleted file mode 100644 index 03ce3c88325122..00000000000000 --- a/python/paddle/distributed/checkpoint/test_utils.sh +++ /dev/null @@ -1,2 +0,0 @@ -rm -rf log/* -python -u -m paddle.distributed.launch --log_dir ./log --devices 0,1,2,3 utils.py From a8491b935f9d2e64e59f680f816b5e9f97d9e4ac Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Fri, 1 Dec 2023 17:23:40 +0800 Subject: [PATCH 10/24] format codes --- .../paddle/distributed/checkpoint/__init__.py | 2 +- .../distributed/checkpoint/load_state_dict.py | 266 +++++++++++++----- .../paddle/distributed/checkpoint/metadata.py | 9 +- .../distributed/checkpoint/save_state_dict.py | 71 +++-- python/paddle/distributed/checkpoint/utils.py | 29 +- 5 files changed, 282 insertions(+), 95 deletions(-) diff --git a/python/paddle/distributed/checkpoint/__init__.py b/python/paddle/distributed/checkpoint/__init__.py index 7de3d719cd6d1c..63a317bd0a4b7d 100644 --- a/python/paddle/distributed/checkpoint/__init__.py +++ b/python/paddle/distributed/checkpoint/__init__.py @@ -18,4 +18,4 @@ __all__ = [ "save_state_dict", "load_state_dict", -] \ No newline at end of file +] diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index fc518f7635612e..4dcd60dc88bd11 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -14,29 +14,34 @@ import os from dataclasses import dataclass -from typing import List, Tuple +from typing import Tuple import paddle from paddle.distributed.communication.group import is_initialized from paddle.distributed.fleet.utils.log_util import logger -from .metadata import Metadata, LocalTensorMetadata, LocalTensorIndex +from .metadata import LocalTensorIndex, LocalTensorMetadata from .utils import compute_local_shape_and_global_offset, flatten_state_dict + @dataclass(frozen=True) class ReadItem: - local_tensor_index:LocalTensorIndex - rank:int - cur_offset:Tuple[int] - storage_offset:Tuple[int] - lengths:Tuple[int] + local_tensor_index: LocalTensorIndex + rank: int + cur_offset: Tuple[int] + storage_offset: Tuple[int] + lengths: Tuple[int] def get_rank_to_files(path, state_dict, process_group): # step 1, get neccesary files to be read accessible_files = os.listdir(path) - metadata_files = [file for file in accessible_files if file.endswith(".metadata")] - assert len(metadata_files) > 0, "No metadata file found in the checkpoint directory:{path}." + metadata_files = [ + file for file in accessible_files if file.endswith(".metadata") + ] + assert ( + len(metadata_files) > 0 + ), "No metadata file found in the checkpoint directory:{path}." # The neccesary files to be read tensor_id_list = [] necessary_files = [] @@ -48,27 +53,38 @@ def get_rank_to_files(path, state_dict, process_group): necessary_files.append(file_name) necessary_data_files_set = set(necessary_files) # allgather all accessible files - local_data_files = [file for file in accessible_files if file.endswith(".distcp")] + local_data_files = [ + file for file in accessible_files if file.endswith(".distcp") + ] global_data_files = [] - paddle.distributed.all_gather_object(global_data_files, local_data_files, process_group) + paddle.distributed.all_gather_object( + global_data_files, local_data_files, process_group + ) tmp = [] for files in global_data_files: tmp += files global_data_files_set = set(tmp) - logger.info(f"necessary_data_files_set:{necessary_data_files_set}, global_data_files_set:{global_data_files_set}") + logger.info( + f"necessary_data_files_set:{necessary_data_files_set}, global_data_files_set:{global_data_files_set}" + ) # check neccesary files in global_data_files - assert global_data_files_set & necessary_data_files_set == necessary_data_files_set, \ - f"The checkpoint files are not complete. Please check the checkpoint directory:{path}.global_data_files_set:{global_data_files_set}, necessary_data_files_set:{necessary_data_files_set}" + assert ( + global_data_files_set & necessary_data_files_set + == necessary_data_files_set + ), f"The checkpoint files are not complete. Please check the checkpoint directory:{path}.global_data_files_set:{global_data_files_set}, necessary_data_files_set:{necessary_data_files_set}" missing_keys = set(state_dict.keys()) - set(tensor_id_list) logger.info(f"missing_keys:{missing_keys}") # step 2, get mapping between ranks and local files rank_to_files = {} for rank, local_files in enumerate(global_data_files): if len(local_files) > 0: - local_files = [f for f in local_files if f in necessary_data_files_set] + local_files = [ + f for f in local_files if f in necessary_data_files_set + ] rank_to_files[rank] = local_files logger.info(f"mapping rank_to_files:{rank_to_files}") + def get_local_load_files(rank_to_files): """ Load files in a load-balanced manner. @@ -89,7 +105,7 @@ def get_local_load_files(rank_to_files): if file not in file_to_ranks: file_to_ranks[file] = [] file_to_ranks[file].append(rank) - rank_to_read_files = {rank:[] for rank in rank_to_files.keys()} + rank_to_read_files = {rank: [] for rank in rank_to_files.keys()} for file, ranks in file_to_ranks.items(): if len(ranks) == 1: rank = ranks[0] @@ -97,20 +113,31 @@ def get_local_load_files(rank_to_files): rank_to_files[rank].remove(file) if len(rank_to_files[rank]) == 0: rank_to_files.pop(rank) - - logger.info(f"start rank_to_read_files:{rank_to_read_files}, rank_to_files:{rank_to_files}") + + logger.info( + f"start rank_to_read_files:{rank_to_read_files}, rank_to_files:{rank_to_files}" + ) + def get_least_read_files_ranks(rank_to_read_files): - nums = [(rank, len(files)) for rank, files in rank_to_read_files.items()] + nums = [ + (rank, len(files)) for rank, files in rank_to_read_files.items() + ] nums = sorted(nums, key=lambda x: x[1]) ranks = [rank for rank, num in nums if num == nums[0][1]] return ranks + def get_read_rank_file(rank_to_files, ranks): if len(rank_to_files) == 0: return (None, None) - nums = [(rank, len(files)) for rank, files in rank_to_files.items() if rank in ranks] + nums = [ + (rank, len(files)) + for rank, files in rank_to_files.items() + if rank in ranks + ] nums = sorted(nums, key=lambda x: x[1]) rank = nums[0][0] return (rank, rank_to_files[rank][0]) + def update(rank_to_read_files, rank_to_files, rank_file): rank, file = rank_file if rank is None and file is None: @@ -136,11 +163,15 @@ def update(rank_to_read_files, rank_to_files, rank_file): ranks = get_least_read_files_ranks(rank_to_read_files) rank_file = get_read_rank_file(rank_to_files, ranks) update(rank_to_read_files, rank_to_files, rank_file) - logger.info(f"update rank_to_read_files:{rank_to_read_files}, rank_to_files:{rank_to_files}, ranks:{ranks}, rank_file:{rank_file}") + logger.info( + f"update rank_to_read_files:{rank_to_read_files}, rank_to_files:{rank_to_files}, ranks:{ranks}, rank_file:{rank_file}" + ) logger.info(f"rank_to_read_files:{rank_to_read_files}") cur_rank = paddle.distributed.get_rank() if cur_rank in rank_to_read_files: - logger.info(f"cur_rank:{cur_rank}, rank_to_read_files[cur_rank]:{rank_to_read_files[cur_rank]}") + logger.info( + f"cur_rank:{cur_rank}, rank_to_read_files[cur_rank]:{rank_to_read_files[cur_rank]}" + ) return rank_to_read_files[cur_rank] else: logger.info(f"rank:{cur_rank} does not need to load checkpoint") @@ -150,15 +181,24 @@ def update(rank_to_read_files, rank_to_files, rank_file): def get_load_infos(path, local_load_files, process_group): load_info = {} accessible_files = os.listdir(path) - metadata_files = [file for file in accessible_files if file.endswith(".metadata")] - assert len(metadata_files) > 0, "No metadata file found in the checkpoint directory:{path}." + metadata_files = [ + file for file in accessible_files if file.endswith(".metadata") + ] + assert ( + len(metadata_files) > 0 + ), "No metadata file found in the checkpoint directory:{path}." for metadata_file in metadata_files: metadata = paddle.load(os.path.join(path, metadata_file)) for local_tensor_index, file_name in metadata.storage_metadata.items(): if file_name in local_load_files: - load_info[local_tensor_index] = (paddle.distributed.get_rank(), file_name) + load_info[local_tensor_index] = ( + paddle.distributed.get_rank(), + file_name, + ) load_info_list = [] - paddle.distributed.all_gather_object(load_info_list, load_info, process_group) + paddle.distributed.all_gather_object( + load_info_list, load_info, process_group + ) load_infos = {} for load_info in load_info_list: for local_tensor_index, (rank, file_name) in load_info.items(): @@ -167,7 +207,10 @@ def get_load_infos(path, local_load_files, process_group): return load_infos -def compute_overlap(cur_chunk_metadata:LocalTensorMetadata, storage_local_tensor_metadata:LocalTensorMetadata): +def compute_overlap( + cur_chunk_metadata: LocalTensorMetadata, + storage_local_tensor_metadata: LocalTensorMetadata, +): cur_offsets = [] storage_offsets = [] lengths = [] @@ -175,7 +218,7 @@ def compute_overlap(cur_chunk_metadata:LocalTensorMetadata, storage_local_tensor cur_chunk_metadata.local_shape, cur_chunk_metadata.global_offset, storage_local_tensor_metadata.local_shape, - storage_local_tensor_metadata.global_offset + storage_local_tensor_metadata.global_offset, ): begin_offset = max(cur_offset, storage_offset) end_offset = min(cur_offset + cur_len, storage_offset + strorage_len) @@ -186,31 +229,49 @@ def compute_overlap(cur_chunk_metadata:LocalTensorMetadata, storage_local_tensor cur_offsets.append(begin_offset - cur_offset) storage_offsets.append(0) else: - assert False, "Should not reach here." + raise ValueError( + f"Invalid begin_offset:{begin_offset}, cur_offset:{cur_offset}, storage_offset:{storage_offset}" + ) lengths.append(end_offset - begin_offset) - assert lengths[-1] >= 0, f"Invalid length:{lengths[-1]}, end_offset:{end_offset}, begin_offset:{begin_offset}" + assert ( + lengths[-1] >= 0 + ), f"Invalid length:{lengths[-1]}, end_offset:{end_offset}, begin_offset:{begin_offset}" return cur_offsets, storage_offsets, lengths -def not_overlap(cur_chunk_metadata:LocalTensorMetadata, storage_local_tensor_metadata:LocalTensorMetadata): +def not_overlap( + cur_chunk_metadata: LocalTensorMetadata, + storage_local_tensor_metadata: LocalTensorMetadata, +): for cur_len, cur_offset, strorage_len, storage_offset in zip( cur_chunk_metadata.local_shape, cur_chunk_metadata.global_offset, storage_local_tensor_metadata.local_shape, - storage_local_tensor_metadata.global_offset + storage_local_tensor_metadata.global_offset, ): - if cur_offset >= (storage_offset + strorage_len) or (cur_offset + cur_len) <= storage_offset: + if ( + cur_offset >= (storage_offset + strorage_len) + or (cur_offset + cur_len) <= storage_offset + ): return True return False + def get_read_items(path, state_dict, process_group): accessible_files = os.listdir(path) - metadata_files = [file for file in accessible_files if file.endswith(".metadata")] - assert len(metadata_files) > 0, "No metadata file found in the checkpoint directory:{path}." + metadata_files = [ + file for file in accessible_files if file.endswith(".metadata") + ] + assert ( + len(metadata_files) > 0 + ), "No metadata file found in the checkpoint directory:{path}." storage_state_dict_metadata = {} for metadata_file in metadata_files: metadata = paddle.load(os.path.join(path, metadata_file)) - for tensor_id, local_tensor_metadata in metadata.state_dict_metadata.items(): + for ( + tensor_id, + local_tensor_metadata, + ) in metadata.state_dict_metadata.items(): if tensor_id not in storage_state_dict_metadata: storage_state_dict_metadata[tensor_id] = [] storage_state_dict_metadata[tensor_id] += local_tensor_metadata @@ -219,21 +280,53 @@ def get_read_items(path, state_dict, process_group): for tensor_id, val in state_dict.items(): if isinstance(val, paddle.Tensor): if val.is_dist(): - local_shape, global_offset = compute_local_shape_and_global_offset(val.shape, val.dist_attr.process_mesh, val.dist_attr.dims_mapping) + ( + local_shape, + global_offset, + ) = compute_local_shape_and_global_offset( + val.shape, + val.dist_attr.process_mesh, + val.dist_attr.dims_mapping, + ) if not local_shape or not global_offset: continue - cur_chunk_metadata = LocalTensorMetadata(global_offset, local_shape) - assert tensor_id in storage_state_dict_metadata, f"tensor_id:{tensor_id} not found in storage_state_dict_metadata:{storage_state_dict_metadata}." - for storage_local_tensor_metadata in storage_state_dict_metadata[tensor_id]: - if not_overlap(cur_chunk_metadata, storage_local_tensor_metadata): + cur_chunk_metadata = LocalTensorMetadata( + global_offset, local_shape + ) + assert ( + tensor_id in storage_state_dict_metadata + ), f"tensor_id:{tensor_id} not found in storage_state_dict_metadata:{storage_state_dict_metadata}." + for ( + storage_local_tensor_metadata + ) in storage_state_dict_metadata[tensor_id]: + if not_overlap( + cur_chunk_metadata, storage_local_tensor_metadata + ): continue - cur_offsets, storage_offsets, lengths = compute_overlap(cur_chunk_metadata, storage_local_tensor_metadata) - storage_local_tensor_index = LocalTensorIndex(tensor_id, tuple(storage_local_tensor_metadata.global_offset)) - read_items.append(ReadItem(storage_local_tensor_index, paddle.distributed.get_rank(), tuple(cur_offsets), tuple(storage_offsets), tuple(lengths))) + cur_offsets, storage_offsets, lengths = compute_overlap( + cur_chunk_metadata, storage_local_tensor_metadata + ) + storage_local_tensor_index = LocalTensorIndex( + tensor_id, + tuple(storage_local_tensor_metadata.global_offset), + ) + read_items.append( + ReadItem( + storage_local_tensor_index, + paddle.distributed.get_rank(), + tuple(cur_offsets), + tuple(storage_offsets), + tuple(lengths), + ) + ) else: - assert False, f"Only support distributed tensor., val type:{type(val)}" + raise ValueError( + f"Only support distributed tensor., val type:{type(val)}" + ) else: - assert False, f"Only support paddle.Tensor., val type:{type(val)}" + raise ValueError( + f"Only support paddle.Tensor., val type:{type(val)}" + ) global_read_items = [] tmp = [] paddle.distributed.all_gather_object(tmp, read_items, process_group) @@ -243,7 +336,9 @@ def get_read_items(path, state_dict, process_group): return global_read_items -def load_state_dict(state_dict, path, process_group=None, coordinator_rank=0, use_dist=True) -> None: +def load_state_dict( + state_dict, path, process_group=None, coordinator_rank=0, use_dist=True +) -> None: """ Load the state_dict inplace from a checkpoint path. Args: @@ -257,11 +352,15 @@ def load_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us import paddle ... """ - assert isinstance(state_dict, dict), "The state_dict should be a dictionary." + assert isinstance( + state_dict, dict + ), "The state_dict should be a dictionary." state_dict = flatten_state_dict(state_dict) if len(state_dict) > 0: for val in state_dict.values(): - assert isinstance(val, (paddle.Tensor, paddle.base.framework.EagerParamBase)), "Only support dygraph Tensor now, support static DistributedTensor later" + assert isinstance( + val, (paddle.Tensor, paddle.base.framework.EagerParamBase) + ), "Only support dygraph Tensor now, support static DistributedTensor later" if process_group is None: # Init the default global process group @@ -275,9 +374,13 @@ def load_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us # slice the storage local tensor in (storage_offsets, lengths) to assign the current tensor in (cur_offsets, lengths) in rank. read_items = get_read_items(path, state_dict, process_group) storage_file_to_state_dict = {} - logger.info(f"before load, state_dict:{state_dict},\n load_infos:{load_infos},\n read_items:{read_items}") + logger.info( + f"before load, state_dict:{state_dict},\n load_infos:{load_infos},\n read_items:{read_items}" + ) for item in read_items: - assert item.local_tensor_index in load_infos, f"item:{item}, load_infos:{load_infos}" + assert ( + item.local_tensor_index in load_infos + ), f"item:{item}, load_infos:{load_infos}" src_rank, file_name = load_infos[item.local_tensor_index] storage_chunk_tensor = None cur_chunk_tensor = None @@ -285,26 +388,55 @@ def load_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us if src_rank == paddle.distributed.get_rank(): if file_name not in storage_file_to_state_dict: # The value in state_dict is not distributed tensor but a normal tensor. - storage_file_to_state_dict[file_name] = paddle.load(os.path.join(path, file_name)) + storage_file_to_state_dict[file_name] = paddle.load( + os.path.join(path, file_name) + ) storage_state_dict = storage_file_to_state_dict[file_name] assert item.local_tensor_index.tensor_id in storage_state_dict - storage_local_tensor = storage_state_dict[item.local_tensor_index.tensor_id] + storage_local_tensor = storage_state_dict[ + item.local_tensor_index.tensor_id + ] storage_offsets = item.storage_offset storage_lengths = item.lengths - storage_ends = [storage_offset + storage_length for storage_offset, storage_length in zip(storage_offsets, storage_lengths)] + storage_ends = [ + storage_offset + storage_length + for storage_offset, storage_length in zip( + storage_offsets, storage_lengths + ) + ] # The storage_chunk_tensor and storage_local_tensor share the same memory. - storage_chunk_tensor = paddle.slice(storage_local_tensor, list(range(len(storage_lengths))), storage_offsets, storage_ends) + storage_chunk_tensor = paddle.slice( + storage_local_tensor, + list(range(len(storage_lengths))), + storage_offsets, + storage_ends, + ) # The read item rank need to be assigned if item.rank == paddle.distributed.get_rank(): - assert item.local_tensor_index.tensor_id in state_dict, f"item:{item}, state_dict:{state_dict}" - cur_local_tensor = state_dict[item.local_tensor_index.tensor_id]._local_value() + assert ( + item.local_tensor_index.tensor_id in state_dict + ), f"item:{item}, state_dict:{state_dict}" + cur_local_tensor = state_dict[ + item.local_tensor_index.tensor_id + ]._local_value() cur_offsets = item.cur_offset cur_lengths = item.lengths - cur_ends = [cur_offset + cur_length for cur_offset, cur_length in zip(cur_offsets, cur_lengths)] + cur_ends = [ + cur_offset + cur_length + for cur_offset, cur_length in zip(cur_offsets, cur_lengths) + ] # The cur_chunk_tensor and cur_local_tensor share the same memory. - cur_chunk_tensor = paddle.slice(cur_local_tensor, list(range(len(cur_lengths))), cur_offsets, cur_ends) + cur_chunk_tensor = paddle.slice( + cur_local_tensor, + list(range(len(cur_lengths))), + cur_offsets, + cur_ends, + ) else: - cur_chunk_tensor = paddle.zeros(item.lengths, dtype=state_dict[item.local_tensor_index.tensor_id].dtype) + cur_chunk_tensor = paddle.zeros( + item.lengths, + dtype=state_dict[item.local_tensor_index.tensor_id].dtype, + ) if src_rank == item.rank: # assign value locally @@ -312,9 +444,15 @@ def load_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us else: # assign value remotely if src_rank == paddle.distributed.get_rank(): - paddle.distributed.broadcast(storage_chunk_tensor, src=src_rank, group=process_group) + paddle.distributed.broadcast( + storage_chunk_tensor, src=src_rank, group=process_group + ) else: - paddle.distributed.broadcast(cur_chunk_tensor, src=src_rank, group=process_group) + paddle.distributed.broadcast( + cur_chunk_tensor, src=src_rank, group=process_group + ) - local_state_dict = { k:v._local_value() for k, v in state_dict.items()} - logger.info(f"after load, local_state_dict:{local_state_dict} \n state_dict:{state_dict}") + local_state_dict = {k: v._local_value() for k, v in state_dict.items()} + logger.info( + f"after load, local_state_dict:{local_state_dict} \n state_dict:{state_dict}" + ) diff --git a/python/paddle/distributed/checkpoint/metadata.py b/python/paddle/distributed/checkpoint/metadata.py index 96e238fd8e8de7..74a17a0a63ada8 100644 --- a/python/paddle/distributed/checkpoint/metadata.py +++ b/python/paddle/distributed/checkpoint/metadata.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Dict from dataclasses import dataclass - +from typing import Dict, List, Tuple @dataclass @@ -22,18 +21,22 @@ class LocalTensorMetadata: """ The location of a local tensor in the global tensor. """ + global_offset: Tuple[int] local_shape: Tuple[int] + @dataclass(frozen=True) class LocalTensorIndex: """ The identifier of a local tensor. """ + tensor_id: str global_offset: Tuple[int] + @dataclass class Metadata: state_dict_metadata: Dict[str, List[LocalTensorMetadata]] = None - storage_metadata: Dict[LocalTensorIndex, str] = None \ No newline at end of file + storage_metadata: Dict[LocalTensorIndex, str] = None diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index 1acb57ec398e30..2872c0a3c4daa8 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -13,31 +13,42 @@ # limitations under the License. import os -from typing import List, Dict -import numpy as np +from typing import List import paddle from paddle.distributed.communication.group import is_initialized from paddle.distributed.fleet.utils.log_util import logger -from .metadata import Metadata, LocalTensorMetadata, LocalTensorIndex + +from .metadata import LocalTensorIndex, LocalTensorMetadata, Metadata from .utils import compute_local_shape_and_global_offset, flatten_state_dict + def check_state_dict(state_dict, process_group): local_keys = list(state_dict.keys()) gloabl_keys = [] paddle.distributed.all_gather_object(gloabl_keys, local_keys, process_group) for keys in gloabl_keys[1:]: - assert keys == gloabl_keys[0], f"keys:{keys} != first_keys: {gloabl_keys[0]}" + assert ( + keys == gloabl_keys[0] + ), f"keys:{keys} != first_keys: {gloabl_keys[0]}" + def check_file_name(file_name, process_group): all_unique_id = [] unique_id = int(file_name.split(".")[0].split("_")[1]) - paddle.distributed.all_gather_object(all_unique_id, unique_id, process_group) + paddle.distributed.all_gather_object( + all_unique_id, unique_id, process_group + ) for id in all_unique_id[1:]: - assert id == all_unique_id[0], f"id:{id} != all_unique_id[0]:{file_name}" + assert ( + id == all_unique_id[0] + ), f"id:{id} != all_unique_id[0]:{file_name}" + def merge_state_dict(global_state_dict): - assert isinstance(global_state_dict, List), "The global_state_dict should be a list." + assert isinstance( + global_state_dict, List + ), "The global_state_dict should be a list." out = {} for state_dict in global_state_dict: for key, val in state_dict.items(): @@ -49,6 +60,7 @@ def merge_state_dict(global_state_dict): out[key] = [val] return out + def dedup_state_dict(global_state_dict): out = {} for state_dict in global_state_dict: @@ -58,7 +70,10 @@ def dedup_state_dict(global_state_dict): out[key] = val return out -def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, use_dist=True) -> None: + +def save_state_dict( + state_dict, path, process_group=None, coordinator_rank=0, use_dist=True +) -> None: """ Save the state_dict of model to path. @@ -68,19 +83,23 @@ def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us process_group: ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards. coordinator_rank: The rank used to coordinate the checkpoint. Rank0 is used by default. use_dist: Whether to save the state_dict in distributed mode. Set True by default. - + Examples: .. code-block:: python - + import paddle ... """ - assert isinstance(state_dict, dict), "The state_dict should be a dictionary." + assert isinstance( + state_dict, dict + ), "The state_dict should be a dictionary." state_dict = flatten_state_dict(state_dict) if len(state_dict) > 0: for val in state_dict.values(): - assert isinstance(val, (paddle.Tensor, paddle.base.framework.EagerParamBase)), "Only support dygraph Tensor now, support static DistributedTensor later" + assert isinstance( + val, (paddle.Tensor, paddle.base.framework.EagerParamBase) + ), "Only support dygraph Tensor now, support static DistributedTensor later" if not os.path.exists(path): os.makedirs(path, exist_ok=True) @@ -89,10 +108,9 @@ def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us # Init the default global process group not is_initialized() and paddle.distributed.init_parallel_env() - unique_id = 0 file_name = "" - while(True): + while True: file_name = f"{paddle.distributed.get_rank()}_{unique_id}.distcp" if not os.path.exists(os.path.join(path, file_name)): break @@ -111,19 +129,34 @@ def save_state_dict(state_dict, path, process_group=None, coordinator_rank=0, us if not val._is_initialized(): continue if val.is_dist(): - local_shape, global_offset = compute_local_shape_and_global_offset(val.shape, val.dist_attr.process_mesh, val.dist_attr.dims_mapping) + ( + local_shape, + global_offset, + ) = compute_local_shape_and_global_offset( + val.shape, + val.dist_attr.process_mesh, + val.dist_attr.dims_mapping, + ) if not local_shape or not global_offset: continue - local_tensor_metadata[key] = LocalTensorMetadata(global_offset, local_shape) - local_storage_metadata[LocalTensorIndex(key, tuple(global_offset))] = file_name + local_tensor_metadata[key] = LocalTensorMetadata( + global_offset, local_shape + ) + local_storage_metadata[ + LocalTensorIndex(key, tuple(global_offset)) + ] = file_name local_tensor = val._local_value() else: local_tensor = val local_state_dict[key] = local_tensor global_tensor_metadata = [] global_storage_metadata = [] - paddle.distributed.all_gather_object(global_tensor_metadata, local_tensor_metadata, process_group) - paddle.distributed.all_gather_object(global_storage_metadata, local_storage_metadata, process_group) + paddle.distributed.all_gather_object( + global_tensor_metadata, local_tensor_metadata, process_group + ) + paddle.distributed.all_gather_object( + global_storage_metadata, local_storage_metadata, process_group + ) metadata.state_dict_metadata = merge_state_dict(global_tensor_metadata) metadata.storage_metadata = dedup_state_dict(global_storage_metadata) if coordinator_rank == paddle.distributed.get_rank(): diff --git a/python/paddle/distributed/checkpoint/utils.py b/python/paddle/distributed/checkpoint/utils.py index aa8ec0d113b0c0..1262d5d9483517 100644 --- a/python/paddle/distributed/checkpoint/utils.py +++ b/python/paddle/distributed/checkpoint/utils.py @@ -12,22 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import copy -from typing import List, Union +from typing import List, Tuple, Union import numpy as np + import paddle from paddle.framework import core -def get_coordinator(mesh:Union[np.array, List[List[int]]], rank:int): + +def get_coordinator(mesh: Union[np.array, List[List[int]]], rank: int): mesh = paddle.to_tensor(mesh) rand_coordinator = (mesh == rank).nonzero() - assert rand_coordinator.shape[0] in (0, 1), f"rand_coordinator.shape: {rand_coordinator.shape}" - return rand_coordinator[0].tolist() if rand_coordinator.shape[0] > 0 else None + assert rand_coordinator.shape[0] in ( + 0, + 1, + ), f"rand_coordinator.shape: {rand_coordinator.shape}" + return ( + rand_coordinator[0].tolist() if rand_coordinator.shape[0] > 0 else None + ) -def compute_local_shape_and_global_offset(global_shape:List[int], process_mesh:core.ProcessMesh, dims_mapping:List[int]) -> Tuple[Tuple[int], Tuple[int]]: +def compute_local_shape_and_global_offset( + global_shape: List[int], + process_mesh: core.ProcessMesh, + dims_mapping: List[int], +) -> Tuple[Tuple[int], Tuple[int]]: mesh = np.array(process_mesh.process_ids).reshape(process_mesh.shape) # deal with cross mesh case if paddle.distributed.get_rank() not in mesh: @@ -39,13 +49,16 @@ def compute_local_shape_and_global_offset(global_shape:List[int], process_mesh:c if dim == -1: continue else: - assert global_shape[i] % process_mesh.shape[dim] == 0, f"i:{i}, global_shape[i]:{global_shape[i]}, process_mesh.shape[dim]:{process_mesh.shape[dim]}" + assert ( + global_shape[i] % process_mesh.shape[dim] == 0 + ), f"i:{i}, global_shape[i]:{global_shape[i]}, process_mesh.shape[dim]:{process_mesh.shape[dim]}" local_shape[i] = global_shape[i] // process_mesh.shape[dim] chunk_idx = rank_coordinator[dim] global_offset[i] = chunk_idx * local_shape[i] - + return tuple(local_shape), tuple(global_offset) + def flatten_state_dict(state_dict): # TODO, {"model": {"w0": xxx}} -> {model.w0: xxx} return state_dict From 2bf30c583e9590fae8bf017c70587ed93b4d3318 Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Mon, 4 Dec 2023 15:53:25 +0800 Subject: [PATCH 11/24] test use_dist --- .../distributed/checkpoint/load_state_dict.py | 131 ++++++++++------- .../distributed/checkpoint/save_state_dict.py | 50 ++++--- .../hybrid_strategy/CMakeLists.txt | 5 +- .../hybrid_strategy/load_state_dict.py | 134 +++++++++++++----- .../hybrid_strategy/save_state_dict.py | 38 +++-- .../semi_auto_parallel_simple_net_dp_mp.py | 2 - .../test_save_load_state_dict.py | 42 +++++- 7 files changed, 277 insertions(+), 125 deletions(-) diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index 4dcd60dc88bd11..ef5471f48b4150 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -33,8 +33,7 @@ class ReadItem: lengths: Tuple[int] -def get_rank_to_files(path, state_dict, process_group): - # step 1, get neccesary files to be read +def get_rank_to_files(path, state_dict, process_group, use_dist): accessible_files = os.listdir(path) metadata_files = [ file for file in accessible_files if file.endswith(".metadata") @@ -47,9 +46,12 @@ def get_rank_to_files(path, state_dict, process_group): necessary_files = [] for metadata_file in metadata_files: metadata = paddle.load(os.path.join(path, metadata_file)) - for metadata_index, file_name in metadata.storage_metadata.items(): - tensor_id_list.append(metadata_index.tensor_id) - if metadata_index.tensor_id in state_dict: + for local_tensor_index, file_name in metadata.storage_metadata.items(): + assert ( + local_tensor_index not in tensor_id_list + ), f"Duplicate tensor_id:{local_tensor_index} found. Check whether the metadata_file:{metadata_file} contains the same tensor metadata." + tensor_id_list.append(local_tensor_index.tensor_id) + if local_tensor_index.tensor_id in state_dict: necessary_files.append(file_name) necessary_data_files_set = set(necessary_files) # allgather all accessible files @@ -57,9 +59,12 @@ def get_rank_to_files(path, state_dict, process_group): file for file in accessible_files if file.endswith(".distcp") ] global_data_files = [] - paddle.distributed.all_gather_object( - global_data_files, local_data_files, process_group - ) + if use_dist: + paddle.distributed.all_gather_object( + global_data_files, local_data_files, process_group + ) + else: + global_data_files.append(local_data_files) tmp = [] for files in global_data_files: tmp += files @@ -74,7 +79,7 @@ def get_rank_to_files(path, state_dict, process_group): ), f"The checkpoint files are not complete. Please check the checkpoint directory:{path}.global_data_files_set:{global_data_files_set}, necessary_data_files_set:{necessary_data_files_set}" missing_keys = set(state_dict.keys()) - set(tensor_id_list) logger.info(f"missing_keys:{missing_keys}") - # step 2, get mapping between ranks and local files + rank_to_files = {} for rank, local_files in enumerate(global_data_files): if len(local_files) > 0: @@ -83,6 +88,7 @@ def get_rank_to_files(path, state_dict, process_group): ] rank_to_files[rank] = local_files logger.info(f"mapping rank_to_files:{rank_to_files}") + return rank_to_files def get_local_load_files(rank_to_files): @@ -178,7 +184,7 @@ def update(rank_to_read_files, rank_to_files, rank_file): return [] -def get_load_infos(path, local_load_files, process_group): +def get_load_infos(path, local_load_files, process_group, use_dist): load_info = {} accessible_files = os.listdir(path) metadata_files = [ @@ -196,9 +202,12 @@ def get_load_infos(path, local_load_files, process_group): file_name, ) load_info_list = [] - paddle.distributed.all_gather_object( - load_info_list, load_info, process_group - ) + if use_dist: + paddle.distributed.all_gather_object( + load_info_list, load_info, process_group + ) + else: + load_info_list.append(load_info) load_infos = {} for load_info in load_info_list: for local_tensor_index, (rank, file_name) in load_info.items(): @@ -257,7 +266,7 @@ def not_overlap( return False -def get_read_items(path, state_dict, process_group): +def get_read_items(path, state_dict, process_group, use_dist): accessible_files = os.listdir(path) metadata_files = [ file for file in accessible_files if file.endswith(".metadata") @@ -288,40 +297,37 @@ def get_read_items(path, state_dict, process_group): val.dist_attr.process_mesh, val.dist_attr.dims_mapping, ) - if not local_shape or not global_offset: + else: + local_shape = val.shape + global_offset = [0] * len(val.shape) + if not local_shape or not global_offset: + continue + cur_chunk_metadata = LocalTensorMetadata(global_offset, local_shape) + assert ( + tensor_id in storage_state_dict_metadata + ), f"tensor_id:{tensor_id} not found in storage_state_dict_metadata:{storage_state_dict_metadata}." + for storage_local_tensor_metadata in storage_state_dict_metadata[ + tensor_id + ]: + if not_overlap( + cur_chunk_metadata, storage_local_tensor_metadata + ): continue - cur_chunk_metadata = LocalTensorMetadata( - global_offset, local_shape + cur_offsets, storage_offsets, lengths = compute_overlap( + cur_chunk_metadata, storage_local_tensor_metadata ) - assert ( - tensor_id in storage_state_dict_metadata - ), f"tensor_id:{tensor_id} not found in storage_state_dict_metadata:{storage_state_dict_metadata}." - for ( - storage_local_tensor_metadata - ) in storage_state_dict_metadata[tensor_id]: - if not_overlap( - cur_chunk_metadata, storage_local_tensor_metadata - ): - continue - cur_offsets, storage_offsets, lengths = compute_overlap( - cur_chunk_metadata, storage_local_tensor_metadata - ) - storage_local_tensor_index = LocalTensorIndex( - tensor_id, - tuple(storage_local_tensor_metadata.global_offset), - ) - read_items.append( - ReadItem( - storage_local_tensor_index, - paddle.distributed.get_rank(), - tuple(cur_offsets), - tuple(storage_offsets), - tuple(lengths), - ) + storage_local_tensor_index = LocalTensorIndex( + tensor_id, + tuple(storage_local_tensor_metadata.global_offset), + ) + read_items.append( + ReadItem( + storage_local_tensor_index, + paddle.distributed.get_rank(), + tuple(cur_offsets), + tuple(storage_offsets), + tuple(lengths), ) - else: - raise ValueError( - f"Only support distributed tensor., val type:{type(val)}" ) else: raise ValueError( @@ -329,7 +335,10 @@ def get_read_items(path, state_dict, process_group): ) global_read_items = [] tmp = [] - paddle.distributed.all_gather_object(tmp, read_items, process_group) + if use_dist: + paddle.distributed.all_gather_object(tmp, read_items, process_group) + else: + tmp.append(read_items) for items in tmp: for item in items: global_read_items.append(item) @@ -352,6 +361,12 @@ def load_state_dict( import paddle ... """ + if not use_dist and ( + paddle.distributed.get_world_size() > 1 or coordinator_rank != 0 + ): + raise ValueError( + f"use_dist is False, please set coordinator_rank to 0 and paddle.distributed.get_world_size() to 1, world_size:{paddle.distributed.get_world_size()}, coordinator_rank:{coordinator_rank}" + ) assert isinstance( state_dict, dict ), "The state_dict should be a dictionary." @@ -359,20 +374,20 @@ def load_state_dict( if len(state_dict) > 0: for val in state_dict.values(): assert isinstance( - val, (paddle.Tensor, paddle.base.framework.EagerParamBase) + val, paddle.Tensor ), "Only support dygraph Tensor now, support static DistributedTensor later" - if process_group is None: + if use_dist and process_group is None: # Init the default global process group not is_initialized() and paddle.distributed.init_parallel_env() - rank_to_files = get_rank_to_files(path, state_dict, process_group) + rank_to_files = get_rank_to_files(path, state_dict, process_group, use_dist) local_load_files = get_local_load_files(rank_to_files) # load_infos: {LocalTensorIndex: (rank, file_name)}, which local tensor located in which file, and the file is load in which rank. - load_infos = get_load_infos(path, local_load_files, process_group) + load_infos = get_load_infos(path, local_load_files, process_group, use_dist) # read_items: [ReadItem(local_tensor_index, rank, cur_offsets, storage_offsets, lengths)], # slice the storage local tensor in (storage_offsets, lengths) to assign the current tensor in (cur_offsets, lengths) in rank. - read_items = get_read_items(path, state_dict, process_group) + read_items = get_read_items(path, state_dict, process_group, use_dist) storage_file_to_state_dict = {} logger.info( f"before load, state_dict:{state_dict},\n load_infos:{load_infos},\n read_items:{read_items}" @@ -416,9 +431,11 @@ def load_state_dict( assert ( item.local_tensor_index.tensor_id in state_dict ), f"item:{item}, state_dict:{state_dict}" - cur_local_tensor = state_dict[ - item.local_tensor_index.tensor_id - ]._local_value() + cur_local_tensor = ( + state_dict[item.local_tensor_index.tensor_id]._local_value() + if use_dist + else state_dict[item.local_tensor_index.tensor_id] + ) cur_offsets = item.cur_offset cur_lengths = item.lengths cur_ends = [ @@ -452,7 +469,11 @@ def load_state_dict( cur_chunk_tensor, src=src_rank, group=process_group ) - local_state_dict = {k: v._local_value() for k, v in state_dict.items()} + local_state_dict = ( + {k: v._local_value() for k, v in state_dict.items()} + if use_dist + else state_dict + ) logger.info( f"after load, local_state_dict:{local_state_dict} \n state_dict:{state_dict}" ) diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index 2872c0a3c4daa8..9c46a10ac5c890 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -91,6 +91,12 @@ def save_state_dict( ... """ + if not use_dist and ( + paddle.distributed.get_world_size() > 1 or coordinator_rank != 0 + ): + raise ValueError( + f"use_dist is False, please set coordinator_rank to 0 and paddle.distributed.get_world_size() to 1, world_size:{paddle.distributed.get_world_size()}, coordinator_rank:{coordinator_rank}" + ) assert isinstance( state_dict, dict ), "The state_dict should be a dictionary." @@ -98,13 +104,13 @@ def save_state_dict( if len(state_dict) > 0: for val in state_dict.values(): assert isinstance( - val, (paddle.Tensor, paddle.base.framework.EagerParamBase) + val, paddle.Tensor ), "Only support dygraph Tensor now, support static DistributedTensor later" if not os.path.exists(path): os.makedirs(path, exist_ok=True) - if process_group is None: + if use_dist and process_group is None: # Init the default global process group not is_initialized() and paddle.distributed.init_parallel_env() @@ -116,11 +122,12 @@ def save_state_dict( break unique_id += 1 logger.info(f"file_name:{file_name}") - check_file_name(file_name, process_group) - # the parameter_name and order in state_dict should be the same - check_state_dict(state_dict, process_group) - local_state_dict = {} + if use_dist: + check_file_name(file_name, process_group) + # the parameter_name and order in state_dict should be the same + check_state_dict(state_dict, process_group) metadata = Metadata() + local_state_dict = {} local_tensor_metadata = {} local_storage_metadata = {} for key, val in state_dict.items(): @@ -139,24 +146,31 @@ def save_state_dict( ) if not local_shape or not global_offset: continue - local_tensor_metadata[key] = LocalTensorMetadata( - global_offset, local_shape - ) - local_storage_metadata[ - LocalTensorIndex(key, tuple(global_offset)) - ] = file_name local_tensor = val._local_value() else: + global_offset = [0] * len(val.shape) + local_shape = val.shape local_tensor = val local_state_dict[key] = local_tensor + local_tensor_metadata[key] = LocalTensorMetadata( + global_offset, local_shape + ) + local_storage_metadata[ + LocalTensorIndex(key, tuple(global_offset)) + ] = file_name global_tensor_metadata = [] global_storage_metadata = [] - paddle.distributed.all_gather_object( - global_tensor_metadata, local_tensor_metadata, process_group - ) - paddle.distributed.all_gather_object( - global_storage_metadata, local_storage_metadata, process_group - ) + if use_dist: + paddle.distributed.all_gather_object( + global_tensor_metadata, local_tensor_metadata, process_group + ) + paddle.distributed.all_gather_object( + global_storage_metadata, local_storage_metadata, process_group + ) + else: + global_tensor_metadata.append(local_tensor_metadata) + global_storage_metadata.append(local_storage_metadata) + metadata.state_dict_metadata = merge_state_dict(global_tensor_metadata) metadata.storage_metadata = dedup_state_dict(global_storage_metadata) if coordinator_rank == paddle.distributed.get_rank(): diff --git a/test/auto_parallel/hybrid_strategy/CMakeLists.txt b/test/auto_parallel/hybrid_strategy/CMakeLists.txt index ef1aaf5376445e..1dbd27467102b9 100644 --- a/test/auto_parallel/hybrid_strategy/CMakeLists.txt +++ b/test/auto_parallel/hybrid_strategy/CMakeLists.txt @@ -12,9 +12,8 @@ if((WITH_GPU) AND (LINUX)) set_tests_properties(test_semi_auto_parallel_hybrid_strategy PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID") py_test_modules( - test_save_load_state_dict MODULES - test_save_load_state_dict ENVS - "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") + test_save_load_state_dict MODULES test_save_load_state_dict ENVS + "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") set_tests_properties(test_save_load_state_dict PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID") endif() diff --git a/test/auto_parallel/hybrid_strategy/load_state_dict.py b/test/auto_parallel/hybrid_strategy/load_state_dict.py index b0a359c7f0ce35..0f5d2f65e89ac6 100644 --- a/test/auto_parallel/hybrid_strategy/load_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/load_state_dict.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,75 +26,129 @@ # limitations under the License. import os + import numpy as np +from auto_parallel.hybrid_strategy.save_state_dict import ( + ckpt_path, + get_global_state_dict, +) import paddle import paddle.distributed as dist -from paddle.distributed import save_state_dict, load_state_dict -from paddle.distributed.checkpoint.utils import get_coordinator, compute_local_shape_and_global_offset -from auto_parallel.hybrid_strategy.save_state_dict import get_global_state_dict, ckpt_path +from paddle.distributed import load_state_dict +from paddle.distributed.checkpoint.utils import ( + compute_local_shape_and_global_offset, + get_coordinator, +) class TestLoadStateDict: - def test_load_state_dict_with_same_cards(self): + def test_load_state_dict_with_one_device(self): + global_state_dict = get_global_state_dict() + saved_w1, saved_w2 = list(global_state_dict.values()) + w1 = paddle.zeros_like(saved_w1) + w2 = paddle.zeros_like(saved_w2) + state_dict = dict(zip(list(global_state_dict.keys()), [w1, w2])) + load_state_dict(state_dict, ckpt_path(), use_dist=False) + # check + expect_w1 = saved_w1 + expect_w2 = saved_w2 + expect_state_dict = dict( + zip(list(global_state_dict.keys()), [expect_w1, expect_w2]) + ) + for k, v in state_dict.items(): + assert k in expect_state_dict, k + print(f"k:{k}, v:{v}, expect_state_dict[k]:{expect_state_dict[k]}") + self.check_tensor_eq(v, expect_state_dict[k]) + + def test_load_state_dict_with_four_devices(self): global_state_dict = get_global_state_dict() saved_w1, saved_w2 = list(global_state_dict.values()) w1 = paddle.zeros_like(saved_w1) w2 = paddle.zeros_like(saved_w2) - mesh = dist.ProcessMesh([0,1,2,3]) - sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0), dist.Replicate()]) - sharded_w2 = dist.shard_tensor(w2, mesh, [dist.Replicate(), dist.Replicate()]) - state_dict = dict(zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2])) + mesh = dist.ProcessMesh([0, 1, 2, 3]) + sharded_w1 = dist.shard_tensor( + w1, mesh, [dist.Shard(0), dist.Replicate()] + ) + sharded_w2 = dist.shard_tensor( + w2, mesh, [dist.Replicate(), dist.Replicate()] + ) + state_dict = dict( + zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2]) + ) load_state_dict(state_dict, ckpt_path()) # check cur_rank = paddle.distributed.get_rank() expect_w1 = saved_w1.split(4, axis=0)[cur_rank] expect_w2 = sharded_w2 - expect_state_dict = dict(zip(list(global_state_dict.keys()), [expect_w1, expect_w2])) + expect_state_dict = dict( + zip(list(global_state_dict.keys()), [expect_w1, expect_w2]) + ) for k, v in state_dict.items(): assert k in expect_state_dict, k print(f"k:{k}, v:{v}, expect_state_dict[k]:{expect_state_dict[k]}") self.check_tensor_eq(v._local_value(), expect_state_dict[k]) - - def test_load_state_dict_with_less_cards(self): + + def test_load_state_dict_with_two_devices(self): global_state_dict = get_global_state_dict() saved_w1, saved_w2 = list(global_state_dict.values()) w1 = paddle.zeros_like(saved_w1) w2 = paddle.zeros_like(saved_w2) - mesh = dist.ProcessMesh([0,1]) + mesh = dist.ProcessMesh([0, 1]) sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)]) sharded_w2 = dist.shard_tensor(w2, mesh, [dist.Shard(1)]) - state_dict = dict(zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2])) + state_dict = dict( + zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2]) + ) load_state_dict(state_dict, ckpt_path()) # check cur_rank = paddle.distributed.get_rank() expect_w1 = saved_w1.split(2, axis=0)[cur_rank] expect_w2 = saved_w2.split(2, axis=1)[cur_rank] - expect_state_dict = dict(zip(list(global_state_dict.keys()), [expect_w1, expect_w2])) + expect_state_dict = dict( + zip(list(global_state_dict.keys()), [expect_w1, expect_w2]) + ) for k, v in state_dict.items(): assert k in expect_state_dict, k print(f"k:{k}, v:{v}, expect_state_dict[k]:{expect_state_dict[k]}") self.check_tensor_eq(v._local_value(), expect_state_dict[k]) - def test_load_state_dict_with_more_cards(self): + def test_load_state_dict_with_eight_devices(self): global_state_dict = get_global_state_dict() saved_w1, saved_w2 = list(global_state_dict.values()) w1 = paddle.zeros_like(saved_w1) w2 = paddle.zeros_like(saved_w2) - mesh = dist.ProcessMesh([[0,1,2,3], [4,5,6,7]]) + mesh = dist.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(1), dist.Shard(0)]) sharded_w2 = dist.shard_tensor(w2, mesh, [dist.Shard(0)]) - state_dict = dict(zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2])) + state_dict = dict( + zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2]) + ) load_state_dict(state_dict, ckpt_path()) # check cur_rank = paddle.distributed.get_rank() - local_shape, global_offset = compute_local_shape_and_global_offset(sharded_w1.shape, sharded_w1.dist_attr.process_mesh, sharded_w1.dist_attr.dims_mapping) - end_offset = [offset + length for offset, length in zip(global_offset, local_shape)] - print(f"local_shape:{local_shape}, global_offset:{global_offset}, end_offset:{end_offset}") - expect_w1 = paddle.slice(saved_w1, axes=[0, 1], starts=global_offset, ends=end_offset) - cur_coordinator = get_coordinator(np.array([[0,1,2,3], [4,5,6,7]]), cur_rank) + local_shape, global_offset = compute_local_shape_and_global_offset( + sharded_w1.shape, + sharded_w1.dist_attr.process_mesh, + sharded_w1.dist_attr.dims_mapping, + ) + end_offset = [ + offset + length + for offset, length in zip(global_offset, local_shape) + ] + print( + f"local_shape:{local_shape}, global_offset:{global_offset}, end_offset:{end_offset}" + ) + expect_w1 = paddle.slice( + saved_w1, axes=[0, 1], starts=global_offset, ends=end_offset + ) + cur_coordinator = get_coordinator( + np.array([[0, 1, 2, 3], [4, 5, 6, 7]]), cur_rank + ) expect_w2 = saved_w2.split(2, axis=0)[cur_coordinator[0]] - expect_state_dict = dict(zip(list(global_state_dict.keys()), [expect_w1, expect_w2])) + expect_state_dict = dict( + zip(list(global_state_dict.keys()), [expect_w1, expect_w2]) + ) for k, v in state_dict.items(): assert k in expect_state_dict, k print(f"k:{k}, v:{v}, expect_state_dict[k]:{expect_state_dict[k]}") @@ -89,21 +157,21 @@ def test_load_state_dict_with_more_cards(self): def check_tensor_eq(self, a, b, verbose=True): np1 = a.astype("float32").numpy() np2 = b.astype("float32").numpy() - np.testing.assert_equal( - np1, np2, verbose=verbose - ) - - + np.testing.assert_equal(np1, np2, verbose=verbose) + def run_test_case(self): device_num = int(os.getenv("device_num")) - if device_num == 2: - self.test_load_state_dict_with_less_cards() + if device_num == 1: + self.test_load_state_dict_with_one_device() + elif device_num == 2: + self.test_load_state_dict_with_two_devices() elif device_num == 4: - self.test_load_state_dict_with_same_cards() + self.test_load_state_dict_with_four_devices() elif device_num == 8: - self.test_load_state_dict_with_more_cards() + self.test_load_state_dict_with_eight_devices() else: raise ValueError("device_num should be 2,4 or 8") + if __name__ == '__main__': - TestLoadStateDict().run_test_case() \ No newline at end of file + TestLoadStateDict().run_test_case() diff --git a/test/auto_parallel/hybrid_strategy/save_state_dict.py b/test/auto_parallel/hybrid_strategy/save_state_dict.py index 6a38a6477779ca..09ab621f9366b7 100644 --- a/test/auto_parallel/hybrid_strategy/save_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/save_state_dict.py @@ -22,25 +22,45 @@ def get_global_state_dict(): w1 = paddle.arange(32).reshape([4, 8]) w2 = paddle.arange(32, 36).reshape([2, 2]) - return {"w1":w1, "w2":w2} + return {"w1": w1, "w2": w2} + def ckpt_path(): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp_ckpt_output") + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), "tmp_ckpt_output" + ) + class TestSaveStateDict: - def test_save_state_dict(self): + def test_save_state_dict_with_one_device(self): + global_state_dict = get_global_state_dict() + keys = list(global_state_dict.keys()) + w1, w2 = list(global_state_dict.values()) + state_dict = dict(zip(keys, [w1, w2])) + save_state_dict(state_dict, ckpt_path(), use_dist=False) + + def test_save_state_dict_with_four_devices(self): global_state_dict = get_global_state_dict() keys = list(global_state_dict.keys()) w1, w2 = list(global_state_dict.values()) - mesh = dist.ProcessMesh([0,1]) - mesh2 = dist.ProcessMesh([2,3]) - sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0), dist.Replicate()]) - sharded_w2 = dist.shard_tensor(w2, mesh2, [dist.Shard(0), dist.Replicate()]) + mesh = dist.ProcessMesh([0, 1]) + mesh2 = dist.ProcessMesh([2, 3]) + sharded_w1 = dist.shard_tensor( + w1, mesh, [dist.Shard(0), dist.Replicate()] + ) + sharded_w2 = dist.shard_tensor( + w2, mesh2, [dist.Shard(0), dist.Replicate()] + ) state_dict = dict(zip(keys, [sharded_w1, sharded_w2])) save_state_dict(state_dict, ckpt_path()) - + def run_test_case(self): - self.test_save_state_dict() + device_num = int(os.getenv("device_num")) + if device_num == 1: + self.test_save_state_dict_with_one_device() + elif device_num == 4: + self.test_save_state_dict_with_four_devices() + if __name__ == "__main__": TestSaveStateDict().run_test_case() diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py index 3d8974df0784e1..0eb96a5877e01d 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py @@ -43,7 +43,6 @@ def test_dp_mp_demo_net(self): DemoNet("dp_mp_hybrid_strategy"), self._mesh, self.shard_fn ) - ( self.dp_mp_loss, self.dp_mp_parameters, @@ -71,7 +70,6 @@ def test_dp_mp_demo_net(self): self.check_tensor_eq(v._local_value(), local_state_dict[k]) os.system(f"rm -rf {ckpt_path}") - def run_test_case(self): self.test_dp_mp_demo_net() diff --git a/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py b/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py index 1be29d3dd70328..32a2b6b4dc15a4 100644 --- a/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py @@ -18,28 +18,60 @@ import collective.test_communication_api_base as test_base from auto_parallel.hybrid_strategy.save_state_dict import ckpt_path + class TestSaveLoadStateDict(test_base.CommunicationTestDistBase): def setUp(self): self._default_envs = {} - self._changeable_envs = {"device_num": ['2','4','8']} + self._changeable_envs = {"device_num": ["1", "2", "4", "8"]} def test_save_load_state_dict(self): + # save with 1 device + os.system(f"rm -rf {ckpt_path()}") + super().setUp(num_of_devices=1, timeout=120, nnode=1) + self.run_test_case( + "save_state_dict.py", user_defined_envs={"device_num": "1"} + ) + + # load with 1, 2, 4, 8 devices + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + super().setUp( + save_log_dir="./log", + num_of_devices=int(envs["device_num"]), + timeout=120, + nnode=1, + ) + self.run_test_case( + "load_state_dict.py", + user_defined_envs=envs, + ) + os.system(f"rm -rf {ckpt_path()}") + # save with 4 devices os.system(f"rm -rf {ckpt_path()}") super().setUp(num_of_devices=4, timeout=120, nnode=1) - self.run_test_case("save_state_dict.py") - # load with 2, 4, 8 devices + self.run_test_case( + "save_state_dict.py", user_defined_envs={"device_num": "4"} + ) + # load with 1, 2, 4, 8 devices envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs ) for envs in envs_list: - super().setUp(save_log_dir="./log", num_of_devices=int(envs["device_num"]), timeout=120, nnode=1) + super().setUp( + save_log_dir="./log", + num_of_devices=int(envs["device_num"]), + timeout=120, + nnode=1, + ) self.run_test_case( "load_state_dict.py", user_defined_envs=envs, ) os.system(f"rm -rf {ckpt_path()}") + if __name__ == '__main__': unittest.main() - From 160552cbcf80dec6b1de9af58ce38753876672a2 Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Mon, 4 Dec 2023 17:35:15 +0800 Subject: [PATCH 12/24] fix test --- .../semi_auto_parallel_simple_net_dp_mp.py | 8 ++++---- .../semi_auto_parallel_simple_net_dp_mp_pp.py | 18 ++++++++++++++++++ .../test_save_load_state_dict.py | 2 -- .../test_semi_auto_parallel_hybrid_strategy.py | 4 ++++ 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py index 0eb96a5877e01d..b4ba198c65b3c1 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py @@ -14,6 +14,7 @@ import os +from auto_parallel.hybrid_strategy.save_state_dict import ckpt_path from auto_parallel.semi_auto_parallel_simple_net import ( DemoNet, TestSimpleNetForSemiAutoParallel, @@ -56,19 +57,18 @@ def test_dp_mp_demo_net(self): self.check_tensor_eq(param.grad, param_base.grad) # save load - ckpt_path = "/ckpt_output/" state_dict = model.state_dict() local_state_dict = {} for k, v in state_dict.items(): local_state_dict[k] = v._local_value().clone() - paddle.distributed.save_state_dict(state_dict, ckpt_path) + paddle.distributed.save_state_dict(state_dict, ckpt_path()) for k, v in state_dict.items(): v._local_value().add_(paddle.ones_like(v._local_value())) - paddle.distributed.load_state_dict(state_dict, ckpt_path) + paddle.distributed.load_state_dict(state_dict, ckpt_path()) for k, v in state_dict.items(): assert k in local_state_dict, k self.check_tensor_eq(v._local_value(), local_state_dict[k]) - os.system(f"rm -rf {ckpt_path}") + os.system(f"rm -rf {ckpt_path()}") def run_test_case(self): self.test_dp_mp_demo_net() diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py index ecac26ee46d86d..2634f78c7b30ab 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py @@ -14,6 +14,7 @@ import os +from auto_parallel.hybrid_strategy.save_state_dict import ckpt_path from auto_parallel.semi_auto_parallel_simple_net import ( DemoNet, TestSimpleNetForSemiAutoParallel, @@ -103,6 +104,23 @@ def test_dp_mp_pp_demo_net(self): self.dp_mp_pp_parameters[3], self.base_parameters[3] ) + # save load + state_dict = model.state_dict() + local_state_dict = {} + for k, v in state_dict.items(): + local_state_dict[k] = ( + v._local_value().clone() if v._is_initialized() else None + ) + paddle.distributed.save_state_dict(state_dict, ckpt_path()) + for k, v in state_dict.items(): + v._local_value().add_(paddle.ones_like(v._local_value())) + paddle.distributed.load_state_dict(state_dict, ckpt_path()) + for k, v in state_dict.items(): + assert k in local_state_dict, k + if v._is_initialized(): + self.check_tensor_eq(v._local_value(), local_state_dict[k]) + os.system(f"rm -rf {ckpt_path()}") + def run_test_case(self): self.test_dp_mp_pp_demo_net() diff --git a/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py b/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py index 32a2b6b4dc15a4..79611b28fc1e72 100644 --- a/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py @@ -38,7 +38,6 @@ def test_save_load_state_dict(self): ) for envs in envs_list: super().setUp( - save_log_dir="./log", num_of_devices=int(envs["device_num"]), timeout=120, nnode=1, @@ -61,7 +60,6 @@ def test_save_load_state_dict(self): ) for envs in envs_list: super().setUp( - save_log_dir="./log", num_of_devices=int(envs["device_num"]), timeout=120, nnode=1, diff --git a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py index 4da980bb466cfa..defaf6227f03ea 100644 --- a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py +++ b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest import collective.test_communication_api_base as test_base +from auto_parallel.hybrid_strategy.save_state_dict import ckpt_path class TestSemiAutoParallelDPMPStrategy(test_base.CommunicationTestDistBase): @@ -25,6 +27,7 @@ def setUp(self): "seed": "2023", } self._changeable_envs = {"backend": ["gpu"]} + os.system(f"rm -rf {ckpt_path()}") def test_simple_net_bybrid_strategy(self): envs_list = test_base.gen_product_envs_list( @@ -49,6 +52,7 @@ def setUp(self): "seed": "2023", } self._changeable_envs = {"backend": ["gpu"]} + os.system(f"rm -rf {ckpt_path()}") def test_simple_net_bybrid_strategy(self): envs_list = test_base.gen_product_envs_list( From baf2b745f6e740cf796735cbd3e01ad3f1a39b62 Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Mon, 4 Dec 2023 19:03:41 +0800 Subject: [PATCH 13/24] info to debug --- .../distributed/checkpoint/load_state_dict.py | 65 ++++++++++--------- .../distributed/checkpoint/save_state_dict.py | 8 +-- 2 files changed, 36 insertions(+), 37 deletions(-) diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index ef5471f48b4150..82d2a9543a3d0f 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import os from dataclasses import dataclass from typing import Tuple @@ -69,7 +70,7 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): for files in global_data_files: tmp += files global_data_files_set = set(tmp) - logger.info( + logger.debug( f"necessary_data_files_set:{necessary_data_files_set}, global_data_files_set:{global_data_files_set}" ) # check neccesary files in global_data_files @@ -78,7 +79,10 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): == necessary_data_files_set ), f"The checkpoint files are not complete. Please check the checkpoint directory:{path}.global_data_files_set:{global_data_files_set}, necessary_data_files_set:{necessary_data_files_set}" missing_keys = set(state_dict.keys()) - set(tensor_id_list) - logger.info(f"missing_keys:{missing_keys}") + if len(missing_keys) > 0: + logger.warning( + f"Missing keys:{missing_keys}, check whether the checkpoint is complete." + ) rank_to_files = {} for rank, local_files in enumerate(global_data_files): @@ -87,7 +91,7 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): f for f in local_files if f in necessary_data_files_set ] rank_to_files[rank] = local_files - logger.info(f"mapping rank_to_files:{rank_to_files}") + logger.debug(f"mapping rank_to_files:{rank_to_files}") return rank_to_files @@ -111,17 +115,18 @@ def get_local_load_files(rank_to_files): if file not in file_to_ranks: file_to_ranks[file] = [] file_to_ranks[file].append(rank) - rank_to_read_files = {rank: [] for rank in rank_to_files.keys()} + rank_to_not_read_files = copy.copy(rank_to_files) + rank_to_read_files = {rank: [] for rank in rank_to_not_read_files.keys()} for file, ranks in file_to_ranks.items(): if len(ranks) == 1: rank = ranks[0] rank_to_read_files[rank].append(file) - rank_to_files[rank].remove(file) - if len(rank_to_files[rank]) == 0: - rank_to_files.pop(rank) + rank_to_not_read_files[rank].remove(file) + if len(rank_to_not_read_files[rank]) == 0: + rank_to_not_read_files.pop(rank) - logger.info( - f"start rank_to_read_files:{rank_to_read_files}, rank_to_files:{rank_to_files}" + logger.debug( + f"rank_to_read_files:{rank_to_read_files}, rank_to_not_read_files:{rank_to_not_read_files}" ) def get_least_read_files_ranks(rank_to_read_files): @@ -132,28 +137,28 @@ def get_least_read_files_ranks(rank_to_read_files): ranks = [rank for rank, num in nums if num == nums[0][1]] return ranks - def get_read_rank_file(rank_to_files, ranks): - if len(rank_to_files) == 0: + def get_read_rank_file(rank_to_not_read_files, ranks): + if len(rank_to_not_read_files) == 0: return (None, None) nums = [ (rank, len(files)) - for rank, files in rank_to_files.items() + for rank, files in rank_to_not_read_files.items() if rank in ranks ] nums = sorted(nums, key=lambda x: x[1]) rank = nums[0][0] - return (rank, rank_to_files[rank][0]) + return (rank, rank_to_not_read_files[rank][0]) - def update(rank_to_read_files, rank_to_files, rank_file): + def update(rank_to_read_files, rank_to_not_read_files, rank_file): rank, file = rank_file if rank is None and file is None: return if rank not in rank_to_read_files: rank_to_read_files[rank] = [] rank_to_read_files[rank].append(file) - # update rank_to_files + # update rank_to_not_read_files file_to_ranks = {} - for r, files in rank_to_files.items(): + for r, files in rank_to_not_read_files.items(): for f in files: if f not in file_to_ranks: file_to_ranks[f] = [] @@ -161,26 +166,22 @@ def update(rank_to_read_files, rank_to_files, rank_file): logger.info(f"file_to_ranks:{file_to_ranks}") if file in file_to_ranks: for r in file_to_ranks[file]: - rank_to_files[r].remove(file) - if len(rank_to_files[r]) == 0: - rank_to_files.pop(r) + rank_to_not_read_files[r].remove(file) + if len(rank_to_not_read_files[r]) == 0: + rank_to_not_read_files.pop(r) - while len(rank_to_files) > 0: + while len(rank_to_not_read_files) > 0: ranks = get_least_read_files_ranks(rank_to_read_files) - rank_file = get_read_rank_file(rank_to_files, ranks) - update(rank_to_read_files, rank_to_files, rank_file) - logger.info( - f"update rank_to_read_files:{rank_to_read_files}, rank_to_files:{rank_to_files}, ranks:{ranks}, rank_file:{rank_file}" + rank_file = get_read_rank_file(rank_to_not_read_files, ranks) + update(rank_to_read_files, rank_to_not_read_files, rank_file) + logger.debug( + f"update rank_to_read_files:{rank_to_read_files}, rank_to_not_read_files:{rank_to_not_read_files}, ranks:{ranks}, rank_file:{rank_file}" ) - logger.info(f"rank_to_read_files:{rank_to_read_files}") cur_rank = paddle.distributed.get_rank() if cur_rank in rank_to_read_files: - logger.info( - f"cur_rank:{cur_rank}, rank_to_read_files[cur_rank]:{rank_to_read_files[cur_rank]}" - ) return rank_to_read_files[cur_rank] else: - logger.info(f"rank:{cur_rank} does not need to load checkpoint") + logger.warning(f"rank:{cur_rank} does not need to load checkpoint") return [] @@ -285,7 +286,7 @@ def get_read_items(path, state_dict, process_group, use_dist): storage_state_dict_metadata[tensor_id] = [] storage_state_dict_metadata[tensor_id] += local_tensor_metadata read_items = [] - logger.info(f"storage_state_dict_metadata:{storage_state_dict_metadata}") + logger.debug(f"storage_state_dict_metadata:{storage_state_dict_metadata}") for tensor_id, val in state_dict.items(): if isinstance(val, paddle.Tensor): if val.is_dist(): @@ -389,7 +390,7 @@ def load_state_dict( # slice the storage local tensor in (storage_offsets, lengths) to assign the current tensor in (cur_offsets, lengths) in rank. read_items = get_read_items(path, state_dict, process_group, use_dist) storage_file_to_state_dict = {} - logger.info( + logger.debug( f"before load, state_dict:{state_dict},\n load_infos:{load_infos},\n read_items:{read_items}" ) for item in read_items: @@ -474,6 +475,6 @@ def load_state_dict( if use_dist else state_dict ) - logger.info( + logger.debug( f"after load, local_state_dict:{local_state_dict} \n state_dict:{state_dict}" ) diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index 9c46a10ac5c890..1f121727c3bf7e 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -121,7 +121,7 @@ def save_state_dict( if not os.path.exists(os.path.join(path, file_name)): break unique_id += 1 - logger.info(f"file_name:{file_name}") + logger.debug(f"file_name:{file_name}") if use_dist: check_file_name(file_name, process_group) # the parameter_name and order in state_dict should be the same @@ -174,9 +174,7 @@ def save_state_dict( metadata.state_dict_metadata = merge_state_dict(global_tensor_metadata) metadata.storage_metadata = dedup_state_dict(global_storage_metadata) if coordinator_rank == paddle.distributed.get_rank(): - logger.info(f"global_tensor_metadata:{global_tensor_metadata}") - logger.info(f"global_storage_metadata:{global_storage_metadata}") - logger.info(f"metadata:{metadata}") + logger.debug(f"metadata:{metadata}") paddle.save(metadata, os.path.join(path, f"{unique_id}.metadata")) - logger.info(f"local_state_dict:{local_state_dict}") + logger.debug(f"local_state_dict:{local_state_dict}") paddle.save(local_state_dict, os.path.join(path, file_name)) From 968d611b1faa6b92fb0c1e327de99ca84716fb0b Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Tue, 5 Dec 2023 11:00:09 +0800 Subject: [PATCH 14/24] fix test --- test/auto_parallel/hybrid_strategy/testslist.csv | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/auto_parallel/hybrid_strategy/testslist.csv b/test/auto_parallel/hybrid_strategy/testslist.csv index 0820f4611e2a58..250c0b18e8ff9b 100644 --- a/test/auto_parallel/hybrid_strategy/testslist.csv +++ b/test/auto_parallel/hybrid_strategy/testslist.csv @@ -1,3 +1,3 @@ name,os,arch,timeout,run_type,launcher,num_port,run_serial,envs,conditions test_semi_auto_parallel_hybrid_strategy,LINUX,GPU,120,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_save_load_state_dict.py,LINUX,GPU,120,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_save_load_state_dict,LINUX,GPU,120,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., From 170fd81cd068e2dc6a7e0b5ac8694eed1dbd5e93 Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Tue, 5 Dec 2023 11:07:01 +0800 Subject: [PATCH 15/24] fix --- test/auto_parallel/hybrid_strategy/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/auto_parallel/hybrid_strategy/CMakeLists.txt b/test/auto_parallel/hybrid_strategy/CMakeLists.txt index 1dbd27467102b9..a1759193941f20 100644 --- a/test/auto_parallel/hybrid_strategy/CMakeLists.txt +++ b/test/auto_parallel/hybrid_strategy/CMakeLists.txt @@ -11,6 +11,8 @@ if((WITH_GPU) AND (LINUX)) "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") set_tests_properties(test_semi_auto_parallel_hybrid_strategy PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID") +endif() +if((WITH_GPU) AND (LINUX)) py_test_modules( test_save_load_state_dict MODULES test_save_load_state_dict ENVS "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") From e0d0690810713973261238592105b798c53075b6 Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Tue, 5 Dec 2023 16:57:51 +0800 Subject: [PATCH 16/24] fix coverage ci --- paddle/scripts/paddle_build.sh | 6 +++--- tools/check_file_diff_approvals.sh | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index dea66f3f487f15..9a8efacf008679 100644 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -814,7 +814,7 @@ set -x fi if [ -a "$PADDLE_ROOT/added_ut" ];then added_uts=^$(awk BEGIN{RS=EOF}'{gsub(/\n/,"$|^");print}' $PADDLE_ROOT/added_ut)$ - ctest -R "(${added_uts})" -LE "RUN_TYPE=DIST|RUN_TYPE=EXCLUSIVE" --output-on-failure --repeat-until-fail 3 --timeout 15;added_ut_error=$? + ctest -R "(${added_uts})" -LE "RUN_TYPE=DIST|RUN_TYPE=EXCLUSIVE|RUN_TYPE=HYBRID" --output-on-failure --repeat-until-fail 3 --timeout 15;added_ut_error=$? ctest -R "(${added_uts})" -L "RUN_TYPE=DIST|RUN_TYPE=EXCLUSIVE" --output-on-failure --repeat-until-fail 3 --timeout 15;added_ut_error_1=$? if [ "$added_ut_error" != 0 ] && [ "$added_ut_error_1" != 0 ];then echo "========================================" @@ -1545,7 +1545,7 @@ set -x fi if [ -a "$PADDLE_ROOT/added_ut" ];then added_uts=^$(awk BEGIN{RS=EOF}'{gsub(/\n/,"$|^");print}' $PADDLE_ROOT/added_ut)$ - env CUDA_VISIBLE_DEVICES=0 ctest -R "(${added_uts})" -LE "RUN_TYPE=DIST|RUN_TYPE=EXCLUSIVE" --output-on-failure --repeat-until-fail 3 --timeout 15;added_ut_error=$? + env CUDA_VISIBLE_DEVICES=0 ctest -R "(${added_uts})" -LE "RUN_TYPE=DIST|RUN_TYPE=EXCLUSIVE|RUN_TYPE=HYBRID" --output-on-failure --repeat-until-fail 3 --timeout 15;added_ut_error=$? ctest -R "(${added_uts})" -L "RUN_TYPE=DIST|RUN_TYPE=EXCLUSIVE" --output-on-failure --repeat-until-fail 3 --timeout 15;added_ut_error_1=$? if [ "$added_ut_error" != 0 ] && [ "$added_ut_error_1" != 0 ];then echo "========================================" @@ -2544,7 +2544,7 @@ set -x fi if [ -a "$PADDLE_ROOT/added_ut" ];then added_uts=^$(awk BEGIN{RS=EOF}'{gsub(/\n/,"$|^");print}' $PADDLE_ROOT/added_ut)$ - env CUDA_VISIBLE_DEVICES=0 ctest -R "(${added_uts})" -LE "RUN_TYPE=DIST|RUN_TYPE=EXCLUSIVE" --output-on-failure --repeat-until-fail 3 --timeout 15;added_ut_error=$? + env CUDA_VISIBLE_DEVICES=0 ctest -R "(${added_uts})" -LE "RUN_TYPE=DIST|RUN_TYPE=EXCLUSIVE|RUN_TYPE=HYBRID" --output-on-failure --repeat-until-fail 3 --timeout 15;added_ut_error=$? ctest -R "(${added_uts})" -L "RUN_TYPE=DIST|RUN_TYPE=EXCLUSIVE" --output-on-failure --repeat-until-fail 3 --timeout 15;added_ut_error_1=$? if [ "$added_ut_error" != 0 ] && [ "$added_ut_error_1" != 0 ];then echo "========================================" diff --git a/tools/check_file_diff_approvals.sh b/tools/check_file_diff_approvals.sh index 563075e1bcf189..bd768fd9c686e7 100644 --- a/tools/check_file_diff_approvals.sh +++ b/tools/check_file_diff_approvals.sh @@ -553,13 +553,13 @@ RUNTYPE_FILE_CHANGED=`git diff --name-only --diff-filter=AM upstream/$BRANCH|gre if [ "${RUNTYPE_FILE_CHANGED}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then for CMAKELISTS_FILE in ${RUNTYPE_FILE_CHANGED}; do - RUNTYPE_ADD=`git diff -U0 upstream/$BRANCH ${PADDLE_ROOT}/${CMAKELISTS_FILE} |grep "^+" |grep -E "SERIAL|RUN_TYPE=EXCLUSIVE|RUN_TYPE=DIST|RUN_TYPE=NIGHTLY|RUN_TYPE=EXCLUSIVE:NIGHTLY|RUN_TYPE=DIST:NIGHTLY|PROPERTIES[[:space:]]+TIMEOUT" || true` + RUNTYPE_ADD=`git diff -U0 upstream/$BRANCH ${PADDLE_ROOT}/${CMAKELISTS_FILE} |grep "^+" |grep -E "SERIAL|RUN_TYPE=EXCLUSIVE|RUN_TYPE=DIST|RUN_TYPE=HYBRID|RUN_TYPE=NIGHTLY|RUN_TYPE=EXCLUSIVE:NIGHTLY|RUN_TYPE=DIST:NIGHTLY|PROPERTIES[[:space:]]+TIMEOUT" || true` if [[ ${RUNTYPE_ADD} != "" ]];then RUNTYPE_ADD_LINES="${RUNTYPE_ADD_LINES}\n${CMAKELISTS_FILE}\n${RUNTYPE_ADD}\n" fi done if [[ ${RUNTYPE_ADD_LINES} != "" ]];then - echo_line="You must have one QA (XieYunshen(Recommend) or chalsliu) approval for setting parameter RUN_TYPE as EXCLUSIVE, DIST, NIGHTLY, EXCLUSIVE:NIGHTLY or DISTNIGHTLY, or setting parameter SERIAL, or setting TIMEOUT properties.\nThe corresponding lines are as follows:\n${RUNTYPE_ADD_LINES}\nFor more information, please refer to:https://github.com/PaddlePaddle/Paddle/wiki/PaddlePaddle-Unit-test-specification" + echo_line="You must have one QA (XieYunshen(Recommend) or chalsliu) approval for setting parameter RUN_TYPE as EXCLUSIVE, DIST, HYBRID, NIGHTLY, EXCLUSIVE:NIGHTLY or DISTNIGHTLY, or setting parameter SERIAL, or setting TIMEOUT properties.\nThe corresponding lines are as follows:\n${RUNTYPE_ADD_LINES}\nFor more information, please refer to:https://github.com/PaddlePaddle/Paddle/wiki/PaddlePaddle-Unit-test-specification" check_approval 1 XieYunshen chalsliu fi fi From 18298b9c18a622e910c869b688b849874da64b07 Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Tue, 5 Dec 2023 17:33:59 +0800 Subject: [PATCH 17/24] fix docstring codes --- .../distributed/checkpoint/load_state_dict.py | 21 +++++++++++++++++-- .../distributed/checkpoint/save_state_dict.py | 11 +++++++--- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index 82d2a9543a3d0f..7099375ed135de 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -359,8 +359,25 @@ def load_state_dict( use_dist: Whether to load the state_dict in distributed mode. Set True by default. Example: .. code-block:: python - import paddle - ... + >>> # doctest: +REQUIRES(env: DISTRIBUTED) + >>> import paddle + >>> import paddle.distributed as dist + >>> ckpt_path = "./checkpoint" + >>> w1 = paddle.arange(32).reshape([4, 8]) + >>> mesh = dist.ProcessMesh([0, 1]) + >>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)]) + >>> state_dict = {"w1": sharded_w1} + >>> dist.save_state_dict(state_dict, ckpt_path) + >>> w1_to_load = paddle.zeros_like(w1) + >>> sharded_w1_to_load = dist.shard_tensor(w1, mesh, [dist.Replicate()]) + >>> state_dict_to_load = {"w1": sharded_w1_to_load} + >>> dist.load_state_dict(state_dict_to_load, ckpt_path) + >>> print(f"state_dict_to_load:{state_dict_to_load}") + state_dict_to_load:{'w1': Tensor(shape=[4, 8], dtype=int64, place=Place(gpu:0), stop_gradient=True, dist_attr={process_mesh: {shape: [2], process_ids: [0,1], dim_names: [d0]}, dims_mappings: [-1,-1], batch_dim: 0, dynamic_dims: [0,0], annotated: [dims_mapping: 1,process_mesh: 1], partial: [].}, GlobalDenseTensor= + [[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], + [8 , 9 , 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31]])} """ if not use_dist and ( paddle.distributed.get_world_size() > 1 or coordinator_rank != 0 diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index 1f121727c3bf7e..2bb640bc955848 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -86,9 +86,14 @@ def save_state_dict( Examples: .. code-block:: python - - import paddle - ... + >>> # doctest: +REQUIRES(env: DISTRIBUTED) + >>> import paddle + >>> import paddle.distributed as dist + >>> w1 = paddle.arange(32).reshape([4, 8]) + >>> mesh = dist.ProcessMesh([0, 1]) + >>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0), dist.Replicate()]) + >>> state_dict = {"w1": sharded_w1} + >>> dist.save_state_dict(state_dict, "./checkpoint") """ if not use_dist and ( From 1dcd0a7faf3c341e7f556931e9c04c0fe9cb63db Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Wed, 6 Dec 2023 11:24:56 +0800 Subject: [PATCH 18/24] rename and codestyle --- .../distributed/checkpoint/load_state_dict.py | 39 ++++++++++--------- .../paddle/distributed/checkpoint/metadata.py | 2 +- .../distributed/checkpoint/save_state_dict.py | 25 ++++++------ .../hybrid_strategy/load_state_dict.py | 7 ---- 4 files changed, 35 insertions(+), 38 deletions(-) diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index 7099375ed135de..0c6a4f6c0ae424 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -50,9 +50,9 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): for local_tensor_index, file_name in metadata.storage_metadata.items(): assert ( local_tensor_index not in tensor_id_list - ), f"Duplicate tensor_id:{local_tensor_index} found. Check whether the metadata_file:{metadata_file} contains the same tensor metadata." - tensor_id_list.append(local_tensor_index.tensor_id) - if local_tensor_index.tensor_id in state_dict: + ), f"Duplicate tensor_key:{local_tensor_index} found. Check whether the metadata_file:{metadata_file} contains the same tensor metadata." + tensor_id_list.append(local_tensor_index.tensor_key) + if local_tensor_index.tensor_key in state_dict: necessary_files.append(file_name) necessary_data_files_set = set(necessary_files) # allgather all accessible files @@ -279,15 +279,15 @@ def get_read_items(path, state_dict, process_group, use_dist): for metadata_file in metadata_files: metadata = paddle.load(os.path.join(path, metadata_file)) for ( - tensor_id, + tensor_key, local_tensor_metadata, ) in metadata.state_dict_metadata.items(): - if tensor_id not in storage_state_dict_metadata: - storage_state_dict_metadata[tensor_id] = [] - storage_state_dict_metadata[tensor_id] += local_tensor_metadata + if tensor_key not in storage_state_dict_metadata: + storage_state_dict_metadata[tensor_key] = [] + storage_state_dict_metadata[tensor_key] += local_tensor_metadata read_items = [] logger.debug(f"storage_state_dict_metadata:{storage_state_dict_metadata}") - for tensor_id, val in state_dict.items(): + for tensor_key, val in state_dict.items(): if isinstance(val, paddle.Tensor): if val.is_dist(): ( @@ -305,10 +305,10 @@ def get_read_items(path, state_dict, process_group, use_dist): continue cur_chunk_metadata = LocalTensorMetadata(global_offset, local_shape) assert ( - tensor_id in storage_state_dict_metadata - ), f"tensor_id:{tensor_id} not found in storage_state_dict_metadata:{storage_state_dict_metadata}." + tensor_key in storage_state_dict_metadata + ), f"tensor_key:{tensor_key} not found in storage_state_dict_metadata:{storage_state_dict_metadata}." for storage_local_tensor_metadata in storage_state_dict_metadata[ - tensor_id + tensor_key ]: if not_overlap( cur_chunk_metadata, storage_local_tensor_metadata @@ -318,7 +318,7 @@ def get_read_items(path, state_dict, process_group, use_dist): cur_chunk_metadata, storage_local_tensor_metadata ) storage_local_tensor_index = LocalTensorIndex( - tensor_id, + tensor_key, tuple(storage_local_tensor_metadata.global_offset), ) read_items.append( @@ -359,7 +359,7 @@ def load_state_dict( use_dist: Whether to load the state_dict in distributed mode. Set True by default. Example: .. code-block:: python - >>> # doctest: +REQUIRES(env: DISTRIBUTED) + >>> # doctest: +SKIP('Load state dict.') >>> import paddle >>> import paddle.distributed as dist >>> ckpt_path = "./checkpoint" @@ -378,6 +378,7 @@ def load_state_dict( [8 , 9 , 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31]])} + >>> # doctest: -SKIP """ if not use_dist and ( paddle.distributed.get_world_size() > 1 or coordinator_rank != 0 @@ -425,9 +426,9 @@ def load_state_dict( os.path.join(path, file_name) ) storage_state_dict = storage_file_to_state_dict[file_name] - assert item.local_tensor_index.tensor_id in storage_state_dict + assert item.local_tensor_index.tensor_key in storage_state_dict storage_local_tensor = storage_state_dict[ - item.local_tensor_index.tensor_id + item.local_tensor_index.tensor_key ] storage_offsets = item.storage_offset storage_lengths = item.lengths @@ -447,12 +448,12 @@ def load_state_dict( # The read item rank need to be assigned if item.rank == paddle.distributed.get_rank(): assert ( - item.local_tensor_index.tensor_id in state_dict + item.local_tensor_index.tensor_key in state_dict ), f"item:{item}, state_dict:{state_dict}" cur_local_tensor = ( - state_dict[item.local_tensor_index.tensor_id]._local_value() + state_dict[item.local_tensor_index.tensor_key]._local_value() if use_dist - else state_dict[item.local_tensor_index.tensor_id] + else state_dict[item.local_tensor_index.tensor_key] ) cur_offsets = item.cur_offset cur_lengths = item.lengths @@ -470,7 +471,7 @@ def load_state_dict( else: cur_chunk_tensor = paddle.zeros( item.lengths, - dtype=state_dict[item.local_tensor_index.tensor_id].dtype, + dtype=state_dict[item.local_tensor_index.tensor_key].dtype, ) if src_rank == item.rank: diff --git a/python/paddle/distributed/checkpoint/metadata.py b/python/paddle/distributed/checkpoint/metadata.py index 74a17a0a63ada8..fb9a39a559f554 100644 --- a/python/paddle/distributed/checkpoint/metadata.py +++ b/python/paddle/distributed/checkpoint/metadata.py @@ -32,7 +32,7 @@ class LocalTensorIndex: The identifier of a local tensor. """ - tensor_id: str + tensor_key: str global_offset: Tuple[int] diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index 2bb640bc955848..8ab8b0cf0aceb7 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -45,7 +45,7 @@ def check_file_name(file_name, process_group): ), f"id:{id} != all_unique_id[0]:{file_name}" -def merge_state_dict(global_state_dict): +def merge_state_dict_metadata(global_state_dict): assert isinstance( global_state_dict, List ), "The global_state_dict should be a list." @@ -61,7 +61,7 @@ def merge_state_dict(global_state_dict): return out -def dedup_state_dict(global_state_dict): +def dedup_storage_metadata(global_state_dict): out = {} for state_dict in global_state_dict: for key, val in state_dict.items(): @@ -81,12 +81,12 @@ def save_state_dict( state_dict: The state_dict to save. path: The directory to save state_dict. process_group: ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards. - coordinator_rank: The rank used to coordinate the checkpoint. Rank0 is used by default. + coordinator_rank: The rank used to save non distributed values. Rank0 is used by default. use_dist: Whether to save the state_dict in distributed mode. Set True by default. Examples: .. code-block:: python - >>> # doctest: +REQUIRES(env: DISTRIBUTED) + >>> # doctest: +SKIP('Save state dict.') >>> import paddle >>> import paddle.distributed as dist >>> w1 = paddle.arange(32).reshape([4, 8]) @@ -94,6 +94,7 @@ def save_state_dict( >>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0), dist.Replicate()]) >>> state_dict = {"w1": sharded_w1} >>> dist.save_state_dict(state_dict, "./checkpoint") + >>> # doctest: -SKIP """ if not use_dist and ( @@ -133,7 +134,7 @@ def save_state_dict( check_state_dict(state_dict, process_group) metadata = Metadata() local_state_dict = {} - local_tensor_metadata = {} + local_state_dict_metadata = {} local_storage_metadata = {} for key, val in state_dict.items(): if isinstance(val, paddle.Tensor): @@ -157,27 +158,29 @@ def save_state_dict( local_shape = val.shape local_tensor = val local_state_dict[key] = local_tensor - local_tensor_metadata[key] = LocalTensorMetadata( + local_state_dict_metadata[key] = LocalTensorMetadata( global_offset, local_shape ) local_storage_metadata[ LocalTensorIndex(key, tuple(global_offset)) ] = file_name - global_tensor_metadata = [] + global_state_dict_metadata = [] global_storage_metadata = [] if use_dist: paddle.distributed.all_gather_object( - global_tensor_metadata, local_tensor_metadata, process_group + global_state_dict_metadata, local_state_dict_metadata, process_group ) paddle.distributed.all_gather_object( global_storage_metadata, local_storage_metadata, process_group ) else: - global_tensor_metadata.append(local_tensor_metadata) + global_state_dict_metadata.append(local_state_dict_metadata) global_storage_metadata.append(local_storage_metadata) - metadata.state_dict_metadata = merge_state_dict(global_tensor_metadata) - metadata.storage_metadata = dedup_state_dict(global_storage_metadata) + metadata.state_dict_metadata = merge_state_dict_metadata( + global_state_dict_metadata + ) + metadata.storage_metadata = dedup_storage_metadata(global_storage_metadata) if coordinator_rank == paddle.distributed.get_rank(): logger.debug(f"metadata:{metadata}") paddle.save(metadata, os.path.join(path, f"{unique_id}.metadata")) diff --git a/test/auto_parallel/hybrid_strategy/load_state_dict.py b/test/auto_parallel/hybrid_strategy/load_state_dict.py index 0f5d2f65e89ac6..be9fa35eee33bf 100644 --- a/test/auto_parallel/hybrid_strategy/load_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/load_state_dict.py @@ -58,7 +58,6 @@ def test_load_state_dict_with_one_device(self): ) for k, v in state_dict.items(): assert k in expect_state_dict, k - print(f"k:{k}, v:{v}, expect_state_dict[k]:{expect_state_dict[k]}") self.check_tensor_eq(v, expect_state_dict[k]) def test_load_state_dict_with_four_devices(self): @@ -86,7 +85,6 @@ def test_load_state_dict_with_four_devices(self): ) for k, v in state_dict.items(): assert k in expect_state_dict, k - print(f"k:{k}, v:{v}, expect_state_dict[k]:{expect_state_dict[k]}") self.check_tensor_eq(v._local_value(), expect_state_dict[k]) def test_load_state_dict_with_two_devices(self): @@ -110,7 +108,6 @@ def test_load_state_dict_with_two_devices(self): ) for k, v in state_dict.items(): assert k in expect_state_dict, k - print(f"k:{k}, v:{v}, expect_state_dict[k]:{expect_state_dict[k]}") self.check_tensor_eq(v._local_value(), expect_state_dict[k]) def test_load_state_dict_with_eight_devices(self): @@ -136,9 +133,6 @@ def test_load_state_dict_with_eight_devices(self): offset + length for offset, length in zip(global_offset, local_shape) ] - print( - f"local_shape:{local_shape}, global_offset:{global_offset}, end_offset:{end_offset}" - ) expect_w1 = paddle.slice( saved_w1, axes=[0, 1], starts=global_offset, ends=end_offset ) @@ -151,7 +145,6 @@ def test_load_state_dict_with_eight_devices(self): ) for k, v in state_dict.items(): assert k in expect_state_dict, k - print(f"k:{k}, v:{v}, expect_state_dict[k]:{expect_state_dict[k]}") self.check_tensor_eq(v._local_value(), expect_state_dict[k]) def check_tensor_eq(self, a, b, verbose=True): From c7284007dec9e17b4c556f9eac483d4aac4683b0 Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Wed, 6 Dec 2023 12:20:55 +0800 Subject: [PATCH 19/24] get rid of use_dist argument --- .../distributed/checkpoint/load_state_dict.py | 15 +++++---------- .../distributed/checkpoint/save_state_dict.py | 15 +++++---------- .../hybrid_strategy/load_state_dict.py | 2 +- .../hybrid_strategy/save_state_dict.py | 2 +- 4 files changed, 12 insertions(+), 22 deletions(-) diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index 0c6a4f6c0ae424..3e2ad5b64e5465 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -347,7 +347,7 @@ def get_read_items(path, state_dict, process_group, use_dist): def load_state_dict( - state_dict, path, process_group=None, coordinator_rank=0, use_dist=True + state_dict, path, process_group=None, coordinator_rank=0 ) -> None: """ Load the state_dict inplace from a checkpoint path. @@ -356,7 +356,6 @@ def load_state_dict( path: The directory to load checkpoint files. process_group: ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards. coordinator_rank: The rank used to coordinate the checkpoint. Rank0 is used by default. - use_dist: Whether to load the state_dict in distributed mode. Set True by default. Example: .. code-block:: python >>> # doctest: +SKIP('Load state dict.') @@ -380,12 +379,6 @@ def load_state_dict( [24, 25, 26, 27, 28, 29, 30, 31]])} >>> # doctest: -SKIP """ - if not use_dist and ( - paddle.distributed.get_world_size() > 1 or coordinator_rank != 0 - ): - raise ValueError( - f"use_dist is False, please set coordinator_rank to 0 and paddle.distributed.get_world_size() to 1, world_size:{paddle.distributed.get_world_size()}, coordinator_rank:{coordinator_rank}" - ) assert isinstance( state_dict, dict ), "The state_dict should be a dictionary." @@ -396,9 +389,11 @@ def load_state_dict( val, paddle.Tensor ), "Only support dygraph Tensor now, support static DistributedTensor later" - if use_dist and process_group is None: + use_dist = True if paddle.distributed.get_world_size() > 1 else False + + if use_dist and process_group is None and not is_initialized(): # Init the default global process group - not is_initialized() and paddle.distributed.init_parallel_env() + paddle.distributed.init_parallel_env() rank_to_files = get_rank_to_files(path, state_dict, process_group, use_dist) local_load_files = get_local_load_files(rank_to_files) diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index 8ab8b0cf0aceb7..6545dde5ba810d 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -72,7 +72,7 @@ def dedup_storage_metadata(global_state_dict): def save_state_dict( - state_dict, path, process_group=None, coordinator_rank=0, use_dist=True + state_dict, path, process_group=None, coordinator_rank=0 ) -> None: """ Save the state_dict of model to path. @@ -82,7 +82,6 @@ def save_state_dict( path: The directory to save state_dict. process_group: ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards. coordinator_rank: The rank used to save non distributed values. Rank0 is used by default. - use_dist: Whether to save the state_dict in distributed mode. Set True by default. Examples: .. code-block:: python @@ -97,12 +96,6 @@ def save_state_dict( >>> # doctest: -SKIP """ - if not use_dist and ( - paddle.distributed.get_world_size() > 1 or coordinator_rank != 0 - ): - raise ValueError( - f"use_dist is False, please set coordinator_rank to 0 and paddle.distributed.get_world_size() to 1, world_size:{paddle.distributed.get_world_size()}, coordinator_rank:{coordinator_rank}" - ) assert isinstance( state_dict, dict ), "The state_dict should be a dictionary." @@ -116,9 +109,11 @@ def save_state_dict( if not os.path.exists(path): os.makedirs(path, exist_ok=True) - if use_dist and process_group is None: + use_dist = True if paddle.distributed.get_world_size() > 1 else False + + if use_dist and process_group is None and not is_initialized(): # Init the default global process group - not is_initialized() and paddle.distributed.init_parallel_env() + paddle.distributed.init_parallel_env() unique_id = 0 file_name = "" diff --git a/test/auto_parallel/hybrid_strategy/load_state_dict.py b/test/auto_parallel/hybrid_strategy/load_state_dict.py index be9fa35eee33bf..944f11f99a9216 100644 --- a/test/auto_parallel/hybrid_strategy/load_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/load_state_dict.py @@ -49,7 +49,7 @@ def test_load_state_dict_with_one_device(self): w1 = paddle.zeros_like(saved_w1) w2 = paddle.zeros_like(saved_w2) state_dict = dict(zip(list(global_state_dict.keys()), [w1, w2])) - load_state_dict(state_dict, ckpt_path(), use_dist=False) + load_state_dict(state_dict, ckpt_path()) # check expect_w1 = saved_w1 expect_w2 = saved_w2 diff --git a/test/auto_parallel/hybrid_strategy/save_state_dict.py b/test/auto_parallel/hybrid_strategy/save_state_dict.py index 09ab621f9366b7..7d750c08f4665f 100644 --- a/test/auto_parallel/hybrid_strategy/save_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/save_state_dict.py @@ -37,7 +37,7 @@ def test_save_state_dict_with_one_device(self): keys = list(global_state_dict.keys()) w1, w2 = list(global_state_dict.values()) state_dict = dict(zip(keys, [w1, w2])) - save_state_dict(state_dict, ckpt_path(), use_dist=False) + save_state_dict(state_dict, ckpt_path()) def test_save_state_dict_with_four_devices(self): global_state_dict = get_global_state_dict() From a3125c094ee7442803792e5d75c1d67fa42c3b8b Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Wed, 6 Dec 2023 12:49:02 +0800 Subject: [PATCH 20/24] fix copyright --- python/paddle/distributed/checkpoint/__init__.py | 2 +- .../distributed/checkpoint/load_state_dict.py | 2 +- python/paddle/distributed/checkpoint/metadata.py | 2 +- .../distributed/checkpoint/save_state_dict.py | 2 +- python/paddle/distributed/checkpoint/utils.py | 2 +- .../hybrid_strategy/load_state_dict.py | 13 ------------- .../hybrid_strategy/save_state_dict.py | 2 +- .../hybrid_strategy/test_save_load_state_dict.py | 2 +- 8 files changed, 7 insertions(+), 20 deletions(-) diff --git a/python/paddle/distributed/checkpoint/__init__.py b/python/paddle/distributed/checkpoint/__init__.py index 63a317bd0a4b7d..da89b737adfb8d 100644 --- a/python/paddle/distributed/checkpoint/__init__.py +++ b/python/paddle/distributed/checkpoint/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index 3e2ad5b64e5465..dce468eda2507f 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/paddle/distributed/checkpoint/metadata.py b/python/paddle/distributed/checkpoint/metadata.py index fb9a39a559f554..4eb5d559a9c0c4 100644 --- a/python/paddle/distributed/checkpoint/metadata.py +++ b/python/paddle/distributed/checkpoint/metadata.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index 6545dde5ba810d..f88b9bf0b9d1fd 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/paddle/distributed/checkpoint/utils.py b/python/paddle/distributed/checkpoint/utils.py index 1262d5d9483517..32b95198d135ae 100644 --- a/python/paddle/distributed/checkpoint/utils.py +++ b/python/paddle/distributed/checkpoint/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/test/auto_parallel/hybrid_strategy/load_state_dict.py b/test/auto_parallel/hybrid_strategy/load_state_dict.py index 944f11f99a9216..93b088e7c33324 100644 --- a/test/auto_parallel/hybrid_strategy/load_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/load_state_dict.py @@ -12,19 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import os import numpy as np diff --git a/test/auto_parallel/hybrid_strategy/save_state_dict.py b/test/auto_parallel/hybrid_strategy/save_state_dict.py index 7d750c08f4665f..e908d233cc77d7 100644 --- a/test/auto_parallel/hybrid_strategy/save_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/save_state_dict.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py b/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py index 79611b28fc1e72..7122ad18fdbc3d 100644 --- a/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 0543d1f97d32c1e92601d8e3ed99229fad0ac795 Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Wed, 6 Dec 2023 14:52:01 +0800 Subject: [PATCH 21/24] polish doc --- .../distributed/checkpoint/load_state_dict.py | 23 +++++++++++-------- .../distributed/checkpoint/save_state_dict.py | 15 +++++++----- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index dce468eda2507f..38930ef5066bd3 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -15,7 +15,7 @@ import copy import os from dataclasses import dataclass -from typing import Tuple +from typing import Dict, Tuple import paddle from paddle.distributed.communication.group import is_initialized @@ -43,15 +43,15 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): len(metadata_files) > 0 ), "No metadata file found in the checkpoint directory:{path}." # The neccesary files to be read - tensor_id_list = [] + tensor_key_list = [] necessary_files = [] for metadata_file in metadata_files: metadata = paddle.load(os.path.join(path, metadata_file)) for local_tensor_index, file_name in metadata.storage_metadata.items(): assert ( - local_tensor_index not in tensor_id_list + local_tensor_index not in tensor_key_list ), f"Duplicate tensor_key:{local_tensor_index} found. Check whether the metadata_file:{metadata_file} contains the same tensor metadata." - tensor_id_list.append(local_tensor_index.tensor_key) + tensor_key_list.append(local_tensor_index.tensor_key) if local_tensor_index.tensor_key in state_dict: necessary_files.append(file_name) necessary_data_files_set = set(necessary_files) @@ -78,7 +78,7 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): global_data_files_set & necessary_data_files_set == necessary_data_files_set ), f"The checkpoint files are not complete. Please check the checkpoint directory:{path}.global_data_files_set:{global_data_files_set}, necessary_data_files_set:{necessary_data_files_set}" - missing_keys = set(state_dict.keys()) - set(tensor_id_list) + missing_keys = set(state_dict.keys()) - set(tensor_key_list) if len(missing_keys) > 0: logger.warning( f"Missing keys:{missing_keys}, check whether the checkpoint is complete." @@ -347,15 +347,18 @@ def get_read_items(path, state_dict, process_group, use_dist): def load_state_dict( - state_dict, path, process_group=None, coordinator_rank=0 + state_dict: Dict[str, paddle.Tensor], + path: str, + process_group: paddle.distributed.collective.Group = None, + coordinator_rank: int = 0, ) -> None: """ Load the state_dict inplace from a checkpoint path. Args: - state_dict: The state_dict to load. It will be modified inplace after loading. - path: The directory to load checkpoint files. - process_group: ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards. - coordinator_rank: The rank used to coordinate the checkpoint. Rank0 is used by default. + state_dict(Dict[str, paddle.Tensor]): The state_dict to load. It will be modified inplace after loading. + path(str): The directory to load checkpoint files. + process_group(paddle.distributed.collective.Group): ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards. + coordinator_rank(int): The rank used to coordinate the checkpoint. Rank0 is used by default. Example: .. code-block:: python >>> # doctest: +SKIP('Load state dict.') diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index f88b9bf0b9d1fd..f61e22ccd29d52 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -13,7 +13,7 @@ # limitations under the License. import os -from typing import List +from typing import Dict, List import paddle from paddle.distributed.communication.group import is_initialized @@ -72,16 +72,19 @@ def dedup_storage_metadata(global_state_dict): def save_state_dict( - state_dict, path, process_group=None, coordinator_rank=0 + state_dict: Dict[str, paddle.Tensor], + path: str, + process_group: paddle.distributed.collective.Group = None, + coordinator_rank: int = 0, ) -> None: """ Save the state_dict of model to path. Args: - state_dict: The state_dict to save. - path: The directory to save state_dict. - process_group: ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards. - coordinator_rank: The rank used to save non distributed values. Rank0 is used by default. + state_dict(Dict[str, paddle.Tensor]): The state_dict to save. + path(str): The directory to save state_dict. + process_group(paddle.distributed.collective.Group): ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards. + coordinator_rank(int): The rank used to save non distributed values. Rank0 is used by default. Examples: .. code-block:: python From e4c72cdad3f6af9dd23417ef27b375db8f6801a4 Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Wed, 6 Dec 2023 19:04:19 +0800 Subject: [PATCH 22/24] polish --- .../paddle/distributed/checkpoint/load_state_dict.py | 10 +++++----- .../paddle/distributed/checkpoint/save_state_dict.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index 38930ef5066bd3..c256c60ce5dd73 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -15,7 +15,7 @@ import copy import os from dataclasses import dataclass -from typing import Dict, Tuple +from typing import Tuple import paddle from paddle.distributed.communication.group import is_initialized @@ -347,10 +347,10 @@ def get_read_items(path, state_dict, process_group, use_dist): def load_state_dict( - state_dict: Dict[str, paddle.Tensor], - path: str, - process_group: paddle.distributed.collective.Group = None, - coordinator_rank: int = 0, + state_dict, + path, + process_group=None, + coordinator_rank=0, ) -> None: """ Load the state_dict inplace from a checkpoint path. diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index f61e22ccd29d52..919f09e6ee474b 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -13,7 +13,7 @@ # limitations under the License. import os -from typing import Dict, List +from typing import List import paddle from paddle.distributed.communication.group import is_initialized @@ -72,10 +72,10 @@ def dedup_storage_metadata(global_state_dict): def save_state_dict( - state_dict: Dict[str, paddle.Tensor], + state_dict, path: str, - process_group: paddle.distributed.collective.Group = None, - coordinator_rank: int = 0, + process_group=None, + coordinator_rank=0, ) -> None: """ Save the state_dict of model to path. From 0561180f4c4a755644de65f2cf421de52d78615a Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Wed, 6 Dec 2023 19:13:14 +0800 Subject: [PATCH 23/24] polish --- python/paddle/distributed/checkpoint/save_state_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index 919f09e6ee474b..4b7f3665d86da2 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -73,7 +73,7 @@ def dedup_storage_metadata(global_state_dict): def save_state_dict( state_dict, - path: str, + path, process_group=None, coordinator_rank=0, ) -> None: From 4df7f76b50f5640f471870c7bd2c111be67eddff Mon Sep 17 00:00:00 2001 From: pangengzheng Date: Wed, 6 Dec 2023 20:26:47 +0800 Subject: [PATCH 24/24] use tmp file path --- .../distributed/checkpoint/load_state_dict.py | 4 ++-- .../hybrid_strategy/load_state_dict.py | 14 ++++++----- .../hybrid_strategy/save_state_dict.py | 13 ++++------- .../semi_auto_parallel_simple_net_dp_mp.py | 7 +++--- .../semi_auto_parallel_simple_net_dp_mp_pp.py | 7 +++--- .../test_save_load_state_dict.py | 23 +++++++++++-------- ...test_semi_auto_parallel_hybrid_strategy.py | 11 +++++---- 7 files changed, 41 insertions(+), 38 deletions(-) diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index c256c60ce5dd73..153c6764d70d60 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -41,7 +41,7 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): ] assert ( len(metadata_files) > 0 - ), "No metadata file found in the checkpoint directory:{path}." + ), f"No metadata file found in the checkpoint directory:{path}." # The neccesary files to be read tensor_key_list = [] necessary_files = [] @@ -163,7 +163,7 @@ def update(rank_to_read_files, rank_to_not_read_files, rank_file): if f not in file_to_ranks: file_to_ranks[f] = [] file_to_ranks[f].append(r) - logger.info(f"file_to_ranks:{file_to_ranks}") + logger.debug(f"file_to_ranks:{file_to_ranks}") if file in file_to_ranks: for r in file_to_ranks[file]: rank_to_not_read_files[r].remove(file) diff --git a/test/auto_parallel/hybrid_strategy/load_state_dict.py b/test/auto_parallel/hybrid_strategy/load_state_dict.py index 93b088e7c33324..c500853324e713 100644 --- a/test/auto_parallel/hybrid_strategy/load_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/load_state_dict.py @@ -16,7 +16,6 @@ import numpy as np from auto_parallel.hybrid_strategy.save_state_dict import ( - ckpt_path, get_global_state_dict, ) @@ -30,13 +29,16 @@ class TestLoadStateDict: + def __init__(self): + self._ckpt_path = os.getenv("ckpt_path") + def test_load_state_dict_with_one_device(self): global_state_dict = get_global_state_dict() saved_w1, saved_w2 = list(global_state_dict.values()) w1 = paddle.zeros_like(saved_w1) w2 = paddle.zeros_like(saved_w2) state_dict = dict(zip(list(global_state_dict.keys()), [w1, w2])) - load_state_dict(state_dict, ckpt_path()) + load_state_dict(state_dict, self._ckpt_path) # check expect_w1 = saved_w1 expect_w2 = saved_w2 @@ -62,7 +64,7 @@ def test_load_state_dict_with_four_devices(self): state_dict = dict( zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2]) ) - load_state_dict(state_dict, ckpt_path()) + load_state_dict(state_dict, self._ckpt_path) # check cur_rank = paddle.distributed.get_rank() expect_w1 = saved_w1.split(4, axis=0)[cur_rank] @@ -85,7 +87,7 @@ def test_load_state_dict_with_two_devices(self): state_dict = dict( zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2]) ) - load_state_dict(state_dict, ckpt_path()) + load_state_dict(state_dict, self._ckpt_path) # check cur_rank = paddle.distributed.get_rank() expect_w1 = saved_w1.split(2, axis=0)[cur_rank] @@ -108,7 +110,7 @@ def test_load_state_dict_with_eight_devices(self): state_dict = dict( zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2]) ) - load_state_dict(state_dict, ckpt_path()) + load_state_dict(state_dict, self._ckpt_path) # check cur_rank = paddle.distributed.get_rank() local_shape, global_offset = compute_local_shape_and_global_offset( @@ -150,7 +152,7 @@ def run_test_case(self): elif device_num == 8: self.test_load_state_dict_with_eight_devices() else: - raise ValueError("device_num should be 2,4 or 8") + raise ValueError("device_num should be 1, 2, 4 or 8") if __name__ == '__main__': diff --git a/test/auto_parallel/hybrid_strategy/save_state_dict.py b/test/auto_parallel/hybrid_strategy/save_state_dict.py index e908d233cc77d7..0fd2f5d7049dbf 100644 --- a/test/auto_parallel/hybrid_strategy/save_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/save_state_dict.py @@ -25,19 +25,16 @@ def get_global_state_dict(): return {"w1": w1, "w2": w2} -def ckpt_path(): - return os.path.join( - os.path.dirname(os.path.abspath(__file__)), "tmp_ckpt_output" - ) - - class TestSaveStateDict: + def __init__(self): + self._ckpt_path = os.getenv("ckpt_path") + def test_save_state_dict_with_one_device(self): global_state_dict = get_global_state_dict() keys = list(global_state_dict.keys()) w1, w2 = list(global_state_dict.values()) state_dict = dict(zip(keys, [w1, w2])) - save_state_dict(state_dict, ckpt_path()) + save_state_dict(state_dict, self._ckpt_path) def test_save_state_dict_with_four_devices(self): global_state_dict = get_global_state_dict() @@ -52,7 +49,7 @@ def test_save_state_dict_with_four_devices(self): w2, mesh2, [dist.Shard(0), dist.Replicate()] ) state_dict = dict(zip(keys, [sharded_w1, sharded_w2])) - save_state_dict(state_dict, ckpt_path()) + save_state_dict(state_dict, self._ckpt_path) def run_test_case(self): device_num = int(os.getenv("device_num")) diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py index b4ba198c65b3c1..a3c2938a7370f9 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py @@ -14,7 +14,6 @@ import os -from auto_parallel.hybrid_strategy.save_state_dict import ckpt_path from auto_parallel.semi_auto_parallel_simple_net import ( DemoNet, TestSimpleNetForSemiAutoParallel, @@ -31,6 +30,7 @@ def __init__(self): self._dtype = os.getenv("dtype") self._backend = os.getenv("backend") self._seed = eval(os.getenv("seed")) + self._ckpt_path = os.getenv("ckpt_path") self._mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) paddle.set_device(self._backend) @@ -61,14 +61,13 @@ def test_dp_mp_demo_net(self): local_state_dict = {} for k, v in state_dict.items(): local_state_dict[k] = v._local_value().clone() - paddle.distributed.save_state_dict(state_dict, ckpt_path()) + paddle.distributed.save_state_dict(state_dict, self._ckpt_path) for k, v in state_dict.items(): v._local_value().add_(paddle.ones_like(v._local_value())) - paddle.distributed.load_state_dict(state_dict, ckpt_path()) + paddle.distributed.load_state_dict(state_dict, self._ckpt_path) for k, v in state_dict.items(): assert k in local_state_dict, k self.check_tensor_eq(v._local_value(), local_state_dict[k]) - os.system(f"rm -rf {ckpt_path()}") def run_test_case(self): self.test_dp_mp_demo_net() diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py index 2634f78c7b30ab..ddbc66e080b2b7 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py @@ -14,7 +14,6 @@ import os -from auto_parallel.hybrid_strategy.save_state_dict import ckpt_path from auto_parallel.semi_auto_parallel_simple_net import ( DemoNet, TestSimpleNetForSemiAutoParallel, @@ -32,6 +31,7 @@ def __init__(self): self._dtype = os.getenv("dtype") self._backend = os.getenv("backend") self._seed = eval(os.getenv("seed")) + self._ckpt_path = os.getenv("ckpt_path") self._mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) self._pp_mesh0 = dist.ProcessMesh( [[0, 1], [2, 3]], dim_names=["x", "y"] @@ -111,15 +111,14 @@ def test_dp_mp_pp_demo_net(self): local_state_dict[k] = ( v._local_value().clone() if v._is_initialized() else None ) - paddle.distributed.save_state_dict(state_dict, ckpt_path()) + paddle.distributed.save_state_dict(state_dict, self._ckpt_path) for k, v in state_dict.items(): v._local_value().add_(paddle.ones_like(v._local_value())) - paddle.distributed.load_state_dict(state_dict, ckpt_path()) + paddle.distributed.load_state_dict(state_dict, self._ckpt_path) for k, v in state_dict.items(): assert k in local_state_dict, k if v._is_initialized(): self.check_tensor_eq(v._local_value(), local_state_dict[k]) - os.system(f"rm -rf {ckpt_path()}") def run_test_case(self): self.test_dp_mp_pp_demo_net() diff --git a/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py b/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py index 7122ad18fdbc3d..a0b64a374d6274 100644 --- a/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +import tempfile import unittest import collective.test_communication_api_base as test_base -from auto_parallel.hybrid_strategy.save_state_dict import ckpt_path class TestSaveLoadStateDict(test_base.CommunicationTestDistBase): @@ -26,10 +25,11 @@ def setUp(self): def test_save_load_state_dict(self): # save with 1 device - os.system(f"rm -rf {ckpt_path()}") + ckpt_path = tempfile.TemporaryDirectory() super().setUp(num_of_devices=1, timeout=120, nnode=1) self.run_test_case( - "save_state_dict.py", user_defined_envs={"device_num": "1"} + "save_state_dict.py", + user_defined_envs={"device_num": "1", "ckpt_path": ckpt_path.name}, ) # load with 1, 2, 4, 8 devices @@ -37,38 +37,41 @@ def test_save_load_state_dict(self): self._default_envs, self._changeable_envs ) for envs in envs_list: + envs["ckpt_path"] = ckpt_path.name super().setUp( num_of_devices=int(envs["device_num"]), - timeout=120, + timeout=180, nnode=1, ) self.run_test_case( "load_state_dict.py", user_defined_envs=envs, ) - os.system(f"rm -rf {ckpt_path()}") + ckpt_path.cleanup() # save with 4 devices - os.system(f"rm -rf {ckpt_path()}") + ckpt_path = tempfile.TemporaryDirectory() super().setUp(num_of_devices=4, timeout=120, nnode=1) self.run_test_case( - "save_state_dict.py", user_defined_envs={"device_num": "4"} + "save_state_dict.py", + user_defined_envs={"device_num": "4", "ckpt_path": ckpt_path.name}, ) # load with 1, 2, 4, 8 devices envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs ) for envs in envs_list: + envs["ckpt_path"] = ckpt_path.name super().setUp( num_of_devices=int(envs["device_num"]), - timeout=120, + timeout=180, nnode=1, ) self.run_test_case( "load_state_dict.py", user_defined_envs=envs, ) - os.system(f"rm -rf {ckpt_path()}") + ckpt_path.cleanup() if __name__ == '__main__': diff --git a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py index defaf6227f03ea..21da5d0a694425 100644 --- a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py +++ b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +import tempfile import unittest import collective.test_communication_api_base as test_base -from auto_parallel.hybrid_strategy.save_state_dict import ckpt_path class TestSemiAutoParallelDPMPStrategy(test_base.CommunicationTestDistBase): @@ -27,17 +26,19 @@ def setUp(self): "seed": "2023", } self._changeable_envs = {"backend": ["gpu"]} - os.system(f"rm -rf {ckpt_path()}") def test_simple_net_bybrid_strategy(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs ) for envs in envs_list: + ckpt_path = tempfile.TemporaryDirectory() + envs["ckpt_path"] = ckpt_path.name self.run_test_case( "semi_auto_parallel_simple_net_dp_mp.py", user_defined_envs=envs, ) + ckpt_path.cleanup() class TestSemiAutoParallelHybridStrategy(test_base.CommunicationTestDistBase): @@ -52,17 +53,19 @@ def setUp(self): "seed": "2023", } self._changeable_envs = {"backend": ["gpu"]} - os.system(f"rm -rf {ckpt_path()}") def test_simple_net_bybrid_strategy(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs ) for envs in envs_list: + ckpt_path = tempfile.TemporaryDirectory() + envs["ckpt_path"] = ckpt_path.name self.run_test_case( "semi_auto_parallel_simple_net_dp_mp_pp.py", user_defined_envs=envs, ) + ckpt_path.cleanup() if __name__ == "__main__":