Skip to content

Commit 7ed9822

Browse files
Hsieh, KevinGitHub Enterprise
Hsieh, Kevin
authored and
GitHub Enterprise
committed
Fix Lora scaling device issue
Signed-off-by: Kevin Hsieh <[email protected]>
1 parent 31fa4c3 commit 7ed9822

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

TrainingExtensions/torch/src/python/aimet_torch/peft.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,14 @@ def __init__(self, lora_layer: PeftLoraLayer):
7171
self.base_layer = lora_layer.base_layer
7272
self.r = lora_layer.r
7373
self.lora_alpha = lora_layer.lora_alpha
74-
self.scaling = [
75-
torch.nn.Parameter(torch.as_tensor(scale), requires_grad=False).to(
76-
self.base_layer.weight.device
77-
)
78-
for scale in lora_layer.scaling.values()
79-
]
74+
self.scaling = torch.nn.ParameterList(
75+
[
76+
torch.nn.Parameter(torch.as_tensor(scale), requires_grad=False).to(
77+
self.base_layer.weight.device
78+
)
79+
for scale in lora_layer.scaling.values()
80+
]
81+
)
8082
self.lora_dropout = nn.ModuleList({})
8183
self.adapter_name_to_index = {}
8284
self.index_to_adapter_name = {}

TrainingExtensions/torch/test/python/test_peft.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,10 +380,23 @@ def forward_pass(model, forward_pass_callback=None):
380380
]
381381
assert sorted(tensor_name) == sorted(tensors)
382382

383+
@pytest.mark.cuda
384+
def test_changing_lora_device(self):
385+
model = one_adapter_model().cuda()
386+
387+
replace_lora_layers_with_quantizable_layers(model)
388+
dummy_inputs = torch.randn(10, 10).cuda()
389+
390+
_ = model(dummy_inputs)
391+
392+
model.cpu()
393+
394+
_ = model(dummy_inputs.cpu())
395+
383396

384397
def _is_frozen(quantizer):
385398
return (
386-
quantizer._allow_overwrite == False
387-
and quantizer.min.requires_grad == False
388-
and quantizer.max.requires_grad == False
399+
not quantizer._allow_overwrite
400+
and not quantizer.min.requires_grad
401+
and not quantizer.max.requires_grad
389402
)

0 commit comments

Comments
 (0)