Skip to content

Commit

Permalink
Update code
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorSusmelj committed Jan 11, 2024
1 parent dc99515 commit 8b19156
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions benchmarks/imagenet/resnet50/wmse.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,14 @@ def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
resnet = resnet50()
resnet.fc = Identity() # Ignore classification head
self.backbone = resnet
self.projection_head = WMSEProjectionHead()
self.criterion_WMSE4loss = WMSELoss(num_samples=4)

# we use a projection head with output dimension 64
# and w_size of 128 to support a batch size of 256
self.projection_head = WMSEProjectionHead(output_dim=64)

self.criterion_WMSE4loss = WMSELoss(
w_size=128, embedding_dim=64, num_samples=4, gather_distributed=True
)

self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)

Expand All @@ -39,8 +45,7 @@ def training_step(
views, targets = batch[0], batch[1]
features = self.forward(torch.cat(views)).flatten(start_dim=1)
z = self.projection_head(features)
z0, z1 = z.chunk(len(views))
loss = self.criterion_WMSE4loss(z0, z1)
loss = self.criterion_WMSE4loss(z)
self.log(
"train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
)
Expand Down

0 comments on commit 8b19156

Please sign in to comment.