Skip to content

Commit

Permalink
fix bug in distillation function
Browse files Browse the repository at this point in the history
  • Loading branch information
Changbin ZHANG 张长彬 authored and Changbin ZHANG 张长彬 committed Apr 28, 2022
1 parent 79a992d commit 7a3172c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def features_distillation8(
nb_new_classes=1
):
loss = torch.tensor(0.).to(list_attentions_a[0].device)
list_attentions_a = list_attentions_a[:-1]
list_attentions_b = list_attentions_b[:-1]
# list_attentions_a = list_attentions_a[:-1]
# list_attentions_b = list_attentions_b[:-1]
for i, (a, b) in enumerate(zip(list_attentions_a, list_attentions_b)):
n, c, h, w = a.shape
layer_loss = torch.tensor(0.).to(a.device)
Expand Down Expand Up @@ -82,8 +82,8 @@ def features_distillation_channel(
nb_new_classes=1
):
loss = torch.tensor(0.).to(list_attentions_a[0].device)
list_attentions_a = list_attentions_a[:-2]
list_attentions_b = list_attentions_b[:-2]
list_attentions_a = list_attentions_a[:-1]
list_attentions_b = list_attentions_b[:-1]
for i, (a, b) in enumerate(zip(list_attentions_a, list_attentions_b)):
n, c, h, w = a.shape
layer_loss = torch.tensor(0.).to(a.device)
Expand Down

0 comments on commit 7a3172c

Please sign in to comment.