Giter Club home page Giter Club logo

simclr's Introduction

SimCLR

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations by T. Chen et al. Including support for:

  • Distributed data parallel training
  • Global batch normalization
  • LARS (Layer-wise Adaptive Rate Scaling) optimizer.

Link to paper

Open SimCLR in Google Colab Notebook (with TPU support)

Open In Colab

Open SimCLR results comparison on tensorboard.dev:

Quickstart (fine-tune linear classifier)

This downloads a pre-trained model and trains the linear classifier, which should receive an accuracy of ±82.9% on the STL-10 test set.

git clone https://github.com/spijkervet/SimCLR.git && cd SimCLR
wget https://github.com/Spijkervet/SimCLR/releases/download/1.2/checkpoint_100.tar
sh setup.sh || python3 -m pip install -r requirements.txt || exit 1
conda activate simclr
python linear_evaluation.py --dataset=STL10 --model_path=. --epoch_num=100 --resnet resnet50

CPU

wget https://github.com/Spijkervet/SimCLR/releases/download/1.1/checkpoint_100.tar -O checkpoint_100.tar
python linear_evaluation.py --model_path=. --epoch_num=100 --resnet=resnet18 --logistic_batch_size=32

simclr package

SimCLR for PyTorch is now available as a Python package! Simply run and use it in your project:

pip install simclr

You can then simply import SimCLR:

from simclr import SimCLR

encoder = ResNet(...)
projection_dim = 64
n_features = encoder.fc.in_features  # get dimensions of last fully-connected layer
model = SimCLR(encoder, projection_dim, n_features)

Training ResNet encoder:

Simply run the following to pre-train a ResNet encoder using SimCLR on the CIFAR-10 dataset:

python main.py --dataset CIFAR10

Distributed Training

With distributed data parallel (DDP) training:

CUDA_VISIBLE_DEVICES=0 python main.py --nodes 2 --nr 0
CUDA_VISIBLE_DEVICES=1 python main.py --nodes 2 --nr 1
CUDA_VISIBLE_DEVICES=2 python main.py --nodes 2 --nr 2
CUDA_VISIBLE_DEVICES=N python main.py --nodes 2 --nr 3

Results

These are the top-1 accuracy of linear classifiers trained on the (frozen) representations learned by SimCLR:

Method Batch Size ResNet Projection output dimensionality Epochs Optimizer STL-10 CIFAR-10
SimCLR + Linear eval. 256 ResNet50 64 100 Adam 0.829 0.833
SimCLR + Linear eval. 256 ResNet50 64 100 LARS 0.783 -
SimCLR + Linear eval. 256 ResNet18 64 100 Adam 0.765 -
SimCLR + Linear eval. 256 ResNet18 64 40 Adam 0.719 -
SimCLR + Linear eval. 512 ResNet18 64 40 Adam 0.71 -
Logistic Regression - - - 40 Adam 0.358 0.389

Pre-trained models

ResNet (batch_size, epochs) Optimizer STL-10 Top-1
ResNet50 (256, 100) Adam 0.829
ResNet18 (256, 100) Adam 0.765
ResNet18 (256, 40) Adam 0.719

python linear_evaluation.py --model_path=. --epoch_num=100

LARS optimizer

The LARS optimizer is implemented in modules/lars.py. It can be activated by adjusting the config/config.yaml optimizer setting to: optimizer: "LARS". It is still experimental and has not been thoroughly tested.

What is SimCLR?

SimCLR is a "simple framework for contrastive learning of visual representations". The contrastive prediction task is defined on pairs of augmented examples, resulting in 2N examples per minibatch. Two augmented versions of an image are considered as a correlated, "positive" pair (x_i and x_j). The remaining 2(N - 1) augmented examples are considered negative examples. The contrastive prediction task aims to identify x_j in the set of negative examples for a given x_i.

Usage

Run the following command to setup a conda environment:

sh setup.sh
conda activate simclr

Or alternatively with pip:

pip install -r requirements.txt

Then, simply run for single GPU or CPU training:

python main.py

For distributed training (DDP), use for every process in nodes, in which N is the GPU number you would like to dedicate the process to:

CUDA_VISIBLE_DEVICES=0 python main.py --nodes 2 --nr 0
CUDA_VISIBLE_DEVICES=1 python main.py --nodes 2 --nr 1
CUDA_VISIBLE_DEVICES=2 python main.py --nodes 2 --nr 2
CUDA_VISIBLE_DEVICES=N python main.py --nodes 2 --nr 3

--nr corresponds to the process number of the N nodes we make available for training.

Testing

To test a trained model, make sure to set the model_path variable in the config/config.yaml to the log ID of the training (e.g. logs/0). Set the epoch_num to the epoch number you want to load the checkpoints from (e.g. 40).

python linear_evaluation.py

or in place:

python linear_evaluation.py --model_path=./save --epoch_num=40

Configuration

The configuration of training can be found in: config/config.yaml. I personally prefer to use files instead of long strings of arguments when configuring a run. An example config.yaml file:

# train options
batch_size: 256
workers: 16
start_epoch: 0
epochs: 40
dataset_dir: "./datasets"

# model options
resnet: "resnet18"
normalize: True
projection_dim: 64

# loss options
temperature: 0.5

# reload options
model_path: "logs/0" # set to the directory containing `checkpoint_##.tar` 
epoch_num: 40 # set to checkpoint number

# logistic regression options
logistic_batch_size: 256
logistic_epochs: 100

Logging and TensorBoard

To view results in TensorBoard, run:

tensorboard --logdir runs

Optimizers and learning rate schedule

This implementation features the Adam optimizer and the LARS optimizer, with the option to decay the learning rate using a cosine decay schedule. The optimizer and weight decay can be configured in the config/config.yaml file.

Dependencies

torch
torchvision
tensorboard
pyyaml

simclr's People

Contributors

alegonz avatar gauenk avatar isaaccorley avatar lxysl avatar robbiejones96 avatar spijkervet avatar steermomo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

simclr's Issues

I got some problems when I was testing

I was running the code on google colab. Using your ResNet50 (256, 100) pre-trained models, run the linear_evaluation.py. I also set the model_path, epoch_num=100 and "resnet50" in config.yaml.
Here is the error sign.
Files already downloaded and verified
Traceback (most recent call last):
File "linear_evaluation.py", line 170, in
simclr_model = SimCLR(args, encoder, n_features)
File "/content/drive/MyDrive/Test/simclr/simclr.py", line 26, in init
nn.Linear(self.n_features, projection_dim, bias=False),
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py", line 78, in init
self.weight = Parameter(torch.Tensor(out_features, in_features))
TypeError: new() received an invalid combination of arguments - got (ResNet, int), but expected one of:

  • (*, torch.device device)
    didn't match because some of the arguments have invalid types: (ResNet, int)
  • (torch.Storage storage)
  • (Tensor other)
  • (tuple of ints size, *, torch.device device)
  • (object data, *, torch.device device)
    Could you help me to solve it? Many thanks.

inter-gpu communication

Hi, I have a question about the multi-gpu training.

The original SimCLR paper uses a large batch size and the good performance is greatly contributed to the large batch size. This is critical when computing the loss. However, in your implementation, it seems you didn't consider other sub-batches in the other gpus. Specifically, if you set the batch size as 1024 and use 8 gpus, then in each gpu, the sub-batch size is 128. According to your implementation, it seems that when computing the nt_xent loss, the number of samples used is just 128, not 1024. I think the correct implementation should use all_gather operations to get all the 1024 samples from all the gpus.

I am not sure about the actual implementation. If I understand something wrong, please correct me.

Maybe a small bug in LARS implementation

Hello, thanks for your pretty implementation. I think I may find a small bug in your LARS implementation.trust_ratio = tf.where( tf.greater(w_norm, 0), tf.where( tf.greater(g_norm, 0), (self.eeta * w_norm / g_norm), 1.0), 1.0)
is a little different from

SimCLR/modules/lars.py

Lines 119 to 127 in 654f05f

trust_ratio = torch.where(
w_norm.ge(0),
torch.where(
g_norm.ge(0),
(self.eeta * w_norm / g_norm),
torch.Tensor([1.0]).to(device),
),
torch.Tensor([1.0]).to(device),
).item()

As greater is > and ge is >=. Thus bias paramater which is initialized as 0 is never updated. I think trust_ratio = torch.where( w_norm.gt(0), torch.where( g_norm.gt(0), (self.eeta * w_norm / g_norm), torch.Tensor([1.0]).to(device), ), torch.Tensor([1.0]).to(device), ).item() may work better.

Multi GPU support is ... ornamental?

Has the code been tested on multi GPUs in the DDP mode? -- not in the sense that it "runs" (it certainly does), but in the sense that it actually learns a representation (it likely doesn't)?

query regarding usage of TPU

hi @Spijkervet similar to official version of Simclr does your repo need access to google cloud storage bucket to procure dataset and save model checkpoints. please let me know.

Issue loading mask_correlated_samples

Hello, it's me again :), I'm having issue with this line as well

from utils import mask_correlated_samples, post_config_hook

if I comment mask_correlated_samples, the line works perfectly and I can't see the mask_correlated_samples class in the utils file.

Issue about the NT_Xent loss

In the forward function, logits is the splicing of the constructed positive samples and negative samples, but when calculating the loss, both positive samples and negative samples are given label 0. Is there a problem? The original paper implementation uses one-hot, that is, the positive sample should be 1.

Small refactoring of your loss

Hi, thanks for making available your implementation.

If I may, i think it would be cleaner if the mask generating function lives in your loss definition. That would remove the dependency to the args object, and as your loss function already have a batch_size argument, this is easy to include there.

Here is my take on this!

import torch
import torch.nn as nn


class NT_Xent(nn.Module):
    """
    More than inspired from https://github.com/Spijkervet/SimCLR/blob/master/modules/nt_xent.py

    Notes
    =====

    Using this pytorch implementation, you don't actually need to l2-norm the inputs, the results will be
    identical, as shown if you run this file.
    """

    def __init__(self, batch_size, temperature, device):
        super(NT_Xent, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.mask = self.get_correlated_samples_mask()
        self.device = device

        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def forward(self, z_i, z_j):
        """
        We do not sample negative examples explicitly.
        Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
        """

        p1 = torch.cat((z_i, z_j), dim=0)
        sim = self.similarity_f(p1.unsqueeze(1), p1.unsqueeze(0)) / self.temperature

        sim_i_j = torch.diag(sim, self.batch_size)
        sim_j_i = torch.diag(sim, -self.batch_size)

        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(self.batch_size * 2, 1)
        negative_samples = sim[self.mask].reshape(self.batch_size * 2, -1)

        labels = torch.zeros(self.batch_size * 2).to(self.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= 2 * self.batch_size
        return loss

    def get_correlated_samples_mask(self):
        mask = torch.ones((self.batch_size * 2, self.batch_size * 2), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(self.batch_size):
            mask[i, self.batch_size + i] = 0
            mask[self.batch_size + i, i] = 0
        return mask


if __name__ == "__main__":
    a, b = torch.rand(8, 12), torch.rand(8, 12)
    a_norm, b_norm = torch.nn.functional.normalize(a), torch.nn.functional.normalize(b)
    cosine_sim = torch.nn.CosineSimilarity()
    ntxent_loss = NT_Xent(8, 0.5, "cpu")
    assert torch.allclose(cosine_sim(a, b), cosine_sim(a_norm, b_norm))
    assert torch.allclose(ntxent_loss(a, b), ntxent_loss(a_norm, b_norm))

About the Cifar-10 accuracy

Hi Spijkervet,
Thanks for your great implementation. In your project, I find that the accuracy on Cifar-10 is significantly lower than the result reported in the original paper (55% VS 90%+). Is it because that you use the projection_dim of 64 instead of 128 or small batch-size?

Mistakes in linear_evaluation.py

I discover some mistakes in linear_evaluation.py:
in line 198, there should be "len(arr_train_loader)" rather than "len(train_loader)"
in line 206, there should be "len(arr_test_loader)" rather than "len(test_loader)"

Why L2 norm then CosSim?

Hi, This code is really useful for me. Thanks!
But I got a question about the NT-Xent loss. I noticed that you use L2 norm on z and then use cos_similarity after that. But cos_similarity already contain the function of l2 norm. Why use L2 norm first?

CIFAR10 accuracy

Hi Spijkervet,@Spijkervet

have you tried more epochs for your experiment settings, like 500 epochs,
the result from the paper has shown that even 256 batchsize could achieve about 93% accuracy for 500 epochs traning.
However, my implementation can only got ~89.5% accuracy even with 512 batchsize.

Best,
Chen

Mismatching resnet50 model for CIFAR-10 experiments

Hi there, thank you for this amazing implementation :)

I've been trying to roll out my own implementation too for the CIFAR-10 experiments, and I noticed that you use the resnet50 as implemented in torchvision and this differs from the one used in the paper in three ways:

  1. The first convolution must be with a kernel 3x3 and a stride of 1.
  2. No max pooling at the beginning.
  3. torchvision's version of resnet is v1.5 which does the downsampling at the 3x3 convolution in the bottleneck blocks, as opposed to v1 which does it at the first 1x1 convolution.

While the 3rd difference might not be that significant, I think the first two could be significant given the tiny size of the CIFAR images. The details are in the Setup section at the beginning of page 17 in the paper.

Also, it seems that you also need to set the color jitter strength to 0.5 when training on CIFAR.

I made a short function that hacks torchvision's resnet50 to match the one used in the paper and I can submit a PR if you like :)

Loss function is giving different results from tensorflow implementation

I noticed that your loss function is behaving differently from the original tensorflow implementation. It also seems like every pytorch implementation may have copied from each other because they are all giving the same result

Here is a reimplementation of the tensorflow version if you want to try it

import torch
import torch.nn.functional as F


def nt_cross_entropy(x, y, temperature=0.5):
    '''
    normalized temperature-scaled cross entropy loss
    
    tensorflow: https://github.com/google-research/simclr/blob/master/objective.py
    '''
    normalized_x = F.normalize(x, dim=1)
    normalized_y = F.normalize(y, dim=1)

    auto_x = normalized_x @ normalized_x.t() / temperature
    auto_y = normalized_y @ normalized_y.t() / temperature

    corr_xy = normalized_x @ normalized_y.t() / temperature
    corr_yx = normalized_y @ normalized_x.t() / temperature

    targets = torch.arange(auto_x.shape[0]).to(auto_x.device)
    big_diagonal = torch.eye(auto_x.shape[0], device=auto_x.device) * 1e9
    return (
        F.cross_entropy(
            torch.cat([corr_xy, auto_x - big_diagonal], dim=1),
            targets,
        )
        + F.cross_entropy(
            torch.cat([corr_yx, auto_y - big_diagonal], dim=1),
            targets,
        )
    )


def test_nt_cross_entropy():
    x = torch.tensor([[1.0, 2.0], [3.0, -2.0], [1.0, 5.0]])
    y = torch.tensor([[1.0, 0.75], [2.8, -1.75], [1.0, 4.7]])
    assert abs(nt_cross_entropy(x, y, temperature=1.0).item() - 2.2654006) <= 1e-5

    z = torch.tensor([[1.0, 1.75], [2.8, -1.75], [1.0, 4.7]])
    assert abs(nt_cross_entropy(x, z, temperature=1.0).item() - 2.1991892) <= 1e-5

The loss doesn't decrease when using multi nodes.

When i use one node, the code runs well. However, when I use 2 nodes and set the batch_size to 64, the loss is always around 5.545 and doesn't decrease. As 5.545 is the value of ln(512), it seems like that the network never get new knowledge during training. I have checked that the parameters are not fixed. I think maybe there is something wrong with the GatherLayer but i can not find it out. Have you met this problem?

torch version

Hi there, thank you for sharing your implementation.

I wonder your torch version is 1.3-1.5 or old version such as 1.1?

Loss decerases slowly in LFW dataset, Is it normal?

I want to use SSL to my project, LFW is only a demo. Model structure, loss function, optimizer is keeping the same as this repo, only one GPU, load ResNet50 pretrained, learning rate is 0.001, batch size is 64.

my dataloader

from collections import defaultdict
import torchvision.transforms as transforms
import random
from PIL import ImageFilter, Image
from torch.utils.data import Dataset


class GaussianBlur(object):
    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x


class FaceSet(Dataset):
    def __init__(self, label_file, input_shape):
        self.label_file = label_file
        self.imgs = []
        self.size = 0
        self.labels = []
        self.img_2_label = {}
        self.label_2_img = defaultdict(list)
        self.create_data()
        s = 1
        color_jitter = transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        self.trainform = transforms.Compose([
            lambda x: Image.open(x).convert("RGB"),  # open image
            transforms.CenterCrop(input_shape[1:]),
            transforms.RandomApply(
                [
                   color_jitter
                ],
                p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
            transforms.RandomHorizontalFlip(),  # with 0.5 probability
            transforms.ToTensor(),  # Converts a PIL Image to [0, 1]
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

    def create_data(self):
        with open(self.label_file, "r") as f:
            data = f.readlines()
            for line in data:
                line = line.strip()
                line = line.split(" ")
                idx = int(line[0])
                self.labels.append(idx)
                for imgs in line[1:]:
                    self.imgs.append(imgs)
                    self.img_2_label[imgs] = idx
                    self.label_2_img[idx].append(imgs)

        self.size = len(self.imgs)

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        pos_file = self.imgs[idx]
        pos_label = self.img_2_label[pos_file]
        pos1 = self.trainform(pos_file)
        pos2 = self.trainform(pos_file)

        # neg_label = pos_label
        # while neg_label == pos_label:
        #     neg_label = random.choice(self.labels)
        # neg_file = random.choice(self.label_2_img[neg_label])
        # neg = self.trainform(neg_file)

        return pos1, pos2

The log of trainning process

 ===>>> Ready to record the trainning process. <<<=== 
 ===>>> All data is 13233, each epoch iters 206.765625 <<<=== 
 ===>>> Load Data <<<=== 
 ===>>> Load Model <<<=== 
 ===>>> Epoch 1 <<<=== 
Avg Loss is 3.577977
Avg Loss is 3.379171
Avg Loss is 3.286532
Avg Loss is 3.239770
Save best model
 ===>>> Epoch 2 <<<=== 
Avg Loss is 3.096657
Avg Loss is 3.103381
Avg Loss is 3.096943
Avg Loss is 3.089485
Save best model
 ===>>> Epoch 3 <<<=== 
Avg Loss is 3.054055
Avg Loss is 3.054430
Avg Loss is 3.058049
Avg Loss is 3.055276
Save best model
 ===>>> Epoch 4 <<<=== 
Avg Loss is 3.041535
Avg Loss is 3.047728
Avg Loss is 3.049511
Avg Loss is 3.048278
Save best model
 ===>>> Epoch 5 <<<=== 
Avg Loss is 3.038077
Avg Loss is 3.037763
Avg Loss is 3.037274
Avg Loss is 3.036270
Save best model
 ===>>> Epoch 6 <<<=== 
Avg Loss is 3.040787
Avg Loss is 3.038092
Avg Loss is 3.033206
Avg Loss is 3.033263
Save best model
 ===>>> Epoch 7 <<<=== 
Avg Loss is 3.029101
Avg Loss is 3.025529
Avg Loss is 3.026146
Avg Loss is 3.025752
Save best model
 ===>>> Epoch 8 <<<=== 
Avg Loss is 3.021876
Avg Loss is 3.023627
Avg Loss is 3.022280
Avg Loss is 3.022333
Save best model
 ===>>> Epoch 9 <<<=== 
Avg Loss is 3.016602
Avg Loss is 3.014577
Avg Loss is 3.013268
Avg Loss is 3.013314
Save best model
 ===>>> Epoch 10 <<<=== 
Avg Loss is 3.009610
Avg Loss is 3.006718
Avg Loss is 3.007279
Avg Loss is 3.005000
Save best model
 ===>>> Epoch 11 <<<=== 
Avg Loss is 3.006585
Avg Loss is 3.003934
Avg Loss is 3.002049
Avg Loss is 2.999707
Save best model
 ===>>> Epoch 12 <<<=== 
Avg Loss is 2.996356
Avg Loss is 2.996086
Avg Loss is 2.996202
Avg Loss is 2.995989
Save best model
 ===>>> Epoch 13 <<<=== 
Avg Loss is 2.986980
Avg Loss is 2.987497
Avg Loss is 2.988292
Avg Loss is 2.988140
Save best model
 ===>>> Epoch 14 <<<=== 
Avg Loss is 2.984407
Avg Loss is 2.982733
Avg Loss is 2.981985
Avg Loss is 2.983570
Save best model
 ===>>> Epoch 15 <<<=== 
Avg Loss is 2.983063
Avg Loss is 2.981772
Avg Loss is 2.982046
Avg Loss is 2.981637
Save best model
 ===>>> Epoch 16 <<<=== 
Avg Loss is 2.973182
Avg Loss is 2.974823
Avg Loss is 2.974841
Avg Loss is 2.974745
Save best model
 ===>>> Epoch 17 <<<=== 
Avg Loss is 2.972452
Avg Loss is 2.971775
Avg Loss is 2.972774
Avg Loss is 2.973590
Save best model
 ===>>> Epoch 18 <<<=== 
Avg Loss is 2.968253
Avg Loss is 2.969223
Avg Loss is 2.969758
Avg Loss is 2.969625
Save best model
 ===>>> Epoch 19 <<<=== 
Avg Loss is 2.965280
Avg Loss is 2.963255
Avg Loss is 2.962672
Avg Loss is 2.962905
Save best model
 ===>>> Epoch 20 <<<=== 
Avg Loss is 2.962090
Avg Loss is 2.963186
Avg Loss is 2.963687
Avg Loss is 2.964000
 ===>>> Epoch 21 <<<=== 
Avg Loss is 2.960579
Avg Loss is 2.962159
Avg Loss is 2.961549
Avg Loss is 2.960474
Save best model
 ===>>> Epoch 22 <<<=== 
Avg Loss is 2.956076
Avg Loss is 2.954066
Avg Loss is 2.952812
Avg Loss is 2.953630
Save best model
 ===>>> Epoch 23 <<<=== 
Avg Loss is 2.957394
Avg Loss is 2.957740
Avg Loss is 2.955869
Avg Loss is 2.954493
 ===>>> Epoch 24 <<<=== 
Avg Loss is 2.950666
Avg Loss is 2.949612
Avg Loss is 2.949405
Avg Loss is 2.949652
Save best model
 ===>>> Epoch 25 <<<=== 
Avg Loss is 2.948632
Avg Loss is 2.948776
Avg Loss is 2.948908
Avg Loss is 2.949181
Save best model
 ===>>> Epoch 26 <<<=== 
Avg Loss is 2.944604
Avg Loss is 2.946091
Avg Loss is 2.947033
Avg Loss is 2.947451
Save best model
 ===>>> Epoch 27 <<<=== 
Avg Loss is 2.943191
Avg Loss is 2.944458
Avg Loss is 2.944051
Avg Loss is 2.944110
Save best model
 ===>>> Epoch 28 <<<=== 
Avg Loss is 2.942361
Avg Loss is 2.940979
Avg Loss is 2.941820
Avg Loss is 2.942666
Save best model
 ===>>> Epoch 29 <<<=== 
Avg Loss is 2.939801
Avg Loss is 2.939919
Avg Loss is 2.940369
Avg Loss is 2.940156
Save best model
 ===>>> Epoch 30 <<<=== 
Avg Loss is 2.935622
Avg Loss is 2.935057
Avg Loss is 2.935227
Avg Loss is 2.935683
Save best model
 ===>>> Epoch 31 <<<=== 
Avg Loss is 2.937815
Avg Loss is 2.935643
Avg Loss is 2.936246
Avg Loss is 2.936065
 ===>>> Epoch 32 <<<=== 
Avg Loss is 2.934247
Avg Loss is 2.934004
Avg Loss is 2.935194
Avg Loss is 2.934972
Save best model
 ===>>> Epoch 33 <<<=== 
Avg Loss is 2.931649
Avg Loss is 2.930846
Avg Loss is 2.930627
Avg Loss is 2.930930
Save best model
 ===>>> Epoch 34 <<<=== 
Avg Loss is 2.929090
Avg Loss is 2.929404
Avg Loss is 2.928846
Avg Loss is 2.929138
Save best model
 ===>>> Epoch 35 <<<=== 
Avg Loss is 2.927705
Avg Loss is 2.927764
Avg Loss is 2.927287
Avg Loss is 2.927516
Save best model
 ===>>> Epoch 36 <<<=== 
Avg Loss is 2.927743
Avg Loss is 2.926714
Avg Loss is 2.926628
Avg Loss is 2.926617
Save best model
 ===>>> Epoch 37 <<<=== 
Avg Loss is 2.924690
Avg Loss is 2.924439
Avg Loss is 2.924409
Avg Loss is 2.924603
Save best model
 ===>>> Epoch 38 <<<=== 
Avg Loss is 2.922124
Avg Loss is 2.922200
Avg Loss is 2.922583
Avg Loss is 2.922992
Save best model
 ===>>> Epoch 39 <<<=== 
Avg Loss is 2.922472
Avg Loss is 2.921901
Avg Loss is 2.921420
Avg Loss is 2.921485
Save best model
 ===>>> Epoch 40 <<<=== 
Avg Loss is 2.921616
Avg Loss is 2.920895
Avg Loss is 2.920781
Avg Loss is 2.920465
Save best model
 ===>>> Epoch 41 <<<=== 
Avg Loss is 2.918778
Avg Loss is 2.918589
Avg Loss is 2.918421
Avg Loss is 2.918423
Save best model
 ===>>> Epoch 42 <<<=== 
Avg Loss is 2.919555
Avg Loss is 2.919151
Avg Loss is 2.918455
Avg Loss is 2.918853
 ===>>> Epoch 43 <<<=== 
Avg Loss is 2.917781
Avg Loss is 2.916509
Avg Loss is 2.916942
Avg Loss is 2.916903
Save best model
 ===>>> Epoch 44 <<<=== 
Avg Loss is 2.917233
Avg Loss is 2.917524
Avg Loss is 2.916481
Avg Loss is 2.916096
Save best model
 ===>>> Epoch 45 <<<=== 
Avg Loss is 2.915788
Avg Loss is 2.915589
Avg Loss is 2.915973
Avg Loss is 2.915928
Save best model
 ===>>> Epoch 46 <<<=== 
Avg Loss is 2.916222
Avg Loss is 2.916418
Avg Loss is 2.916036
Avg Loss is 2.915801
Save best model
 ===>>> Epoch 47 <<<=== 
Avg Loss is 2.914999
Avg Loss is 2.913783
Avg Loss is 2.913962
Avg Loss is 2.914269
Save best model
 ===>>> Epoch 48 <<<=== 
Avg Loss is 2.914487
Avg Loss is 2.914710
Avg Loss is 2.914641
Avg Loss is 2.914529
 ===>>> Epoch 49 <<<=== 
Avg Loss is 2.915182
Avg Loss is 2.914820
Avg Loss is 2.914557
Avg Loss is 2.914477

Code explanation in gather.py

Hi, Janne

The GatherLayer module in gather.py is smart and efficient. I really appreciate this module.

I almost understand most of the functions of the codes but still confuse at one point.

line 19 of SimCLR/simclr/modules/gather.py

    grad_out[:] = grads[dist.get_rank()]

Would you mind explaining it in more detail?

Logistic Regression Issue without "eval()"

At https://github.com/Spijkervet/SimCLR/blob/master/testing/logistic_regression.py,

simclr_model needs to be set "eval" mode.

At line 168:
# load pre-trained model from checkpoint
simclr_model = SimCLR(args, encoder, n_features)
simclr_model.eval() # need to add this line
model_fp = os.path.join(
args.model_path, "checkpoint_{}.tar".format(args.epoch_num)
)
simclr_model.load_state_dict(torch.load(model_fp, map_location=args.device.type))
simclr_model = simclr_model.to(args.device)

Without "eval" mode, the simclr network weights were unintentionally updated when we extract features (especially for testing set). Then it significantly reduces the testing accuracy for my test cases.

We need to add ".eval()". Please check it out and confirm.

Thank you.

Some questions about the projection head.

Hi, thank you very much for your code. But I have a doubt about the projection head. I think there should be two BN layers i.e. linear layer , BN , relu, linear layer, BN. I don't know if this should be the case, and how much impact the lack of BN layer has on the performance of the model. look forward to your reply. Thanks again.

Hello, If I want to use vgg16, how can I modify it?

I get vgg16 by this
import torchvision
def get_vgg(name, pretrained=False):
vggs = {
"vgg16": torchvision.models.vgg16(pretrained=pretrained),
}
if name not in vggs.keys():
raise KeyError(f"{name} is not a valid ResNet version")
return vggs[name]
and modify this
encoder = get_vgg(args.vgg, pretrained=False)
n_features = encoder.classifier._modules['6'].in_features
But it doen't work. Could u help me?

No upscale in image augmentation?

The SimCLR paper says:

In this work, we sequentially apply three simple augmentations: random
cropping followed by resize back to the original size, random color distortions, and random Gaussian blur

but it seems like the augmentations used in this repository first do a random crop, but do not afterwards resize the crop back to the original size. Why the difference? Am I misunderstanding the SimCLR paper?

Error when loading the datasets

Hi, thank you for the clean code.

I'm having issues, trying to run this cell:

root = "./datasets"

train_sampler = None

if args.dataset == "STL10":
    train_dataset = torchvision.datasets.STL10(
        root, split="unlabeled", download=True, transform=TransformsSimCLR()
    )
elif args.dataset == "CIFAR10":
    train_dataset = torchvision.datasets.CIFAR10(
        root, download=True, transform=TransformsSimCLR()
    )
else:
    raise NotImplementedError

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=(train_sampler is None),
    drop_last=True,
    num_workers=args.workers,
    sampler=train_sampler,
)

This is the error message, I'm getting :

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-13-bc54d6807dbe> in <module>
      9 elif args.dataset == "CIFAR10":
     10     train_dataset = torchvision.datasets.CIFAR10(
---> 11         root, download=True, transform=TransformsSimCLR()
     12     )
     13 else:

TypeError: __init__() missing 1 required positional argument: 'size'

Question about multi-gpu training

Thanks for sharing such an excellent work:

Now I want to use eight GPUs for DDP training on a server. How should I use the script? I am a little confused about the usage of nr.

Looking forward to your reply!
thanks!

train on Imagenet

Hello, Have you trained on Imagenet with batch-size 256?

I am doing this, I found the curve of traing loss is very slow and steady. like this:
image

normalize is not used in data augmentation

From the code, data augmentation used here is

torchvision.transforms.RandomResizedCrop(size=size),
torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
torchvision.transforms.RandomApply([color_jitter], p=0.8),
torchvision.transforms.RandomGrayscale(p=0.2),
torchvision.transforms.ToTensor(),

I think the data needs to be normalized here

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.