Giter Club home page Giter Club logo

cnaps's Introduction

CNAPs: Fast and Flexible Multi-Task Classification Using Conditional Neural Adaptive Processes

This repository contains the code to reproduce the few-shot classification experiments carried out in Fast and Flexible Multi-Task Classification Using Conditional Neural Adaptive Processes and TASKNORM: Rethinking Batch Normalization for Meta-Learning.

The code has been authored by: John Bronskill, Jonathan Gordon, and James Reqeima.

Dependencies

This code requires the following:

  • Python 3.5 or greater
  • PyTorch 1.0 or greater
  • TensorFlow 1.15 or greater

This code has been recently verified on PyTorch 1.7 and TensorFlow 2.3.

GPU Requirements

  • To train or test a CNAPs model with auto-regressive FiLM adaptation on Meta-Dataset, 2 GPUs with 16GB or more memory are required.
  • To train or test a CNAPs model with FiLM only adaptation plus TaskNorm on Meta-Dataset, 2 GPUs with 16GB or more memory are required.
  • It is not currently possible to run a CNAPs model with auto-regressive FiLM adaptation plus TaskNorm on Meta-Dataset (even using 2 GPUs with 16GB of memory). It may be possible (we have not tried) to run this configuration on 2 GPUs with 24GB of memory.
  • The other modes require only a single GPU with at least 16 GB of memory.
  • If you want to run any of the modes on a single GPU, you can train on a single dataset with fixed shot and way. If shot and way are not too large, this configuration will require a single GPU with less than 16GB of memory. An example command line is (though this will not reproduce the meta-dataset results):

python run_cnaps.py --feature_adaptation film -i 20000 -lr 0.001 --batch_normalization task_norm-i -- dataset omniglot --way 5 --shot 5 --data_path <path to directory containing Meta-Dataset records>

Installation

  1. Clone or download this repository.
  2. Configure Meta-Dataset:
  3. Install additional test datasets (MNIST, CIFAR10, CIFAR100):
    • Change to the $DATASRC directory: cd $DATASRC
    • Download the MNIST test images: wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
    • Download the MNIST test labels: wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
    • Download the CIFAR10 dataset: wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
    • Extract the CIFAR10 dataset: tar -zxvf cifar-10-python.tar.gz
    • Download the CIFAR100 dataset: wget https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
    • Extract the CIFAR10 dataset: tar -zxvf cifar-100-python.tar.gz
    • Change to the cnaps/src directory in the repository.
    • Run: python prepare_extra_datasets.py

Usage

To train and test CNAPs on Meta-Dataset:

  1. First run the following two commands.

    ulimit -n 50000

    export META_DATASET_ROOT=<root directory of the cloned or downloaded Meta-Dataset repository>

    Note the above commands need to be run every time you open a new command shell.

  2. Execute the run_cnaps.py script from the src directory following the instructions at the beginning of the file.

Expected Results

The FiLM + TaskNorm configuration consistently yields the best results and trains in much less time than the other configurations. A meta-trained FiLM + TaskNorm-i model is included in the models folder which produced the results shown below. The model was trained for 40,000 iterations on two 16GB GPUs. Note that these results differ from those published in our paper as they now fix the shuffle buffer bug described in meta-dataset issue #54. In particular, the results for the Traffic Signs dataset are considerable worse. However, the results for other datasets are comparable (some slightly better, some slightly worse).

Model trained on all datasets

Dataset FiLM + TaskNorm
ILSVRC 50.8±1.1
Omniglot 91.7±0.5
Aircraft 83.7±0.6
Birds 73.6±0.9
Textures 59.5±0.7
Quick Draw 74.7±0.8
Fungi 50.2±1.1
VGG Flower 88.9±0.5
Traffic Signs 56.5±1.1
MSCOCO 39.4±1.0
MNIST 92.3±0.4
CIFAR10 68.5±0.9
CIFAR100 56.1±1.1

Contact

To ask questions or report issues, please open an issue on the issues tracker.

Citation

If you use this code, please cite our CNAPs and TaskNorm papers:

@incollection{requeima2019cnaps,
  title      = {Fast and Flexible Multi-Task Classification using Conditional Neural Adaptive Processes},
  author     = {Requeima, James and Gordon, Jonathan and Bronskill, John and Nowozin, Sebastian and Turner, Richard E},
  booktitle  = {Advances in Neural Information Processing Systems 32},
  editor     = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\' Alch\'{e}-Buc and E. Fox and R. Garnett},
  pages      = {7957--7968},
  year       = {2019},
  publisher  = {Curran Associates, Inc.},
}

@incollection{bronskill2020tasknorm,
  title     = {TaskNorm: Rethinking Batch Normalization for Meta-Learning},
  author    = {Bronskill, John and Gordon, Jonathan and Requeima, James and Nowozin, Sebastian and Turner, Richard},
  booktitle = {Proceedings of the 37th International Conference on Machine Learning},
  volume    = {119},
  series    = {Proceedings of Machine Learning Research},
  publisher = {PMLR},
  year      = {2020}
}

cnaps's People

Contributors

jfb54 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

cnaps's Issues

Number of parameters in the adaptation module

Dear authors,

First, I would like to thank you for such great work and clean implementation.

I am now working on a related project and I would like to compare our method to yours (CNAPS with AR) in terms of the total number of parameters used to perform adaptation. Precisely, I want to know the ratio between all adaptation parameters and all parameters of ResNet18. Do you have this number by any chance?

Thank you,
Nikita

Second Question

cnaps/src/model.py

Lines 108 to 110 in 5a11dc0

def _get_classifier_params(self, train_features, train_labels):
classifier_params = self.classifier_adaptation_network(self.class_representations)
return classifier_params

Why don't you utilize these params train_features and train_labels in your function _get_classifier_params ?

License for the code

Thanks a lot for the code. I am planning to try it. Can you please specify the License for its use?

pretrained file

Hello, when I decompressed the pretrained file package, the compressed file was damaged. If it is convenient for you, could you please send it again? Thank you very much

What is the identity class used for?

Hi authors,
Thank you very much for your great work! I am reading your code and have a question in set_enconder.py.

cnaps/src/set_encoder.py

Lines 42 to 47 in 5a11dc0

class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x

What is the class identity used for? Is it the rho (.) function in line29(see following)

g(X) = rho ( mean ( phi(x) ) )

It just seems that it does not change itself because the class Identity just returns the input parameter. Why you need to do post_pooling_fn step?

About the training time

Hi, thank you for sharing your work! These days, I also have tried some experiments on Meta-data set by using your code. I have used a RTX3090 to train a model only on ImageNet, but found that per 500 episodes takes about 38 minutes, which is very solw. Is this speed reasonable?

Question about classifier bias

In https://github.com/cambridge-mlg/cnaps/blob/master/src/adaptation_networks.py#L238, the bias_mean_processor is :

    (bias_means_processor): DenseResidualBlock(
      (linear1): Linear(in_features=512, out_features=1, bias=True)
      (linear2): Linear(in_features=1, out_features=1, bias=True)
      (linear3): Linear(in_features=1, out_features=1, bias=True)
      (elu): ELU(alpha=1.0)
    )

which is a little different from Table 13 of the paper.

Is this a mistake or do I misunderstand something?

Questions about changing the dataset

Hello, I am a few-shot learning beginner. I was very interested in your work, but I ran into this problem while trying to change the training data set to miniImagenet:
RuntimeError: No dataset_spec file found in directory /home/cnaps/filelists/miniImagenet
Can you help me solve it? Looking forward to your reply!

Pre-trained ImageNet model

Do you have idea on how to learn a new pretrained ImageNet model using this framework by myself if I would like to try a different architecture?

Meta-Dataset version

Could you clarify which meta-dataset version is used for this repo? Since Meta-dataset repo is updated very frequently.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.