Skip to content

Commit

Permalink
Add tico model (#1398)
Browse files Browse the repository at this point in the history
* Add tico model

* Fix view

* Fix wrong hyperparam

* Fix hyperparam and make naming consistent

* Fix wrong loss

* Minor changes for debugging

* Cleanup

* Log individual losses. Detach B.

* Fix issues in code. Remove debugging logs.

* Add TiCo benchmarks results and checkpoints

* Update codebase and use naming from paper.

* Remove misleading comment about parameters for lr.
  • Loading branch information
IgorSusmelj authored Jan 11, 2024
1 parent 67dc269 commit deb3c31
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 54 deletions.
90 changes: 43 additions & 47 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

![Lightly SSL self-supervised learning Logo](docs/logos/lightly_SSL_logo_crop.png)

![GitHub](https://img.shields.io/github/license/lightly-ai/lightly)
Expand All @@ -8,7 +7,6 @@
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Discord](https://img.shields.io/discord/752876370337726585?logo=discord&logoColor=white&label=discord&color=7289da)](https://discord.gg/xvNJW94)


Lightly SSL is a computer vision framework for self-supervised learning.

- [Documentation](https://docs.lightly.ai/self-supervised-learning/)
Expand All @@ -20,7 +18,6 @@ and [data curation](https://docs.lightly.ai/docs/what-is-lightly). If you're int
Lightly Worker Solution to easily process millions of samples and run [powerful algorithms](https://docs.lightly.ai/docs/selection)
on your data, check out [lightly.ai](https://www.lightly.ai). It's free to get started!


## Features

This self-supervised learning framework offers the following features:
Expand All @@ -31,7 +28,6 @@ This self-supervised learning framework offers the following features:
- Supports custom backbone models for self-supervised pre-training.
- Support for distributed training using PyTorch Lightning.


### Supported Models

You can [find sample code for all the supported models here.](https://docs.lightly.ai/self-supervised-learning/examples/models.html) We provide PyTorch, PyTorch Lightning,
Expand All @@ -57,7 +53,6 @@ and PyTorch Lightning distributed examples for all models to kickstart your proj
- VICReg, 2022 [paper](https://arxiv.org/abs/2105.04906) [docs](https://docs.lightly.ai/self-supervised-learning/examples/vicreg.html)
- VICRegL, 2022 [paper](https://arxiv.org/abs/2210.01571) [docs](https://docs.lightly.ai/self-supervised-learning/examples/vicregl.html)


## Tutorials

Want to jump to the tutorials and see Lightly in action?
Expand All @@ -72,7 +67,6 @@ Community and partner projects:

- [On-Device Deep Learning with Lightly on an ARM microcontroller](https://github.com/ARM-software/EndpointAI/tree/master/ProofOfConcepts/Vision/OpenMvMaskDefaults)


## Quick Start

Lightly requires **Python 3.6+** but we recommend using **Python 3.7+**. We recommend installing Lightly in a **Linux** or **OSX** environment.
Expand All @@ -88,15 +82,16 @@ Lightly is compatible with PyTorch and PyTorch Lightning v2.0+!
Vision transformer based models require Torchvision v0.12+.

### Installation

You can install Lightly and its dependencies from PyPI with:

```
pip3 install lightly
```

We strongly recommend that you install Lightly in a dedicated virtualenv, to avoid
conflicting with your system packages.


### Lightly in Action

With Lightly, you can use the latest self-supervised learning methods in a modular
Expand Down Expand Up @@ -199,7 +194,6 @@ criterion = loss.NegativeCosineSimilarity()

You can [find a more complete example for SimSiam here.](https://docs.lightly.ai/self-supervised-learning/examples/simsiam.html)


Use PyTorch Lightning to train the model:

```python
Expand Down Expand Up @@ -239,6 +233,7 @@ trainer.fit(model, dataloader)
See [our docs for a full PyTorch Lightning example.](https://docs.lightly.ai/self-supervised-learning/examples/simclr.html)

Or train the model on 4 GPUs:

```python

# Use distributed version of loss functions.
Expand All @@ -258,54 +253,51 @@ trainer.fit(model, dataloader)
We provide multi-GPU training examples with distributed gather and synchronized BatchNorm.
[Have a look at our docs regarding distributed training.](https://docs.lightly.ai/self-supervised-learning/getting_started/distributed_training.html)


## Benchmarks

Implemented models and their performance on various datasets. Hyperparameters are not
tuned for maximum accuracy. For detailed results and more information about the benchmarks click
[here](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html).


### ImageNet1k

[ImageNet1k benchmarks](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#imagenet1k)

**Note**: Evaluation settings are based on these papers:
* Linear: [SimCLR](https://arxiv.org/abs/2002.05709)
* Finetune: [SimCLR](https://arxiv.org/abs/2002.05709)
* KNN: [InstDisc](https://arxiv.org/abs/1805.01978)

See the [benchmarking scripts](./benchmarks/imagenet/resnet50/) for details.

- Linear: [SimCLR](https://arxiv.org/abs/2002.05709)
- Finetune: [SimCLR](https://arxiv.org/abs/2002.05709)
- KNN: [InstDisc](https://arxiv.org/abs/1805.01978)

| Model | Backbone | Batch Size | Epochs | Linear Top1 | Finetune Top1 | kNN Top1 | Tensorboard | Checkpoint |
|----------------|----------|------------|--------|-------------|---------------|----------|-------------|------------|
| BarlowTwins | Res50 | 256 | 100 | 62.9 | 72.6 | 45.6 | [link](https://tensorboard.dev/experiment/NxyNRiQsQjWZ82I9b0PvKg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_barlowtwins_2023-08-18_00-11-03/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| BYOL | Res50 | 256 | 100 | 62.4 | 74.0 | 45.6 | [link](https://tensorboard.dev/experiment/Z0iG2JLaTJe5nuBD7DK1bg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_byol_2023-07-10_10-37-32/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| DINO | Res50 | 128 | 100 | 68.2 | 72.5 | 49.9 | [link](https://tensorboard.dev/experiment/DvKHX9sNSWWqDrRksllPLA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dino_2023-06-06_13-59-48/pretrain/version_0/checkpoints/epoch%3D99-step%3D1000900.ckpt) |
| MoCoV2 | Res50 | 256 | 100 | 61.5 | 74.4 | 41.3 | - | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_mocov2_2023-12-06_15-06-19/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR* | Res50 | 256 | 100 | 63.2 | 73.9 | 44.8 | [link](https://tensorboard.dev/experiment/Ugol97adQdezgcVibDYMMA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR* + DCL | Res50 | 256 | 100 | 65.1 | 73.5 | 49.6 | [link](https://tensorboard.dev/experiment/k4ZonZ77QzmBkc0lXswQlg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dcl_2023-07-04_16-51-40/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR* + DCLW | Res50 | 256 | 100 | 64.5 | 73.2 | 48.5 | [link](https://tensorboard.dev/experiment/TrALnpwFQ4OkZV3uvaX7wQ/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dclw_2023-07-07_14-57-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SwAV | Res50 | 256 | 100 | 67.2 | 75.4 | 49.5 | [link](https://tensorboard.dev/experiment/Ipx4Oxl5Qkqm5Sl5kWyKKg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_swav_2023-05-25_08-29-14/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| VICReg | Res50 | 256 | 100 | 63.0 | 73.7 | 46.3 | [link](https://tensorboard.dev/experiment/qH5uywJbTJSzgCEfxc7yUw) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_vicreg_2023-09-11_10-53-08/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
See the [benchmarking scripts](./benchmarks/imagenet/resnet50/) for details.

*\*We use square root learning rate scaling instead of linear scaling as it yields
better results for smaller batch sizes. See Appendix B.1 in the [SimCLR paper](https://arxiv.org/abs/2002.05709).*
| Model | Backbone | Batch Size | Epochs | Linear Top1 | Finetune Top1 | kNN Top1 | Tensorboard | Checkpoint |
| --------------- | -------- | ---------- | ------ | ----------- | ------------- | -------- | ------------------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| BarlowTwins | Res50 | 256 | 100 | 62.9 | 72.6 | 45.6 | [link](https://tensorboard.dev/experiment/NxyNRiQsQjWZ82I9b0PvKg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_barlowtwins_2023-08-18_00-11-03/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| BYOL | Res50 | 256 | 100 | 62.4 | 74.0 | 45.6 | [link](https://tensorboard.dev/experiment/Z0iG2JLaTJe5nuBD7DK1bg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_byol_2023-07-10_10-37-32/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| DINO | Res50 | 128 | 100 | 68.2 | 72.5 | 49.9 | [link](https://tensorboard.dev/experiment/DvKHX9sNSWWqDrRksllPLA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dino_2023-06-06_13-59-48/pretrain/version_0/checkpoints/epoch%3D99-step%3D1000900.ckpt) |
| MoCoV2 | Res50 | 256 | 100 | 61.5 | 74.4 | 41.3 | - | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_mocov2_2023-12-06_15-06-19/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR\* | Res50 | 256 | 100 | 63.2 | 73.9 | 44.8 | [link](https://tensorboard.dev/experiment/Ugol97adQdezgcVibDYMMA) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR\* + DCL | Res50 | 256 | 100 | 65.1 | 73.5 | 49.6 | [link](https://tensorboard.dev/experiment/k4ZonZ77QzmBkc0lXswQlg/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dcl_2023-07-04_16-51-40/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SimCLR\* + DCLW | Res50 | 256 | 100 | 64.5 | 73.2 | 48.5 | [link](https://tensorboard.dev/experiment/TrALnpwFQ4OkZV3uvaX7wQ/) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dclw_2023-07-07_14-57-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| SwAV | Res50 | 256 | 100 | 67.2 | 75.4 | 49.5 | [link](https://tensorboard.dev/experiment/Ipx4Oxl5Qkqm5Sl5kWyKKg) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_swav_2023-05-25_08-29-14/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |
| TiCo | Res50 | 256 | 100 | 49.7 | 72.7 | 26.6 | - | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_tico_2024-01-07_18-40-57/pretrain/version_0/checkpoints/epoch%3D99-step%3D250200.ckpt) |
| VICReg | Res50 | 256 | 100 | 63.0 | 73.7 | 46.3 | [link](https://tensorboard.dev/experiment/qH5uywJbTJSzgCEfxc7yUw) | [link](https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_vicreg_2023-09-11_10-53-08/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt) |

_\*We use square root learning rate scaling instead of linear scaling as it yields
better results for smaller batch sizes. See Appendix B.1 in the [SimCLR paper](https://arxiv.org/abs/2002.05709)._

### ImageNet100
[ImageNet100 benchmarks](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#imagenet100)

[ImageNet100 benchmarks detailed results](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#imagenet100)

### Imagenette

[Imagenette benchmarks](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#imagenette)

[Imagenette benchmarks detailed results](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#imagenette)

### CIFAR-10

[CIFAR-10 benchmarks](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#cifar-10)

[CIFAR-10 benchmarks detailed results](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html#cifar-10)

## Terminology

Expand All @@ -314,53 +306,54 @@ The terms in bold are explained in more detail in our [documentation](https://do

<img src="/docs/source/getting_started/images/lightly_overview.png" alt="Overview of the Lightly pip package"/></a>


### Next Steps

Head to the [documentation](https://docs.lightly.ai/self-supervised-learning/) and see the things you can achieve with Lightly!


## Development

To install dev dependencies (for example to contribute to the framework) you can use the following command:

```
pip3 install -e ".[dev]"
```

For more information about how to contribute have a look [here](CONTRIBUTING.md).


### Running Tests

Unit tests are within the [tests directory](tests/) and we recommend running them using
Unit tests are within the [tests directory](tests/) and we recommend running them using
[pytest](https://docs.pytest.org/en/stable/). There are two test configurations
available. By default, only a subset will be run:

```
make test-fast
```

To run all tests (including the slow ones) you can use the following command:

```
make test
```

To test a specific file or directory use:

```
pytest <path to file or directory>
```


### Code Formatting

To format code with [black](https://black.readthedocs.io/en/stable/) and [isort](https://docs.pytest.org) run:

```
make format
```


## Further Reading

**Self-Supervised Learning**:

- Have a look at our [#papers channel on discord](https://discord.com/channels/752876370337726585/815153188487299083)
for the newest self-supervised learning papers.
- [A Cookbook of Self-Supervised Learning, 2023](https://arxiv.org/abs/2304.12210)
Expand All @@ -371,26 +364,27 @@ make format
- [A Simple Framework for Contrastive Learning of Visual Representations, 2020](https://arxiv.org/abs/2002.05709)
- [Momentum Contrast for Unsupervised Visual Representation Learning, 2020](https://arxiv.org/abs/1911.05722)


## FAQ

- Why should I care about self-supervised learning? Aren't pre-trained models from ImageNet much better for transfer learning?
- Self-supervised learning has become increasingly popular among scientists over the last years because the learned representations perform extraordinarily well on downstream tasks. This means that they capture the important information in an image better than other types of pre-trained models. By training a self-supervised model on *your* dataset, you can make sure that the representations have all the necessary information about your images.

- Self-supervised learning has become increasingly popular among scientists over the last years because the learned representations perform extraordinarily well on downstream tasks. This means that they capture the important information in an image better than other types of pre-trained models. By training a self-supervised model on _your_ dataset, you can make sure that the representations have all the necessary information about your images.

- How can I contribute?

- Create an issue if you encounter bugs or have ideas for features we should implement. You can also add your own code by forking this repository and creating a PR. More details about how to contribute with code is in our [contribution guide](CONTRIBUTING.md).

- Is this framework for free?

- Yes, this framework is completely free to use and we provide the source code. We believe that we need to make training deep learning models more data efficient to achieve widespread adoption. One step to achieve this goal is by leveraging self-supervised learning. The company behind Lightly is committed to keep this framework open-source.

- If this framework is free, how is the company behind Lightly making money?
- Training self-supervised models is only one part of our solution.
[The company behind Lightly](https://lightly.ai/) focuses on processing and analyzing embeddings created by self-supervised models.
By building, what we call a self-supervised active learning loop we help companies understand and work with their data more efficiently.
As the [Lightly Solution](https://docs.lightly.ai) is a freemium product, you can try it out for free. However, we will charge for some features.
- Training self-supervised models is only one part of our solution.
[The company behind Lightly](https://lightly.ai/) focuses on processing and analyzing embeddings created by self-supervised models.
By building, what we call a self-supervised active learning loop we help companies understand and work with their data more efficiently.
As the [Lightly Solution](https://docs.lightly.ai) is a freemium product, you can try it out for free. However, we will charge for some features.
- In any case this framework will always be free to use, even for commercial purposes.


## Lightly in Research

- [Reverse Engineering Self-Supervised Learning, 2023](https://arxiv.org/abs/2305.15614)
Expand All @@ -401,7 +395,8 @@ make format
- [solo-learn: A Library of Self-supervised Methods for Visual Representation Learning, 2021](https://www.jmlr.org/papers/volume23/21-1155/21-1155.pdf)

## Company behind this Open Source Framework
[Lightly](https://www.lightly.ai) is a spin-off from ETH Zurich that helps companies

[Lightly](https://www.lightly.ai) is a spin-off from ETH Zurich that helps companies
build efficient active learning pipelines to select the most relevant data for their models.

You can find out more about the company and it's services by following the links below:
Expand All @@ -411,6 +406,7 @@ You can find out more about the company and it's services by following the links
- [Lightly Solution Documentation (Lightly Worker & API)](https://docs.lightly.ai/)

## BibTeX

If you want to cite the framework feel free to use this:

```bibtex
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/imagenet/resnet50/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import mocov2
import simclr
import swav
import tico
import torch
import vicreg
from pytorch_lightning import LightningModule, Trainer
Expand Down Expand Up @@ -61,6 +62,7 @@
"mocov2": {"model": mocov2.MoCoV2, "transform": mocov2.transform},
"simclr": {"model": simclr.SimCLR, "transform": simclr.transform},
"swav": {"model": swav.SwAV, "transform": swav.transform},
"tico": {"model": tico.TiCo, "transform": tico.transform},
"vicreg": {"model": vicreg.VICReg, "transform": vicreg.transform},
}

Expand Down
Loading

0 comments on commit deb3c31

Please sign in to comment.