diff --git a/xtuner/v1/config/fsdp.py b/xtuner/v1/config/fsdp.py index c7341799a..99d0d1202 100644 --- a/xtuner/v1/config/fsdp.py +++ b/xtuner/v1/config/fsdp.py @@ -22,6 +22,7 @@ class FSDPConfig(BaseModel): # TODO: (caoweihan) Convert `torch.dtype` to `Annotated` for compatibility with cyclopts param_dtype: Annotated[torch.dtype, Parameter(help="Data type for model parameters")] = torch.bfloat16 reduce_dtype: Annotated[torch.dtype, Parameter(help="Data type for reduction operations")] = torch.bfloat16 + lm_head_fp32: Annotated[bool, Parameter(help="Use float32 for language model head")] = False torch_compile: Annotated[bool, Parameter(help="Enable model compilation for faster inference")] = False compile_targets: Annotated[ Optional[Tuple[str, ...]], diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index 70839d241..e5a4e26de 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -186,6 +186,10 @@ def fully_shard( mp_policy = MixedPrecisionPolicy( param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype ) + if self.fsdp_config.lm_head_fp32: + lm_head_mp_policy = MixedPrecisionPolicy(param_dtype=torch.float32, reduce_dtype=torch.float32) + else: + lm_head_mp_policy = mp_policy num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio) generator = torch.Generator() @@ -236,7 +240,7 @@ def fully_shard( fully_shard( self.lm_head, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, - mp_policy=mp_policy, + mp_policy=lm_head_mp_policy, reshard_after_forward=self.fsdp_config.reshard_after_forward, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, ) diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index c5703b5fb..ef8a6f43e 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -653,6 +653,10 @@ def fully_shard( mp_policy = MixedPrecisionPolicy( param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype ) + if self.fsdp_config.lm_head_fp32: + lm_head_mp_policy = MixedPrecisionPolicy(param_dtype=torch.float32, reduce_dtype=torch.float32) + else: + lm_head_mp_policy = mp_policy num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio) for layer_idx, layer in tqdm(self.layers.items(), desc="[FSDP Sharding]"): @@ -698,7 +702,7 @@ def fully_shard( fully_shard( self.lm_head, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, - mp_policy=mp_policy, + mp_policy=lm_head_mp_policy, reshard_after_forward=self.fsdp_config.reshard_after_forward, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, )