Giter Club home page Giter Club logo

gnn_few_shot_cifar100's Introduction

Few-Shot Learning with Graph Neural Networks on CIFAR-100

This is the PyTorch-0.4.0 implementation of few-shot learning on CIFAR-100 with graph neural networks (GNN). And the codes is on the basis of following paper/github/course.

Besides directly end-to-end training both CNN-embedding layers and graph convolution layers, I also tried different training methods and got better results sometimes.

  1. Pretrain the CNN-embedding layers by do classification on CIFAR-100 (excluding few-shot data), and fine-tune them while training graph convolution layer.
  2. Pretrain the CNN-embedding layers but fixing them while training graph convolution layer.

Dataset

CIFAR-100 (Get the dataset from torchvision)

Requirement

python == 3.6
pytorch == 0.4.0
torchvision == 0.2.1
scikit-image == 0.14.0

OS == Linux system (Ubuntu 16.04LTS)

Execution

training

Train for N way M shots

python3 main.py --data_root 'data' --nway N --shots M

Pretrain CNN-embedding layer and train for N way M shots

python3 main.py --data_root 'data' --nway N --shots M --pretrain

Pretrain CNN-embedding layer and train for N way M shots while fixing the CNN-embedding layer

python3 main.py --data_root 'data' --nway N --shots M --pretrain --freeze_cnn

testing

Test for N way M shot:
python main.py --todo 'test' --data_root 'data' --nway N --shots M --load --load_dir [path for folder saving model.pth]

Experiment Result

  • 5 way (validation / test)
1-shot 5-shot 10-shot
end-to-end 37.68 % / 35.60 % 61.42 % / 63.20 % 70.05 % / 68.60 %
pretrain and fine-tune 49.66 % / 49.60 % 63.84 % / 59.00 % 69.69 % / 67.20 %
pretrain and fixed 48.30 % / 42.80 % 63.74 % / 65.60 % 68.01 % / 67.20 %
  • 20 way (validation / test)
1-shot 5-shot 10-shot
end-to-end 19.85 % / 16.55 % 36.58 % / 35.85 % 38.34 % / 44.05 %
pretrain and fine-tune 20.85 % / 22.95 % 35.50 % / 37.15 % 42.61 % / 41.10 %
pretrain and fixed 22.68 % / 19.25 % 35.66 % / 30.75 % 41.67 % / 42.25 %
  1. The seed for choosing few-shot class is 1.
  2. Only run each experiment for 1 time. Running it for multiple times can get more convincing results.
  3. Note that if increasing the layers of the model too much, end-to-end training might fail. Pretraining CNN-embedding layers can cure this problem.

Contact

Yi-Lin Sung, [email protected]

gnn_few_shot_cifar100's People

Contributors

ylsung avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

gnn_few_shot_cifar100's Issues

class overlap between train and test data

I am not able to find out how class overlap between train and test data is being maintained. Typically in few shot learning evaluation, few classes are never shown to the model, and on test time we evaluate the model by giving it N examples from those classes that have not been shown to the model. I am not sure if this separation is maintained in the code.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.