Skip to content

check param is of nn.Parameter type for pruning sanitization #20783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/lightning/pytorch/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,10 @@ def sanitize_parameters_to_prune(

if not parameters_to_prune:
parameters_to_prune = [
(m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None
(m, p)
for p in parameters
for m in current_modules
if getattr(m, p, None) is not None and isinstance(getattr(m, p, None), nn.Parameter)
]
elif (
isinstance(parameters_to_prune, (list, tuple))
Expand Down
67 changes: 67 additions & 0 deletions tests/tests_pytorch/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,70 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
assert not hasattr(model.layer.mlp_3, "weight_orig")
model = TestModel.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")


def test_sanitize_parameters_explicit_check():
"""Test the sanitize_parameters_to_prune method with various attribute types."""

class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(5, 5))
self.bias = nn.Parameter(torch.randn(5))
self.some_bool = True
self.some_tensor = torch.randn(3, 3) # Regular tensor, not parameter
self.some_string = "test"
self.some_none = None

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.test_module = TestModule()

model = TestModel()

parameters_to_prune = ModelPruning.sanitize_parameters_to_prune(
model,
parameters_to_prune=(),
parameter_names=["weight", "bias", "some_bool", "some_tensor", "some_string", "some_none"],
)

param_names_found = set()
for module, param_name in parameters_to_prune:
param = getattr(module, param_name)
assert isinstance(param, nn.Parameter), f"Expected Parameter, got {type(param)}"
param_names_found.add(param_name)

assert "weight" in param_names_found
assert "bias" in param_names_found
assert "some_bool" not in param_names_found
assert "some_tensor" not in param_names_found
assert "some_string" not in param_names_found
assert "some_none" not in param_names_found


def test_original_issue_reproduction():
"""Issue: https://github.com/Lightning-AI/pytorch-lightning/issues/10835."""

class ProblematicModel(BoringModel):
def __init__(self):
super().__init__()
self.layer = Sequential(
OrderedDict([
("mlp_1", nn.Linear(32, 32)),
("mlp_2", nn.Linear(32, 2)),
])
)
# Add boolean attributes that would cause the original error
self.layer.mlp_1.training = True
self.layer.mlp_2.requires_grad = True

model = ProblematicModel()

parameters_to_prune = ModelPruning.sanitize_parameters_to_prune(
model, parameters_to_prune=(), parameter_names=["weight", "bias", "training", "requires_grad"]
)

for module, param_name in parameters_to_prune:
param = getattr(module, param_name)
assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}"
Loading