Skip to content

Commit d9118af

Browse files
SkafteNickilantiga
authored andcommitted
Add support for deepspeeds exclude_frozen_parameters (#21060)
* add to deepspeed strategies * add testing * changelog * GLOO_SOCKET_IFNAME --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka B <[email protected]> (cherry picked from commit 5071a04)
1 parent e72fc19 commit d9118af

File tree

7 files changed

+94
-6
lines changed

7 files changed

+94
-6
lines changed

.github/workflows/ci-tests-pytorch.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ jobs:
8383
- name: basic setup
8484
run: pip install -q -r .actions/requirements.txt
8585

86+
- name: Append Env. vars for Linux
87+
if: ${{ runner.os == 'Linux' }}
88+
run: echo "GLOO_SOCKET_IFNAME=eth0" >> $GITHUB_ENV
89+
8690
- name: Set min. dependencies
8791
if: ${{ matrix.config.requires == 'oldest' }}
8892
run: |

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
### Changed
1212

13-
-
13+
- Include `exclude_frozen_parameters` to `DeepSpeedStrategy` ([#21060](https://github.com/Lightning-AI/pytorch-lightning/pull/21060))
1414

1515

1616
### Fixed
@@ -57,7 +57,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5757

5858
### Removed
5959

60-
- Removed legacy support for `lightning run model`. Use `fabric run` instead ([#20588](https://github.com/Lightning-AI/pytorch-lightning/pull/20588))
60+
- Removed legacy support for `lightning run model`; use `fabric run` instead ([#20588](https://github.com/Lightning-AI/pytorch-lightning/pull/20588))
6161

6262

6363
## [2.5.0] - 2024-12-19

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(
9999
precision: Optional[Precision] = None,
100100
process_group_backend: Optional[str] = None,
101101
timeout: Optional[timedelta] = default_pg_timeout,
102+
exclude_frozen_parameters: bool = False,
102103
) -> None:
103104
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
104105
billion parameter models. `For more information: https://pytorch-
@@ -228,6 +229,8 @@ def __init__(
228229
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
229230
per worker.
230231
232+
exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints.
233+
231234
"""
232235
if not _DEEPSPEED_AVAILABLE:
233236
raise ImportError(
@@ -288,6 +291,7 @@ def __init__(
288291

289292
self.remote_device = remote_device
290293
self.load_full_weights = load_full_weights
294+
self.exclude_frozen_parameters = exclude_frozen_parameters
291295

292296
# default FP16 parameters.
293297
self.loss_scale = loss_scale
@@ -444,7 +448,9 @@ def save_checkpoint(
444448
# there might be other stateful objects unrelated to the deepspeed engine - convert them to a state_dict
445449
state = self._convert_stateful_objects_in_state(state, filter={})
446450
# use deepspeed's internal checkpointing function to handle partitioned weights across processes
447-
engine.save_checkpoint(path, client_state=state, tag="checkpoint")
451+
engine.save_checkpoint(
452+
path, client_state=state, tag="checkpoint", exclude_frozen_parameters=self.exclude_frozen_parameters
453+
)
448454

449455
@override
450456
def load_checkpoint(

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
## [unReleased] - 2025-09-DD
1010

11+
- Added `exclude_frozen_parameters` to `DeepSpeedStrategy` ([#21060](https://github.com/Lightning-AI/pytorch-lightning/pull/21060))
12+
13+
1114
- Added `PossibleUserWarning` that is raised if modules are in eval mode when training starts ([#21146](https://github.com/Lightning-AI/pytorch-lightning/pull/21146))
1215

16+
1317
### Changed
1418

1519
-

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(
122122
precision_plugin: Optional[Precision] = None,
123123
process_group_backend: Optional[str] = None,
124124
timeout: Optional[timedelta] = default_pg_timeout,
125+
exclude_frozen_parameters: bool = False,
125126
) -> None:
126127
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
127128
billion parameter models. `For more information: https://pytorch-
@@ -253,6 +254,8 @@ def __init__(
253254
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
254255
per worker.
255256
257+
exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints.
258+
256259
"""
257260
if not _DEEPSPEED_AVAILABLE:
258261
raise MisconfigurationException(
@@ -311,6 +314,7 @@ def __init__(
311314

312315
self.remote_device = remote_device
313316
self.load_full_weights = load_full_weights
317+
self.exclude_frozen_parameters = exclude_frozen_parameters
314318

315319
# default FP16 parameters.
316320
self.loss_scale = loss_scale
@@ -648,7 +652,12 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op
648652
# dump states as a checkpoint dictionary object
649653
_exclude_keys = ["state_dict", "optimizer_states"]
650654
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
651-
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint")
655+
self.deepspeed_engine.save_checkpoint(
656+
filepath,
657+
client_state=checkpoint,
658+
tag="checkpoint",
659+
exclude_frozen_parameters=self.exclude_frozen_parameters,
660+
)
652661

653662
@override
654663
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:

tests/tests_fabric/strategies/test_deepspeed.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,19 @@ def test_deepspeed_save_checkpoint_client_state_separation(tmp_path):
193193
model.modules.return_value = [model]
194194
strategy.save_checkpoint(path=tmp_path, state={"model": model, "test": "data"})
195195
# the client_state should not contain any deepspeed engine or deepspeed optimizer
196-
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
196+
model.save_checkpoint.assert_called_with(
197+
tmp_path, client_state={"test": "data"}, tag="checkpoint", exclude_frozen_parameters=False
198+
)
197199

198200
# Model and optimizer
199201
optimizer = Mock()
200202
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
201203
model.modules.return_value = [model]
202204
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
203205
# the client_state should not contain any deepspeed engine or deepspeed optimizer
204-
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
206+
model.save_checkpoint.assert_called_with(
207+
tmp_path, client_state={"test": "data"}, tag="checkpoint", exclude_frozen_parameters=False
208+
)
205209

206210

207211
@RunIf(deepspeed=True)
@@ -218,6 +222,27 @@ def test_deepspeed_save_checkpoint_warn_colliding_keys(tmp_path):
218222
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "mp_world_size": 2})
219223

220224

225+
@RunIf(deepspeed=True)
226+
@pytest.mark.parametrize("exclude_frozen_parameters", [True, False])
227+
def test_deepspeed_save_checkpoint_exclude_frozen_parameters(exclude_frozen_parameters):
228+
"""Test that the DeepSpeed strategy can save checkpoints with the `exclude_frozen_parameters` argument."""
229+
from deepspeed import DeepSpeedEngine
230+
231+
strategy = DeepSpeedStrategy(exclude_frozen_parameters=exclude_frozen_parameters)
232+
assert strategy.exclude_frozen_parameters is exclude_frozen_parameters
233+
234+
model = Mock(spec=DeepSpeedEngine, optimizer=None)
235+
model.modules.return_value = [model]
236+
strategy.save_checkpoint(path="test_path", state={"model": model, "extra": "data"})
237+
238+
model.save_checkpoint.assert_called_with(
239+
"test_path",
240+
client_state={"extra": "data"},
241+
tag="checkpoint",
242+
exclude_frozen_parameters=exclude_frozen_parameters,
243+
)
244+
245+
221246
@RunIf(deepspeed=True)
222247
def test_deepspeed_load_checkpoint_validate_path(tmp_path):
223248
"""Test that we validate the checkpoint path for a DeepSpeed checkpoint and give suggestions for user error."""

tests/tests_pytorch/strategies/test_deepspeed.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,46 @@ def test_deepspeed_multigpu_single_file(tmp_path):
562562
trainer.test(model, ckpt_path=checkpoint_path)
563563

564564

565+
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
566+
def test_deepspeed_strategy_exclude_frozen_parameters_integration(tmp_path):
567+
"""Test end-to-end integration of exclude_frozen_parameters with actual model training and checkpointing."""
568+
569+
class TestModelWithFrozenParams(BoringModel):
570+
def __init__(self):
571+
super().__init__()
572+
self.frozen_layer = torch.nn.Linear(32, 32)
573+
574+
def configure_model(self) -> None:
575+
super().configure_model()
576+
# Freeze the additional layer parameters
577+
for param in self.frozen_layer.parameters():
578+
param.requires_grad = False
579+
580+
def forward(self, x):
581+
x = self.frozen_layer(x)
582+
return super().forward(x)
583+
584+
model = TestModelWithFrozenParams()
585+
586+
trainer = Trainer(
587+
default_root_dir=tmp_path,
588+
strategy=DeepSpeedStrategy(exclude_frozen_parameters=True),
589+
accelerator="gpu",
590+
devices=1,
591+
fast_dev_run=True,
592+
precision="16-mixed",
593+
enable_progress_bar=False,
594+
enable_model_summary=False,
595+
)
596+
597+
trainer.fit(model)
598+
checkpoint_path = os.path.join(tmp_path, "checkpoint_exclude_frozen.ckpt")
599+
trainer.save_checkpoint(checkpoint_path)
600+
601+
# Verify checkpoint was created
602+
assert os.path.exists(checkpoint_path)
603+
604+
565605
class ModelParallelClassificationModel(LightningModule):
566606
def __init__(self, lr: float = 0.01, num_blocks: int = 5):
567607
super().__init__()

0 commit comments

Comments
 (0)