An Pytorch implementation of SimCLR. SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
SimCLR is a framework for contrastive learning of visual representations, it includes three main modules:
- Generate two different data augmentations. (Implementated in "simclr/augmentation_simclr.py"
- A deep model to output learned representations of the two augmentations. (Implemented in "models/simclr_backbone.py")
- Maximize the aggrement of learned representations. (Implemented in "simclr/contrastive_loss.py")
Simply run the following code to pre-train an encoder using SimCLR:
python train_simclr.py --dataset cifar10 --backbone resnet18 --projection_size 128
Simply run the following code to train a Logistic Regression classifier using the features generated by the pre-trained model
python train_linear_model.py --dataset cifar10 --backbone resnet18 --projection_size 128 --resume ./pretrain_results/resnet18_cifar10_model.pth.tar
python = 3.7.6
pytorch = 1.4.0
torchvision
