Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
philippmwirth committed Nov 20, 2023
1 parent c916d1e commit 4c57c10
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions lightly/models/modules/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,11 @@ def get_updated_group_features(self, x: torch.Tensor) -> torch.Tensor:
mask = assignments == assigned_class
group_features[assigned_class] = self.beta * self.group_features[
assigned_class
] + (1 - self.beta) * x[mask].mean(
] + (1 - self.beta) * x[
mask
].mean( # type: ignore
axis=0
) # type: ignore
)

return group_features

Expand Down
2 changes: 1 addition & 1 deletion lightly/models/modules/nn_memory_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lightly.loss.memory_bank import MemoryBankModule


class NNMemoryBankModule(MemoryBankModule): # type: ignore # Cannot subclass type Any.
class NNMemoryBankModule(MemoryBankModule):
"""Nearest Neighbour Memory Bank implementation
This class implements a nearest neighbour memory bank as described in the
Expand Down

0 comments on commit 4c57c10

Please sign in to comment.