Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source/clis.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ TRL provides several ready-to-use Accelerate configs to simplify common training

| Name | Description |
| --- | --- |
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
| `zero1` | DeepSpeed ZeRO Stage 1 |
| `zero2` | DeepSpeed ZeRO Stage 2 |
Expand Down
28 changes: 0 additions & 28 deletions examples/accelerate_configs/fsdp1.yaml

This file was deleted.

28 changes: 0 additions & 28 deletions trl/accelerate_configs/fsdp1.yaml

This file was deleted.

46 changes: 3 additions & 43 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,34 +863,7 @@ def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = No
name = name.replace(prefix, "")
return name

def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
"""Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM."""
# For FSDP1, we need to recurse into children and also use summon_full_params
if visited is None:
visited = set()
for child_name, child_module in module.named_children():
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
self._sync_fsdp1_params_to_vllm(
child_module, prefix=child_prefix, visited=visited
) # recurse into the child

if isinstance(module, FSDP):
with FSDP.summon_full_params(module, recurse=False, writeback=False):
for param_name, param in module.named_parameters():
full_name = f"{prefix}.{param_name}" if prefix else param_name
full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."])

if full_name in visited:
continue # skip FSDP subtrees already traversed
visited.add(full_name)

if self.vllm_mode == "server" and self.accelerator.is_main_process:
self.vllm_client.update_named_param(full_name, param.data)
elif self.vllm_mode == "colocate":
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(full_name, param.data)])

def _sync_fsdp2_params_to_vllm(self, module: nn.Module):
def _sync_fsdp_params_to_vllm(self, module: nn.Module):
# For FSDP2, module.state_dict() already covers all parameters, so no need for recursion
for name, param in module.state_dict().items():
if param.is_cpu:
Expand Down Expand Up @@ -925,15 +898,7 @@ def _move_model_to_vllm(self):
# Update vLLM weights while parameters are gathered
if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext
# Update vLLM weights while parameters are gathered
# For PEFT with FSDP we need to use the memory efficient post-order traversal
fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
if fsdp_version == 1:
self._sync_fsdp1_params_to_vllm(
self.model
) # use memory-efficient post-order traversal for FSDP
elif fsdp_version == 2:
self._sync_fsdp2_params_to_vllm(self.model)
self._sync_fsdp_params_to_vllm(self.model)
else:
# DeepSpeed ZeRO-3 with PEFT
for name, param in self.model.named_parameters():
Expand All @@ -957,12 +922,7 @@ def _move_model_to_vllm(self):
else:
# For non-PEFT models, simply gather (if needed) and update each parameter individually.
if self.is_fsdp_enabled:
fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
if fsdp_version == 1:
self._sync_fsdp1_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP
elif fsdp_version == 2:
self._sync_fsdp2_params_to_vllm(self.model)
self._sync_fsdp_params_to_vllm(self.model)
else:
for name, param in self.model.named_parameters():
name = self._fix_param_name_to_vllm(name)
Expand Down
45 changes: 3 additions & 42 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,14 +884,7 @@ def _move_model_to_vllm(self):
# Update vLLM weights while parameters are gathered
if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext
# Update vLLM weights while parameters are gathered
# For PEFT with FSDP we need to use the memory efficient post-order traversal
fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
if fsdp_version == 1:
# use memory-efficient post-order traversal for FSDP
self._sync_fsdp1_params_to_vllm(self.model)
elif fsdp_version == 2:
self._sync_fsdp2_params_to_vllm(self.model)
self._sync_fsdp_params_to_vllm(self.model)
else:
# DeepSpeed ZeRO-3 with PEFT
for name, param in self.model.named_parameters():
Expand All @@ -915,12 +908,7 @@ def _move_model_to_vllm(self):
else:
# For non-PEFT models, simply gather (if needed) and update each parameter individually.
if self.is_fsdp_enabled:
fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
if fsdp_version == 1:
self._sync_fsdp1_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP
elif fsdp_version == 2:
self._sync_fsdp2_params_to_vllm(self.model)
self._sync_fsdp_params_to_vllm(self.model)
else:
for name, param in self.model.named_parameters():
name = self._fix_param_name_to_vllm(name)
Expand All @@ -937,34 +925,7 @@ def _move_model_to_vllm(self):
elif self.vllm_mode == "colocate":
self.llm.reset_prefix_cache()

def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
"""Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM."""
# For FSDP1, we need to recurse into children and also use summon_full_params
if visited is None:
visited = set()
for child_name, child_module in module.named_children():
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
self._sync_fsdp1_params_to_vllm(
child_module, prefix=child_prefix, visited=visited
) # recurse into the child

if isinstance(module, FSDP):
with FSDP.summon_full_params(module, recurse=False, writeback=False):
for param_name, param in module.named_parameters():
full_name = f"{prefix}.{param_name}" if prefix else param_name
full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."])

if full_name in visited:
continue # skip FSDP subtrees already traversed
visited.add(full_name)

if self.vllm_mode == "server" and self.accelerator.is_main_process:
self.vllm_client.update_named_param(full_name, param.data)
elif self.vllm_mode == "colocate":
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(full_name, param.data)])

def _sync_fsdp2_params_to_vllm(self, module: nn.Module):
def _sync_fsdp_params_to_vllm(self, module: nn.Module):
# For FSDP2, module.state_dict() already covers all parameters, so no need for recursion
for name, param in module.state_dict().items():
if param.is_cpu:
Expand Down
46 changes: 3 additions & 43 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,34 +859,7 @@ def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = No
name = name.replace(prefix, "")
return name

def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
"""Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM."""
# For FSDP1, we need to recurse into children and also use summon_full_params
if visited is None:
visited = set()
for child_name, child_module in module.named_children():
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
self._sync_fsdp1_params_to_vllm(
child_module, prefix=child_prefix, visited=visited
) # recurse into the child

if isinstance(module, FSDP):
with FSDP.summon_full_params(module, recurse=False, writeback=False):
for param_name, param in module.named_parameters():
full_name = f"{prefix}.{param_name}" if prefix else param_name
full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."])

if full_name in visited:
continue # skip FSDP subtrees already traversed
visited.add(full_name)

if self.vllm_mode == "server" and self.accelerator.is_main_process:
self.vllm_client.update_named_param(full_name, param.data)
elif self.vllm_mode == "colocate":
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(full_name, param.data)])

def _sync_fsdp2_params_to_vllm(self, module: nn.Module):
def _sync_fsdp_params_to_vllm(self, module: nn.Module):
# For FSDP2, module.state_dict() already covers all parameters, so no need for recursion
for name, param in module.state_dict().items():
if param.is_cpu:
Expand Down Expand Up @@ -921,15 +894,7 @@ def _move_model_to_vllm(self):
# Update vLLM weights while parameters are gathered
if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext
# Update vLLM weights while parameters are gathered
# For PEFT with FSDP we need to use the memory efficient post-order traversal
fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
if fsdp_version == 1:
self._sync_fsdp1_params_to_vllm(
self.model
) # use memory-efficient post-order traversal for FSDP
elif fsdp_version == 2:
self._sync_fsdp2_params_to_vllm(self.model)
self._sync_fsdp_params_to_vllm(self.model)
else:
# DeepSpeed ZeRO-3 with PEFT
for name, param in self.model.named_parameters():
Expand All @@ -953,12 +918,7 @@ def _move_model_to_vllm(self):
else:
# For non-PEFT models, simply gather (if needed) and update each parameter individually.
if self.is_fsdp_enabled:
fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
if fsdp_version == 1:
self._sync_fsdp1_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP
elif fsdp_version == 2:
self._sync_fsdp2_params_to_vllm(self.model)
self._sync_fsdp_params_to_vllm(self.model)
else:
for name, param in self.model.named_parameters():
name = self._fix_param_name_to_vllm(name)
Expand Down
Loading