-
Notifications
You must be signed in to change notification settings - Fork 313
Start on SupCon loss #1554 #1877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Start on SupCon loss #1554 #1877
Conversation
|
Hi @KylevdLangemheen, thank you for your contribution! Will have a look and give you feedback on how to proceed. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1877 +/- ##
==========================================
+ Coverage 86.10% 86.14% +0.03%
==========================================
Files 168 169 +1
Lines 6979 7028 +49
==========================================
+ Hits 6009 6054 +45
- Misses 970 974 +4 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original SupConLoss implementation we're refering to was highly experimental and confusing in the author's own words. Most importantly, it has some critical issues pointed out below that makes it unsuitable to be integrated into LightlySSL.
Fortunately, I have a much simpler and working version I rewrote for my master's thesis partly inspired by the author's improved version, but without distributed support. We can try starting from there and add distributed support.
8f0e909 to
7a5a73d
Compare
|
Hi! I have redone the loss from scratch, basing it off of your implementation! Is it online such that I can reference it? I have also included it in the test. Let me know if this way of doing distributed is correct. It's slightly different from how it's handled in NTXentLoss, but conceptually this made sense to me. Is there a way to test it even if you only have a single GPU, or do I need to spin up a multi-gpu instance 🤔 p.s. I added temperature rescaling as an optional parameter in order to compare it to the NTXentLoss. I can also remove it altogether. |
|
Great, thanks Kyle! Either I or @yutong-xiang-97 will have a look very soon. About the distributed testing: We don't do it in the unit tests, but you can do it locally also without any GPUs but in a multi-process CPU setting. Example code below (important is mainly that the # dist_train_cpu.py
import os
from argparse import ArgumentParser
import contextlib
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
MASTER_ADDR = "localhost"
MASTER_PORT = "12355"
@contextlib.contextmanager
def setup_dist(rank: int, world_size: int):
try:
os.environ['MASTER_ADDR'] = MASTER_ADDR
os.environ['MASTER_PORT'] = MASTER_PORT
dist.init_process_group("gloo", rank=rank, world_size=world_size)
yield
finally:
dist.destroy_process_group()
def train_dist(rank: int, world_size: int) -> None:
# Setup the process group.
with setup_dist(rank, world_size):
# insert a test here
pass
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--world-size", type=int, required=True)
args = parser.parse_args()
mp.spawn(
train_dist,
args=(args.world_size),
nprocs=args.world_size
) |
7a5a73d to
f7d94fb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job @KylevdLangemheen ! Much appreciated. However, some changes are needed for the tests and docstrings. Please check my comments.
|
|
||
| Examples: | ||
| >>> # initialize loss function without memory bank | ||
| >>> loss_fn = NTXentLoss(memory_bank_size=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the docstring with the example for SupConLoss
| ) -> Tensor: | ||
| """Forward pass through Supervised Contrastive Loss. | ||
|
|
||
| Computes the loss based on contrast_mode setting. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no contrast mode setting anymore
|
|
||
| def __init__( | ||
| self, | ||
| temperature: float = 0.5, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default temperature in the original paper is 0.1
| ntxent_loss = ntxent(out1, out2) | ||
| assert (supcon_loss - ntxent_loss).pow(2).item() == pytest.approx(0.0) | ||
|
|
||
| @pytest.mark.parametrize("labels", [[0, 0, 0, 0], [0, 1, 1, 1], [0, 1, 2, 3]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you shouldn't compare to my implementation. Your code is basically a rewritten version of mine that is not legitimately tested (even if it is correct). I would instead compare against plain code rewritten in NumPy like done here:
lightly/tests/loss/test_ntx_ent_loss.py
Line 151 in ee30cd4
| def _calc_ntxent_loss_manual( |
As per #1554 this PR starts on an implementation for SupCon loss.
The officially referenced pytorch implementation does not yet support multi-gpu, but the official tensorflow implementation does.
Currently implemented is support for all three contrast modes under the definition for$\mathcal{L}^{sup}_{out}$ (equation 2 in https://arxiv.org/abs/2004.11362). There is not yet support for capping the number of positives used.
Currently implemented is also two very basic tests, one which just runs the loss with some random features and labels, and one which compares the output of this implementation to the existing NTXentLoss when labels is None. More tests are definitely needed, and the implementation is not final (and likely still has some bugs).
Note: This method could be expanded with an altered version of a memory bank which also stores labels.