Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
philippmwirth committed Nov 20, 2023
1 parent 4c57c10 commit abbb9cd
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 2 additions & 4 deletions lightly/models/modules/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,9 @@ def get_updated_group_features(self, x: torch.Tensor) -> torch.Tensor:
mask = assignments == assigned_class
group_features[assigned_class] = self.beta * self.group_features[
assigned_class
] + (1 - self.beta) * x[
mask
].mean( # type: ignore
] + (1 - self.beta) * x[mask].mean(
axis=0
)
) # type: ignore

return group_features

Expand Down
6 changes: 4 additions & 2 deletions lightly/models/modules/nn_memory_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ class NNMemoryBankModule(MemoryBankModule):
Attributes:
size:
Number of keys the memory bank can store. If set to 0,
memory bank is not used.
Number of keys the memory bank can store.
Examples:
>>> model = NNCLR(backbone)
Expand All @@ -40,6 +39,8 @@ class NNMemoryBankModule(MemoryBankModule):
"""

def __init__(self, size: int = 2**16):
if size <= 0:
raise ValueError(f"Memory bank size must be positive, got {size}.")

Check warning on line 43 in lightly/models/modules/nn_memory_bank.py

View check run for this annotation

Codecov / codecov/patch

lightly/models/modules/nn_memory_bank.py#L43

Added line #L43 was not covered by tests
super(NNMemoryBankModule, self).__init__(size)

def forward(
Expand All @@ -57,6 +58,7 @@ def forward(
"""

output, bank = super(NNMemoryBankModule, self).forward(output, update=update)
assert bank is not None
bank = bank.to(output.device).t()

output_normed = torch.nn.functional.normalize(output, dim=1)
Expand Down

0 comments on commit abbb9cd

Please sign in to comment.