Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xtuner/v1/config/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class FSDPConfig(BaseModel):
# TODO: (caoweihan) Convert `torch.dtype` to `Annotated` for compatibility with cyclopts
param_dtype: Annotated[torch.dtype, Parameter(help="Data type for model parameters")] = torch.bfloat16
reduce_dtype: Annotated[torch.dtype, Parameter(help="Data type for reduction operations")] = torch.bfloat16
lm_head_fp32: Annotated[bool, Parameter(help="Use float32 for language model head")] = False
torch_compile: Annotated[bool, Parameter(help="Enable model compilation for faster inference")] = False
compile_targets: Annotated[
Optional[Tuple[str, ...]],
Expand Down
6 changes: 5 additions & 1 deletion xtuner/v1/model/dense/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ def fully_shard(
mp_policy = MixedPrecisionPolicy(
param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype
)
if self.fsdp_config.lm_head_fp32:
lm_head_mp_policy = MixedPrecisionPolicy(param_dtype=torch.float32, reduce_dtype=torch.float32)
else:
lm_head_mp_policy = mp_policy
num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio)

generator = torch.Generator()
Expand Down Expand Up @@ -236,7 +240,7 @@ def fully_shard(
fully_shard(
self.lm_head,
mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh,
mp_policy=mp_policy,
mp_policy=lm_head_mp_policy,
reshard_after_forward=self.fsdp_config.reshard_after_forward,
offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None,
)
Expand Down
6 changes: 5 additions & 1 deletion xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,10 @@ def fully_shard(
mp_policy = MixedPrecisionPolicy(
param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype
)
if self.fsdp_config.lm_head_fp32:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Declaring it when using it

lm_head_mp_policy = MixedPrecisionPolicy(param_dtype=torch.float32, reduce_dtype=torch.float32)
else:
lm_head_mp_policy = mp_policy
num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio)

for layer_idx, layer in tqdm(self.layers.items(), desc="[FSDP Sharding]"):
Expand Down Expand Up @@ -698,7 +702,7 @@ def fully_shard(
fully_shard(
self.lm_head,
mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh,
mp_policy=mp_policy,
mp_policy=lm_head_mp_policy,
reshard_after_forward=self.fsdp_config.reshard_after_forward,
offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None,
)
Expand Down
Loading