From 2fc2fb7939f190ae0f1a2274103704a43f98db57 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 15:09:17 -0500 Subject: [PATCH 01/21] Update fsdp.py --- src/lightning/pytorch/plugins/precision/fsdp.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index c41199adb480e..57c5bf9c9aba1 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -76,9 +76,6 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca @override def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ - # section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP. - # To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference - # to the root module raise MisconfigurationException( f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`" ) From c36f40cb0b066b3ebd22c76afbdd196f3edc53b5 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 15:52:46 -0500 Subject: [PATCH 02/21] Support gradient norm clipping for FSDP --- src/lightning/pytorch/core/module.py | 4 +++- src/lightning/pytorch/plugins/precision/amp.py | 6 +++++- src/lightning/pytorch/plugins/precision/fsdp.py | 9 ++++----- src/lightning/pytorch/plugins/precision/precision.py | 5 +++-- tests/tests_pytorch/plugins/precision/test_amp.py | 8 +++++--- 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index b8624daac3fa3..d1b0cca4feeae 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1207,7 +1207,9 @@ def clip_gradients( ) gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm) - self.trainer.precision_plugin.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm) + self.trainer.precision_plugin.clip_gradients( + self.trainer.model, optimizer, gradient_clip_val, gradient_clip_algorithm + ) def configure_gradient_clipping( self, diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 75e792af46b90..6746b5dcd2585 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from torch.nn import Module from torch.optim import LBFGS, Optimizer from typing_extensions import override @@ -100,6 +101,7 @@ def optimizer_step( # type: ignore[override] @override def clip_gradients( self, + module: Module, optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, @@ -109,7 +111,9 @@ def clip_gradients( f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping" " because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?" ) - super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm) + super().clip_gradients( + module=module, optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm + ) def autocast_context_manager(self) -> torch.autocast: return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index cd05cda985df5..bc4b3c0185a85 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import AbstractContextManager -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from lightning_utilities import apply_to_collection from torch import Tensor from torch.nn import Module +from torch.optim import Optimizer from typing_extensions import get_args, override import lightning.pytorch as pl @@ -81,11 +82,9 @@ def convert_module(self, module: Module) -> Module: return module @override - def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: + def clip_grad_by_norm(self, module: Module, optimizer: Optimizer, clip_val: Union[int, float]) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ - raise MisconfigurationException( - f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`" - ) + module.clip_grad_norm_(clip_val) @property def mixed_precision_config(self) -> "TorchMixedPrecision": diff --git a/src/lightning/pytorch/plugins/precision/precision.py b/src/lightning/pytorch/plugins/precision/precision.py index 327fb2d4f5a27..08655fafca758 100644 --- a/src/lightning/pytorch/plugins/precision/precision.py +++ b/src/lightning/pytorch/plugins/precision/precision.py @@ -143,6 +143,7 @@ def _clip_gradients( def clip_gradients( self, + module: Module, optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, @@ -153,14 +154,14 @@ def clip_gradients( if gradient_clip_algorithm == GradClipAlgorithmType.VALUE: self.clip_grad_by_value(optimizer, clip_val) elif gradient_clip_algorithm == GradClipAlgorithmType.NORM: - self.clip_grad_by_norm(optimizer, clip_val) + self.clip_grad_by_norm(module, optimizer, clip_val) def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: """Clip gradients by value.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) - def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + def clip_grad_by_norm(self, module: Module, optimizer: Optimizer, clip_val: Union[int, float]) -> None: """Clip gradients by norm.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_norm_(parameters, clip_val) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index cb061c540b2be..809d9f19d0706 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,6 +14,7 @@ from unittest.mock import Mock import pytest +from torch.nn import Module from torch.optim import Optimizer from lightning.pytorch.plugins import MixedPrecision @@ -22,22 +23,23 @@ def test_clip_gradients(): """Test that `.clip_gradients()` is a no-op when clipping is disabled.""" + module = Mock(spec=Module) optimizer = Mock(spec=Optimizer) precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) precision.clip_grad_by_value = Mock() precision.clip_grad_by_norm = Mock() - precision.clip_gradients(optimizer) + precision.clip_gradients(module, optimizer) precision.clip_grad_by_value.assert_not_called() precision.clip_grad_by_norm.assert_not_called() - precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE) + precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE) precision.clip_grad_by_value.assert_called_once() precision.clip_grad_by_norm.assert_not_called() precision.clip_grad_by_value.reset_mock() precision.clip_grad_by_norm.reset_mock() - precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM) + precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM) precision.clip_grad_by_value.assert_not_called() precision.clip_grad_by_norm.assert_called_once() From 8fad4235fdbcac1819d6582b40c385584adbe02d Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 16:08:08 -0500 Subject: [PATCH 03/21] Update CHANGELOG.md --- src/lightning/pytorch/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 5616defeffc8a..c794603990737 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593)) +- Support `grad_clip_norm_()` for FSDP ([#20784](https://github.com/Lightning-AI/pytorch-lightning/pull/20784)) ### Changed From 04fbaf1f996f49cfffe60da6c16a97534ffcd1d6 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 16:17:43 -0500 Subject: [PATCH 04/21] Fix args for certain precisions --- src/lightning/pytorch/plugins/precision/deepspeed.py | 1 + src/lightning/pytorch/plugins/precision/precision.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index 9225e3bb9e7be..e09eb67f4fecf 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -141,6 +141,7 @@ def optimizer_step( # type: ignore[override] @override def clip_gradients( self, + module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, diff --git a/src/lightning/pytorch/plugins/precision/precision.py b/src/lightning/pytorch/plugins/precision/precision.py index 08655fafca758..a11182db68f97 100644 --- a/src/lightning/pytorch/plugins/precision/precision.py +++ b/src/lightning/pytorch/plugins/precision/precision.py @@ -143,7 +143,7 @@ def _clip_gradients( def clip_gradients( self, - module: Module, + module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, @@ -161,7 +161,7 @@ def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) - def clip_grad_by_norm(self, module: Module, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None: """Clip gradients by norm.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_norm_(parameters, clip_val) From bce69ca26a1290653b4fd89edc06f5f506631000 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 16:24:10 -0500 Subject: [PATCH 05/21] Standardize precision args --- src/lightning/pytorch/plugins/precision/amp.py | 2 +- src/lightning/pytorch/plugins/precision/fsdp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 6746b5dcd2585..f6ec37e7d4edb 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -101,7 +101,7 @@ def optimizer_step( # type: ignore[override] @override def clip_gradients( self, - module: Module, + module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index bc4b3c0185a85..aec60f4529740 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -82,7 +82,7 @@ def convert_module(self, module: Module) -> Module: return module @override - def clip_grad_by_norm(self, module: Module, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ module.clip_grad_norm_(clip_val) From 0df38f54022088e524c3cbfaffd0bc50a8999afb Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 16:33:03 -0500 Subject: [PATCH 06/21] Guard for typing --- src/lightning/pytorch/plugins/precision/fsdp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index aec60f4529740..899dc1d623564 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -84,6 +84,8 @@ def convert_module(self, module: Module) -> Module: @override def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ + if module is None or not hasattr(module, "clip_grad_norm_") or not isinstance(module.clip_grad_norm_, Callable): + return module.clip_grad_norm_(clip_val) @property From a42b974389d20eb9553cb8982bf7aa66d6459556 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 16:38:54 -0500 Subject: [PATCH 07/21] Fix argument typing --- src/lightning/pytorch/plugins/precision/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 899dc1d623564..1facad738ae85 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -84,7 +84,7 @@ def convert_module(self, module: Module) -> Module: @override def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ - if module is None or not hasattr(module, "clip_grad_norm_") or not isinstance(module.clip_grad_norm_, Callable): + if module is None or not hasattr(module, "clip_grad_norm_") or not isinstance(module.clip_grad_norm_, callable): return module.clip_grad_norm_(clip_val) From ed2fe05ad04c43873d8998b8c7782f3a249430fd Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 16:41:36 -0500 Subject: [PATCH 08/21] Wrap AMP test module in FSDP --- tests/tests_pytorch/plugins/precision/test_amp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index 809d9f19d0706..b009a900446dd 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,6 +14,7 @@ from unittest.mock import Mock import pytest +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn import Module from torch.optim import Optimizer @@ -23,7 +24,7 @@ def test_clip_gradients(): """Test that `.clip_gradients()` is a no-op when clipping is disabled.""" - module = Mock(spec=Module) + module = FSDP(Mock(spec=Module)) optimizer = Mock(spec=Optimizer) precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) precision.clip_grad_by_value = Mock() From 2f62a0a1b7f0c462b693d19b3b7ec3b680c410aa Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 16:51:14 -0500 Subject: [PATCH 09/21] Simplify guard --- src/lightning/pytorch/plugins/precision/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 1facad738ae85..280bc4351f237 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -84,7 +84,7 @@ def convert_module(self, module: Module) -> Module: @override def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ - if module is None or not hasattr(module, "clip_grad_norm_") or not isinstance(module.clip_grad_norm_, callable): + if module is None: return module.clip_grad_norm_(clip_val) From 7f7987e5225b807196ee2dd878c9f7b095e4505e Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 17:08:37 -0500 Subject: [PATCH 10/21] Remove FSDP traces in AMP precision unit test --- tests/tests_pytorch/plugins/precision/test_amp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index b009a900446dd..900892fad5fdd 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,7 +14,6 @@ from unittest.mock import Mock import pytest -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn import Module from torch.optim import Optimizer @@ -24,7 +23,7 @@ def test_clip_gradients(): """Test that `.clip_gradients()` is a no-op when clipping is disabled.""" - module = FSDP(Mock(spec=Module)) + module = Mock(spec=Module) optimizer = Mock(spec=Optimizer) precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) precision.clip_grad_by_value = Mock() @@ -49,8 +48,9 @@ def test_optimizer_amp_scaling_support_in_step_method(): """Test that the plugin checks if the optimizer takes over unscaling in its step, making it incompatible with gradient clipping (example: fused Adam).""" + module = Mock(spec=Module) optimizer = Mock(_step_supports_amp_scaling=True) precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"): - precision.clip_gradients(optimizer, clip_val=1.0) + precision.clip_gradients(module, optimizer, clip_val=1.0) From dee22253d0f672fb417697db17c39283a069a356 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 4 Sep 2025 22:12:21 +0200 Subject: [PATCH 11/21] Apply suggestions from code review --- src/lightning/pytorch/plugins/precision/fsdp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 280bc4351f237..0432805066142 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -86,6 +86,7 @@ def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ if module is None: return + assert isinstance(module.clip_grad_norm_, Module) module.clip_grad_norm_(clip_val) @property From de84676cdede63b3df146eb77d0d70e84eb74efd Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Wed, 10 Sep 2025 10:14:45 -0700 Subject: [PATCH 12/21] Update module.py --- src/lightning/pytorch/core/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 1426b66d5091a..3603f1600e3f6 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1274,7 +1274,7 @@ def clip_gradients( gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm) self.trainer.precision_plugin.clip_gradients( - self.trainer.model, optimizer, gradient_clip_val, gradient_clip_algorithm + optimizer, gradient_clip_val, gradient_clip_algorithm, module=self.trainer.model, ) def configure_gradient_clipping( From 188ca22ce6d8830c5497d48e85445dafe3ff44ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Sep 2025 17:15:05 +0000 Subject: [PATCH 13/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/core/module.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 3603f1600e3f6..d67430c531f5a 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1274,7 +1274,10 @@ def clip_gradients( gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm) self.trainer.precision_plugin.clip_gradients( - optimizer, gradient_clip_val, gradient_clip_algorithm, module=self.trainer.model, + optimizer, + gradient_clip_val, + gradient_clip_algorithm, + module=self.trainer.model, ) def configure_gradient_clipping( From 7c829b6b24aa9d9dd562f55979293266a25d239f Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Wed, 10 Sep 2025 10:15:44 -0700 Subject: [PATCH 14/21] Update amp.py --- src/lightning/pytorch/plugins/precision/amp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 546290978e315..c1e934039ee12 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -101,10 +101,10 @@ def optimizer_step( # type: ignore[override] @override def clip_gradients( self, - module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, + module: Optional[Module] = None, ) -> None: if clip_val > 0 and _optimizer_handles_unscaling(optimizer): raise RuntimeError( @@ -112,7 +112,7 @@ def clip_gradients( " because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?" ) super().clip_gradients( - module=module, optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm + optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm, module=module ) def autocast_context_manager(self) -> torch.autocast: From 161241e2ce230615b2a33b0459170caa8a4d2ee0 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Wed, 10 Sep 2025 10:16:12 -0700 Subject: [PATCH 15/21] Update deepspeed.py --- src/lightning/pytorch/plugins/precision/deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index e09eb67f4fecf..5c80b47cf2b46 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -141,9 +141,9 @@ def optimizer_step( # type: ignore[override] @override def clip_gradients( self, - module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, + module: Optional[Module] = None, ) -> None: """DeepSpeed handles gradient clipping internally.""" From eea0a9478c12536cd055c6bfc8198bea98b622ab Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Wed, 10 Sep 2025 10:18:17 -0700 Subject: [PATCH 16/21] Update fsdp.py --- src/lightning/pytorch/plugins/precision/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 0432805066142..244b4f1e94570 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -82,7 +82,7 @@ def convert_module(self, module: Module) -> Module: return module @override - def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None: + def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float], module: Optional[Module] = None) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ if module is None: return From 181a355bb5ed9b80260f38d3a4d914456dd716d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Sep 2025 17:18:37 +0000 Subject: [PATCH 17/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/plugins/precision/fsdp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 244b4f1e94570..a0d5bbd0618a9 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -82,7 +82,9 @@ def convert_module(self, module: Module) -> Module: return module @override - def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float], module: Optional[Module] = None) -> None: + def clip_grad_by_norm( + self, optimizer: Optimizer, clip_val: Union[int, float], module: Optional[Module] = None + ) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ if module is None: return From 759c4a2b1b52b4866dbb82137684b2f2b17dfa90 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Wed, 10 Sep 2025 10:20:17 -0700 Subject: [PATCH 18/21] Update precision.py --- src/lightning/pytorch/plugins/precision/precision.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/plugins/precision/precision.py b/src/lightning/pytorch/plugins/precision/precision.py index a11182db68f97..a02d348f9b621 100644 --- a/src/lightning/pytorch/plugins/precision/precision.py +++ b/src/lightning/pytorch/plugins/precision/precision.py @@ -143,10 +143,10 @@ def _clip_gradients( def clip_gradients( self, - module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, + module: Optional[Module] = None, ) -> None: """Clips the gradients.""" if clip_val <= 0: @@ -154,14 +154,14 @@ def clip_gradients( if gradient_clip_algorithm == GradClipAlgorithmType.VALUE: self.clip_grad_by_value(optimizer, clip_val) elif gradient_clip_algorithm == GradClipAlgorithmType.NORM: - self.clip_grad_by_norm(module, optimizer, clip_val) + self.clip_grad_by_norm(optimizer, clip_val, module=module) def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: """Clip gradients by value.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) - def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None: + def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float], module: Optional[Module] = None) -> None: """Clip gradients by norm.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_norm_(parameters, clip_val) From 9bc399106a69e08c39dd323eb42d93ff1e9af214 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Sep 2025 17:20:36 +0000 Subject: [PATCH 19/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/plugins/precision/precision.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/plugins/precision/precision.py b/src/lightning/pytorch/plugins/precision/precision.py index a02d348f9b621..6d922ab40577e 100644 --- a/src/lightning/pytorch/plugins/precision/precision.py +++ b/src/lightning/pytorch/plugins/precision/precision.py @@ -161,7 +161,9 @@ def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) - def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float], module: Optional[Module] = None) -> None: + def clip_grad_by_norm( + self, optimizer: Optimizer, clip_val: Union[int, float], module: Optional[Module] = None + ) -> None: """Clip gradients by norm.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_norm_(parameters, clip_val) From 63e9d3aff90dd215639a51afc16704ce368e417c Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Wed, 10 Sep 2025 10:22:03 -0700 Subject: [PATCH 20/21] Update test_amp.py --- tests/tests_pytorch/plugins/precision/test_amp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index 924213661735d..acf238216184b 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -29,18 +29,18 @@ def test_clip_gradients(): precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) precision.clip_grad_by_value = Mock() precision.clip_grad_by_norm = Mock() - precision.clip_gradients(module, optimizer) + precision.clip_gradients(optimizer, module=module) precision.clip_grad_by_value.assert_not_called() precision.clip_grad_by_norm.assert_not_called() - precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE) + precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE, module=module) precision.clip_grad_by_value.assert_called_once() precision.clip_grad_by_norm.assert_not_called() precision.clip_grad_by_value.reset_mock() precision.clip_grad_by_norm.reset_mock() - precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM) + precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM, module=module) precision.clip_grad_by_value.assert_not_called() precision.clip_grad_by_norm.assert_called_once() @@ -54,7 +54,7 @@ def test_optimizer_amp_scaling_support_in_step_method(): precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"): - precision.clip_gradients(module, optimizer, clip_val=1.0) + precision.clip_gradients(optimizer, clip_val=1.0, module=module) def test_amp_with_no_grad(): From 46fb1b5918cbb9e9b1e93aeb4bacda24e2bb7c9a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Sep 2025 17:22:26 +0000 Subject: [PATCH 21/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/plugins/precision/test_amp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index acf238216184b..2a74e0b5f0e5e 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -33,7 +33,9 @@ def test_clip_gradients(): precision.clip_grad_by_value.assert_not_called() precision.clip_grad_by_norm.assert_not_called() - precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE, module=module) + precision.clip_gradients( + optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE, module=module + ) precision.clip_grad_by_value.assert_called_once() precision.clip_grad_by_norm.assert_not_called()