Skip to content

Commit

Permalink
Fix state dict api en doc (PaddlePaddle#59793)
Browse files Browse the repository at this point in the history
* exclude xpu

* fix save and load doc and not expose api in checkpoint dir

* skip example

* fix doc
  • Loading branch information
pangengzheng committed Dec 15, 2023
1 parent e06bb2f commit 34ff9eb
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 10 deletions.
8 changes: 0 additions & 8 deletions python/paddle/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .save_state_dict import save_state_dict
from .load_state_dict import load_state_dict

__all__ = [
"save_state_dict",
"load_state_dict",
]
5 changes: 4 additions & 1 deletion python/paddle/distributed/checkpoint/load_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,14 +360,17 @@ def load_state_dict(
) -> None:
"""
Load the state_dict inplace from a checkpoint path.
Args:
state_dict(Dict[str, paddle.Tensor]): The state_dict to load. It will be modified inplace after loading.
path(str): The directory to load checkpoint files.
process_group(paddle.distributed.collective.Group): ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards.
coordinator_rank(int): The rank used to coordinate the checkpoint. Rank0 is used by default.
Example:
.. code-block:: python
>>> # doctest: +SKIP('Load state dict.')
>>> # doctest: +SKIP('run in distributed mode.')
>>> import paddle
>>> import paddle.distributed as dist
>>> ckpt_path = "./checkpoint"
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/distributed/checkpoint/save_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def save_state_dict(
Examples:
.. code-block:: python
>>> # doctest: +SKIP('Save state dict.')
>>> # doctest: +SKIP('run in distributed mode')
>>> import paddle
>>> import paddle.distributed as dist
>>> w1 = paddle.arange(32).reshape([4, 8])
Expand Down

0 comments on commit 34ff9eb

Please sign in to comment.