Skip to content

Commit

Permalink
Fixed docs on HSDP sharding/replication dims
Browse files Browse the repository at this point in the history
ghstack-source-id: 77f650e8281dae12f2a7ccdb415be88f9abd88cc
Pull Request resolved: #283
  • Loading branch information
awgu committed Apr 29, 2024
1 parent b898545 commit 935b572
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion docs/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def fully_shard(
- Calling `model.named_parameters()` for a `model` with FSDP2 applied returns unchanged parameter names and `DTensor` sharded parameters. This means that the optimizer and gradient norm clipping see `DTensor`s.
- `fully_shard(module)` performs a dynamic class swap on `module`. E.g., if `type(module) is Transformer`, then FSDP2 constructs a new class `FSDPTransformer` that inherits from a class `FSDP` and `Transformer` and sets `module.__class__` to be `FSDPTransformer`. This allows us to add new methods and override methods via `FSDP` without constructing an `nn.Module` wrapper.
- FSDP1's `sharding_strategy` and `process_group`/`device_mesh` maps to FSDP2's `mesh` and `reshard_after_forward`.
- `mesh` should be 1D for FSDP and 2D for HSDP. For HSDP, we assume sharding on the 0th mesh dim and replication on the 1st mesh dim. If `mesh is None`, then FSDP2 initializes a 1D global mesh over the default process group.
- `mesh` should be 1D for FSDP and 2D for HSDP. For HSDP, we assume replication on the 0th mesh dim and sharding on the 1st mesh dim. If `mesh is None`, then FSDP2 initializes a 1D global mesh over the default process group.
- `reshard_after_forward=True` or `False` determines whether parameters are resharded (freed) after forward. If `True`, then they are re-all-gathered in backward. This trades off saving memory at the cost of extra communication.
- (Experimental) `reshard_after_forward: int` means that parameters are resharded to a smaller world size after forward (e.g. `reshard_after_forward=8` can mean intra-node) so that the backward all-gather is over a smaller world size.
- | FSDP1 | FSDP2 | DeepSpeed |
Expand Down

0 comments on commit 935b572

Please sign in to comment.