Giter Club home page Giter Club logo

proxy-nca's Introduction

About

This repository contains a PyTorch implementation of No Fuss Distance Metric Learning using Proxies as introduced by Google Research.

The training and evaluation setup is exactly the same as described in the paper, except that Adam was used as optimizer instead of RMSprop.

I have ported the PyTorch BN-Inception model from PyTorch 0.2 to PyTorch >= 0.4. It's weights are stored inside the repository in the directory net.

You need Python3, PyTorch >= 1.1 and torchvision >= 0.3.0 to run the code. I have used CUDA Version 10.0.130.

Note that negative log with softmax is used as ProxyNCA loss. Therefore, the anchor-positive-proxy distance is not excluded in the denominator. In practice, I have not noticed a difference.

The importance of scaling of the normalized proxies and embeddings is mentioned in the ProxyNCA paper (in the theoretical background), but the exact scaling factors are ommitted. I have found that (3, 3) work well for CUB and Cars and (8, 1) work well for SOP (first being for proxies and latter for embeddings).

Reproducing Results

You can adjust most training settings (learning rate, optimizer, criterion, dataset, ...) in the config file.

You'll only have to adjust the root paths for the datasets. Then you're ready to go.

Downloading and Extracting the Datasets

mkdir cars196
cd cars196
wget http://imagenet.stanford.edu/internal/car196/cars_annos.mat
wget http://imagenet.stanford.edu/internal/car196/car_ims.tgz
tar -xzvf car_ims.tgz
pwd # use this path as root path for config file
wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz
tar -xzvf CUB_200_2011.tgz
cd CUB_200_2011
pwd # use this path as root path for config file
wget ftp://cs.stanford.edu/cs/cvgl/Stanford_Online_Products.zip
unzip Stanford_Online_Products.zip
cd Stanford_Online_Products
pwd # use this path as root path for config file

Commands

DATA=cub; SCALING_X=3.0; SCALING_P=3.0; LR=1; python3 train.py --data $DATA \
--log-filename $DATA-scaling_x_$SCALING_X-scaling_p_$SCALING_P-lr_$LR \
--config config.json --epochs=20 --gpu-id 0 --lr-proxynca=$LR \
--scaling-x=$SCALING_X --scaling-p=$SCALING_P --with-nmi
DATA=cars; SCALING_X=3.0; SCALING_P=3.0; LR=1; python3 train.py --data $DATA \
--log-filename $DATA-scaling_x_$SCALING_X-scaling_p_$SCALING_P-lr_$LR \
--config config.json --epochs=50 --gpu-id 1 --lr-proxynca=$LR \
--scaling-x=$SCALING_X --scaling-p=$SCALING_P --with-nmi
DATA=sop; SCALING_X=1; SCALING_P=8; LR=10; python3 train.py --data $DATA \
--log-filename $DATA-scaling_x_$SCALING_X-scaling_p_$SCALING_P-lr_$LR \
--config config.json --epochs=50 --gpu-id 3 --lr-proxynca=$LR \
--scaling-x=$SCALING_X --scaling-p=$SCALING_P

Results

The results were obtained mostly with one Titan X or a weaker GPU.

Reading: This Implementation [Google's Implementation].

CUB Cars SOP
Duration 00:19h 00:24h 01:55h
Epoch 17 15 16
Log here here here
R@1 52.63 [49.21] 72.19 [73.22] 74.07 [73.73]
R@2 64.63 [61.90] 81.31 [82.42] 79.13 [-------]
R@4 75.76 [67.90] 87.54 [86.36] 83.30 [-------]
R@8 84.52 [72.40] 92.54 [88.68] 86.66 [-------]
NMI 60.64 [59.53] 62.45 [64.90] ----------

Referencing this Implementation

If you'd like to reference this ProxyNCA implementation, you can use this bibtex:

@misc{Tschernezki2020,
  author = {Tschernezki, Vadim and Sanakoyeu, Artsiom and Ommer, Bj{\"o}rn,},
  title = {PyTorch Implementation of ProxyNCA},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/dichotomies/proxy-nca}},
}

proxy-nca's People

Contributors

dichotomies avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

proxy-nca's Issues

Dealing with Multi-Label Classifier Models

What if we had target values consisting of multiple labels?

then this would have to change:
Z_i = torch.index_select( proxies_normed, 0, Variable( torch.LongTensor( np.reshape( np.setdiff1d( np.arange(0, nb_classes), ts.data.cpu().numpy()[i] ), (nb_classes - 1) ) ) ).cuda() )

to something more dynamic, instead of forcing (nb_classes - 1), although the same concept applies, taking all "non-applicable" labels.

But then how would I deal with this:
n_dist = torch.sum(torch.exp(-torch.sum(torch.pow(Z_i - xs[i], 2), dim=1)))

Any help is appreciated! Still haven't fully understood all the code but I'm getting there!

Early overfit?

Hi, I have a similar implementation in mxnet but after the first epoch both the test and train accuracy goes down (though loss as well) instead of going up. I tried your code as well with my dataset and the same thing happens. Do proxies overfit after 1 epoch?

What about dynamic assignment

Thanks for your project, great work!
I am new to this, I wanna know how do you implement the dynamic assignment when the label is not available.
Because dynamic assignment needs Argmin function, which can not back propagation, how do you address it?

Class Labels

image
Does "T" represent class labels with a size of (batch_size)?
If yes, can I say that input and output of Proxynca and softmax are the same?

Magic numbers

Apart from trial and error, is there any particular reason for using a multiplicative factor of 3 in the loss function. Is this the same reason why gradients are clipped above the value 10

Other datasets and a remark

Hi there,

You posted a comment on bnu-wangxun/Deep_Metric#5 (comment) and I checked out your work here, nice job!

During my research around this loss, I found that CUB200 is the least challenging dataset to fit and I was able to reproduce the results on it quite well. However, I really struggled with the other datasets mentioned in the paper. Have you tried to use CARS196 dataset? It has relatively the same amount of training examples, however, achieving the same Recall@K metrics was more challenging.

One final remark is that you mention using InceptionV2 model however the authors stated that they used Inception-BN model which, as to my knowledge, are not the same thing. Although, as per the paper of InceptionV2 the 7x7 convolution was factorized into smaller ones and the presence of such convolution in your code reveals that it is Inception-BN rather than InceptionV2, which is correct (however I didn't check the whole code for the network thoroughly). It seems like you just mislabeled it.

can not reproduce the performance

Hi, @dichotomies
I just tried "python train.py" on cub200, but the performance was not as good as yours.
So, I'd like to know if I have missed some details.
Thanks!

Here is the log
2019-01-17 15:31:42,341 Training parameters: {'cub_root': 'cub200', 'cub_is_extracted': True, 'sz_embedding': 64, 'nb_classes': 100, 'sz_batch': 32, 'lr_embedding': 1e-05, 'lr_inception': 0.001, 'lr_proxynca': 0.001, 'weight_decay': 0.0005, 'epsilon': 0.01, 'gamma': 0.1, 'nb_epochs': 20, 'log_filename': 'example', 'gpu_id': 0, 'nb_workers': 16}
2019-01-17 15:31:42,341 Training for 20 epochs.
2019-01-17 15:31:42,341 Evaluating initial model...
2019-01-17 15:32:00,189 NMI: 47.185
2019-01-17 15:32:01,649 R@1 : 33.355
2019-01-17 15:32:01,761 R@2 : 43.970
2019-01-17 15:32:01,885 R@4 : 54.650
2019-01-17 15:32:02,006 R@8 : 67.203
2019-01-17 15:32:21,369 Epoch: 1, loss: 4.282, time (seconds): 19.36.
2019-01-17 15:32:21,382 Evaluating...
2019-01-17 15:32:39,077 NMI: 51.503
2019-01-17 15:32:40,414 R@1 : 37.200
2019-01-17 15:32:40,530 R@2 : 48.505
2019-01-17 15:32:40,645 R@4 : 59.809
2019-01-17 15:32:40,759 R@8 : 71.607
2019-01-17 15:33:00,144 Epoch: 2, loss: 3.987, time (seconds): 19.38.
2019-01-17 15:33:00,158 Evaluating...
2019-01-17 15:33:17,491 NMI: 56.126
2019-01-17 15:33:18,811 R@1 : 44.298
2019-01-17 15:33:18,921 R@2 : 56.425
2019-01-17 15:33:19,031 R@4 : 67.499
2019-01-17 15:33:19,142 R@8 : 77.982
2019-01-17 15:33:38,986 Epoch: 3, loss: 3.747, time (seconds): 19.84.
2019-01-17 15:33:39,003 Evaluating...
2019-01-17 15:33:56,294 NMI: 55.789
2019-01-17 15:33:57,738 R@1 : 44.791
2019-01-17 15:33:57,851 R@2 : 56.260
2019-01-17 15:33:57,961 R@4 : 67.598
2019-01-17 15:33:58,070 R@8 : 78.114
2019-01-17 15:34:17,360 Epoch: 4, loss: 3.653, time (seconds): 19.29.
2019-01-17 15:34:17,375 Evaluating...
2019-01-17 15:34:39,191 NMI: 56.323
2019-01-17 15:34:40,620 R@1 : 47.355
2019-01-17 15:34:40,737 R@2 : 57.378
2019-01-17 15:34:40,851 R@4 : 67.861
2019-01-17 15:34:40,966 R@8 : 77.588
2019-01-17 15:35:00,658 Epoch: 5, loss: 3.605, time (seconds): 19.69.
2019-01-17 15:35:00,661 Evaluating...
2019-01-17 15:35:17,480 NMI: 57.108
2019-01-17 15:35:18,810 R@1 : 46.434
2019-01-17 15:35:18,924 R@2 : 57.640
2019-01-17 15:35:19,035 R@4 : 68.814
2019-01-17 15:35:19,146 R@8 : 78.377
2019-01-17 15:35:38,904 Epoch: 6, loss: 3.632, time (seconds): 19.76.
2019-01-17 15:35:38,915 Evaluating...
2019-01-17 15:35:57,393 NMI: 57.125
2019-01-17 15:35:58,711 R@1 : 46.960
2019-01-17 15:35:58,822 R@2 : 58.002
2019-01-17 15:35:58,932 R@4 : 68.879
2019-01-17 15:35:59,041 R@8 : 78.672
2019-01-17 15:36:18,225 Epoch: 7, loss: 3.573, time (seconds): 19.18.
2019-01-17 15:36:18,238 Evaluating...
2019-01-17 15:36:35,586 NMI: 56.258
2019-01-17 15:36:37,024 R@1 : 46.763
2019-01-17 15:36:37,135 R@2 : 57.838
2019-01-17 15:36:37,246 R@4 : 68.518
2019-01-17 15:36:37,356 R@8 : 78.114
2019-01-17 15:36:56,823 Epoch: 8, loss: 3.553, time (seconds): 19.47.
2019-01-17 15:36:56,830 Evaluating...
2019-01-17 15:37:14,101 NMI: 56.605
2019-01-17 15:37:15,501 R@1 : 46.204
2019-01-17 15:37:15,613 R@2 : 57.936
2019-01-17 15:37:15,723 R@4 : 67.203
2019-01-17 15:37:15,833 R@8 : 78.114
2019-01-17 15:37:35,404 Epoch: 9, loss: 3.526, time (seconds): 19.57.
2019-01-17 15:37:35,406 Evaluating...
2019-01-17 15:37:50,982 NMI: 56.698
2019-01-17 15:37:52,344 R@1 : 46.960
2019-01-17 15:37:52,457 R@2 : 58.101
2019-01-17 15:37:52,568 R@4 : 68.978
2019-01-17 15:37:52,679 R@8 : 78.935
2019-01-17 15:38:12,034 Epoch: 10, loss: 3.527, time (seconds): 19.35.
2019-01-17 15:38:12,045 Evaluating...
2019-01-17 15:38:29,900 NMI: 56.664
2019-01-17 15:38:31,419 R@1 : 46.632
2019-01-17 15:38:31,530 R@2 : 58.363
2019-01-17 15:38:31,640 R@4 : 68.781
2019-01-17 15:38:31,751 R@8 : 79.264
2019-01-17 15:38:50,909 Epoch: 11, loss: 3.508, time (seconds): 19.16.
2019-01-17 15:38:50,917 Evaluating...
2019-01-17 15:39:09,785 NMI: 56.386
2019-01-17 15:39:11,244 R@1 : 47.157
2019-01-17 15:39:11,360 R@2 : 58.265
2019-01-17 15:39:11,474 R@4 : 68.518
2019-01-17 15:39:11,589 R@8 : 78.640
2019-01-17 15:39:31,209 Epoch: 12, loss: 3.481, time (seconds): 19.62.
2019-01-17 15:39:31,223 Evaluating...
2019-01-17 15:39:54,383 NMI: 57.177
2019-01-17 15:39:56,128 R@1 : 46.172
2019-01-17 15:39:56,262 R@2 : 57.706
2019-01-17 15:39:56,400 R@4 : 68.386
2019-01-17 15:39:56,545 R@8 : 78.409
2019-01-17 15:40:15,812 Epoch: 13, loss: 3.505, time (seconds): 19.26.
2019-01-17 15:40:15,825 Evaluating...
2019-01-17 15:40:32,983 NMI: 56.854
2019-01-17 15:40:34,309 R@1 : 46.664
2019-01-17 15:40:34,421 R@2 : 58.068
2019-01-17 15:40:34,531 R@4 : 69.405
2019-01-17 15:40:34,642 R@8 : 78.048
2019-01-17 15:40:53,608 Epoch: 14, loss: 3.493, time (seconds): 18.97.
2019-01-17 15:40:53,621 Evaluating...
2019-01-17 15:41:11,281 NMI: 54.765
2019-01-17 15:41:12,615 R@1 : 45.120
2019-01-17 15:41:12,728 R@2 : 55.866
2019-01-17 15:41:12,838 R@4 : 67.335
2019-01-17 15:41:12,949 R@8 : 77.062
2019-01-17 15:41:32,569 Epoch: 15, loss: 3.479, time (seconds): 19.62.
2019-01-17 15:41:32,583 Evaluating...
2019-01-17 15:41:49,486 NMI: 58.062
2019-01-17 15:41:50,908 R@1 : 47.223
2019-01-17 15:41:51,020 R@2 : 58.331
2019-01-17 15:41:51,129 R@4 : 68.551
2019-01-17 15:41:51,239 R@8 : 78.475
2019-01-17 15:42:12,802 Epoch: 16, loss: 3.494, time (seconds): 21.56.
2019-01-17 15:42:12,814 Evaluating...
2019-01-17 15:42:30,496 NMI: 55.504
2019-01-17 15:42:31,948 R@1 : 45.186
2019-01-17 15:42:32,061 R@2 : 56.852
2019-01-17 15:42:32,173 R@4 : 66.809
2019-01-17 15:42:32,284 R@8 : 77.982
2019-01-17 15:42:51,355 Epoch: 17, loss: 3.504, time (seconds): 19.07.
2019-01-17 15:42:51,368 Evaluating...
2019-01-17 15:43:08,087 NMI: 57.182
2019-01-17 15:43:09,491 R@1 : 46.204
2019-01-17 15:43:09,608 R@2 : 58.133
2019-01-17 15:43:09,722 R@4 : 69.307
2019-01-17 15:43:09,836 R@8 : 78.212
2019-01-17 15:43:29,346 Epoch: 18, loss: 3.495, time (seconds): 19.51.
2019-01-17 15:43:29,360 Evaluating...
2019-01-17 15:43:45,197 NMI: 57.448
2019-01-17 15:43:46,630 R@1 : 46.927
2019-01-17 15:43:46,743 R@2 : 58.462
2019-01-17 15:43:46,853 R@4 : 69.635
2019-01-17 15:43:46,964 R@8 : 79.001
2019-01-17 15:44:06,537 Epoch: 19, loss: 3.481, time (seconds): 19.57.
2019-01-17 15:44:06,552 Evaluating...
2019-01-17 15:44:29,381 NMI: 57.078
2019-01-17 15:44:30,807 R@1 : 46.632
2019-01-17 15:44:30,920 R@2 : 57.016
2019-01-17 15:44:31,031 R@4 : 68.124
2019-01-17 15:44:31,142 R@8 : 77.949
2019-01-17 15:44:50,334 Epoch: 20, loss: 3.509, time (seconds): 19.19.
2019-01-17 15:44:50,348 Evaluating...
2019-01-17 15:45:12,585 NMI: 57.708
2019-01-17 15:45:14,051 R@1 : 47.387
2019-01-17 15:45:14,165 R@2 : 58.824
2019-01-17 15:45:14,277 R@4 : 69.504
2019-01-17 15:45:14,388 R@8 : 79.428
2019-01-17 15:45:14,389 Total training time (minutes): 13.53.

Result

Do you reach the performance as the paper listed on CUB-200-2011?

Where the proxies are changed during training

First, bug thank you for this. I find it super useful.

Then the my beginner question:

In another thread you wrote:
The proxies are randomly initialized and change during training.

but where and how is this happening? Is this line called every time? Because I cannot see any other line updating proxies.

Thank again

function pairwise_distance

Why the implementation of function pairwise_distance is too complex and cost many GPU memory? In my opinion, this function is equivalent to:
D = 18 - 2 * torch.mm(X, P.t())

Not runnable

I could not run your code in any pytorch version. I'm using python 3.6.
After fixing the call pnca = ProxyNCA(no_top_model, SZ_EMBED, NB_CLASS, SZ_BATCH) in pnca.py and using torch.nn.init.xavier_uniform or torch.nn.init.xavier_uniform_ depending on pytorch version, I receive these errors depending on the pytorch version:

v0.1.12-post2 and v0.2.0-post2:
RuntimeError: Given input size: (128 x 3 x 3). Calculated output size: (768 x -1 x -1). Output size is too small at /b/wheel/pytorch-src/torch/lib/THNN/generic/SpatialConvolutionMM.c:4

v0.3.0 and v0.4.0:
RuntimeError: Calculated padded input size per channel: (3 x 3). Kernel size: (5 x 5). Kernel size can't greater than actual input size at /pytorch/aten/src/THNN/generic/SpatialConvolutionMM.c:48

Any clue?

question about the decrease process of the loss

I noticed that during the training process, your proxy loss only decreased a little. Although the result is similar with the result in paper, maybe it will further increase if the loss keep decreasing. Do you have any idea why this happens?

I tried to replace the loss by using traditional triplet loss. Also the loss is difficult to decrease. And the reproduced rank-1 accuracy is only 34%. Would you please tell me what should i change?

Thank you so much for your effort.

Wonder about proxyNCA loss

loss = torch.sum(- T * F.log_softmax(D, -1), -1), should D be -D since we want to minimize the distance between the example with its proxy with same label.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.