diff --git a/nemo/collections/llm/recipes/CONFIGURATION-HIERARCHY.md b/nemo/collections/llm/recipes/CONFIGURATION-HIERARCHY.md index 80eb0416cc18..9c5dff0e8b2b 100644 --- a/nemo/collections/llm/recipes/CONFIGURATION-HIERARCHY.md +++ b/nemo/collections/llm/recipes/CONFIGURATION-HIERARCHY.md @@ -57,6 +57,11 @@ bucket_size: Optional[int] = None # Maximum number of parameters in each bucket average_in_collective: bool = False # If true, compute average in collective directly, as opposed to dividing by the dp_size first and then computing sum in the collective fp8_param_gather: bool = False # If true, keep the compute param in fp8 (do not use any other intermediate dtype) and perform the param all-gather in fp8 + use_custom_fsdp: bool = False # If true, use MCore's custom FSDP implementation. recipe.model.config.gradient_accumulation_fusion must be False when using this + data_parallel_sharding_strategy: str = "no_shard" # Sharding strategy when using custom FSDP, choices=['no_shard', 'optim', 'optim_grads', 'optim_grads_params'] + suggested_communication_unit_size: int = 400_000_000 # When using custom FSDP and batch communication is needed across multiple buckets, this variable guides the size of communication unit size + preserve_fp32_weights: bool = True # If true, preserve fp32 weights in the custom FSDP ParamAndGradBuffer + keep_fp8_transpose_cache_when_using_custom_fsdp: bool = False # If true, keep the fp8 transpose cache when using custom FSDP ``` diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index ed4e82c7a96a..22dd34e77e02 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -689,6 +689,12 @@ def init_ddp(self): ) # We need to do this explicitly since this is a attr pytorch uses model_chunk.__class__.__getattr__ = getattr_proxy # type: ignore + # Ensure that if using FSDP, gradient_accumulation_fusion is disabled on the model config. + if self.ddp_config.use_custom_fsdp: + assert ( + module.config.gradient_accumulation_fusion == False + ), "gradient_accumulation_fusion cannot be used with FSDP" + # param_sync_func is set in nemo.lightning.pytorch.optim.megatron no_sync_func, grad_sync_func = extract_ddp_funcs(self.ddp_config, self) for module in self: