Skip to content

Conversation

@KylevdLangemheen
Copy link
Contributor

As per #1554 this PR starts on an implementation for SupCon loss.

The officially referenced pytorch implementation does not yet support multi-gpu, but the official tensorflow implementation does.

Currently implemented is support for all three contrast modes under the definition for $\mathcal{L}^{sup}_{out}$ (equation 2 in https://arxiv.org/abs/2004.11362). There is not yet support for capping the number of positives used.

Currently implemented is also two very basic tests, one which just runs the loss with some random features and labels, and one which compares the output of this implementation to the existing NTXentLoss when labels is None. More tests are definitely needed, and the implementation is not final (and likely still has some bugs).

Note: This method could be expanded with an altered version of a memory bank which also stores labels.

@yutong-xiang-97
Copy link
Contributor

Hi @KylevdLangemheen, thank you for your contribution! Will have a look and give you feedback on how to proceed.

@codecov
Copy link

codecov bot commented Aug 13, 2025

Codecov Report

❌ Patch coverage is 91.83673% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 86.14%. Comparing base (ee30cd4) to head (f7d94fb).

Files with missing lines Patch % Lines
lightly/loss/supcon_loss.py 91.66% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1877      +/-   ##
==========================================
+ Coverage   86.10%   86.14%   +0.03%     
==========================================
  Files         168      169       +1     
  Lines        6979     7028      +49     
==========================================
+ Hits         6009     6054      +45     
- Misses        970      974       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@yutong-xiang-97 yutong-xiang-97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original SupConLoss implementation we're refering to was highly experimental and confusing in the author's own words. Most importantly, it has some critical issues pointed out below that makes it unsuitable to be integrated into LightlySSL.

Fortunately, I have a much simpler and working version I rewrote for my master's thesis partly inspired by the author's improved version, but without distributed support. We can try starting from there and add distributed support.

KylevdLangemheen added a commit to KylevdLangemheen/lightly that referenced this pull request Aug 25, 2025
@KylevdLangemheen KylevdLangemheen marked this pull request as ready for review August 25, 2025 17:30
@KylevdLangemheen
Copy link
Contributor Author

Hi! I have redone the loss from scratch, basing it off of your implementation! Is it online such that I can reference it? I have also included it in the test.

Let me know if this way of doing distributed is correct. It's slightly different from how it's handled in NTXentLoss, but conceptually this made sense to me. Is there a way to test it even if you only have a single GPU, or do I need to spin up a multi-gpu instance 🤔

p.s. I added temperature rescaling as an optional parameter in order to compare it to the NTXentLoss. I can also remove it altogether.

@liopeer
Copy link
Contributor

liopeer commented Aug 25, 2025

Great, thanks Kyle!

Either I or @yutong-xiang-97 will have a look very soon. About the distributed testing: We don't do it in the unit tests, but you can do it locally also without any GPUs but in a multi-process CPU setting. Example code below (important is mainly that the "gloo" backend is used, in contrast to "nccl" for CUDA GPUs).

# dist_train_cpu.py
import os
from argparse import ArgumentParser
import contextlib

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

MASTER_ADDR = "localhost"
MASTER_PORT = "12355"

@contextlib.contextmanager
def setup_dist(rank: int, world_size: int):
    try:
        os.environ['MASTER_ADDR'] = MASTER_ADDR
        os.environ['MASTER_PORT'] = MASTER_PORT
        dist.init_process_group("gloo", rank=rank, world_size=world_size)
        yield
    finally:
        dist.destroy_process_group()

def train_dist(rank: int, world_size: int) -> None:
    # Setup the process group.
    with setup_dist(rank, world_size):
        # insert a test here
        pass

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--world-size", type=int, required=True)
    args = parser.parse_args()

    mp.spawn(
        train_dist, 
        args=(args.world_size), 
        nprocs=args.world_size
    )

Copy link
Contributor

@yutong-xiang-97 yutong-xiang-97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job @KylevdLangemheen ! Much appreciated. However, some changes are needed for the tests and docstrings. Please check my comments.


Examples:
>>> # initialize loss function without memory bank
>>> loss_fn = NTXentLoss(memory_bank_size=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the docstring with the example for SupConLoss

) -> Tensor:
"""Forward pass through Supervised Contrastive Loss.

Computes the loss based on contrast_mode setting.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no contrast mode setting anymore


def __init__(
self,
temperature: float = 0.5,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default temperature in the original paper is 0.1

ntxent_loss = ntxent(out1, out2)
assert (supcon_loss - ntxent_loss).pow(2).item() == pytest.approx(0.0)

@pytest.mark.parametrize("labels", [[0, 0, 0, 0], [0, 1, 1, 1], [0, 1, 2, 3]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you shouldn't compare to my implementation. Your code is basically a rewritten version of mine that is not legitimately tested (even if it is correct). I would instead compare against plain code rewritten in NumPy like done here:

def _calc_ntxent_loss_manual(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants