Giter Club home page Giter Club logo

separateemd's Introduction

Hierarchy SeparateEMD For Few-Shot Learning

The code repository for "Hierarchy SeparateEMD For Few-Shot Learning" in PyTorch. If you use any content of this repo for your work, please cite the following bib entry:

@inproceedings{separate2022fewshot,
  author    = {Yaqiang Sun and
               Jie Hao and
               Zhuojun Zou and
               Lin Shu and
               Shengjie Hu},
  title     = {Hierarchy SeparateEMD For Few-Shot Learning},
  booktitle="Methods and Applications for Modeling and Simulation of Complex Systems",
  year="2022",
  publisher="Springer Nature Singapore",
  address="Singapore",
  pages="548--560",
  isbn="978-981-19-9198-1"
}

SeparateEMD

We propose a novel model-based approach to adapt the few shot classifacation task. We denote our method as Hierarchy SeparateEMD.

Standard Few-shot Learning Results

Experimental results on few-shot learning datasets with ResNet-12 backbone. We report average results with 10,000 randomly sampled few-shot learning episodes for stablized evaluation.

MiniImageNet Dataset

Setups 1-Shot 5-Way 5-Shot 5-Way
ProtoNet 62.39 80.53
BILSTM 63.90 80.63
DEEPSETS 64.14 80.93
GCN 64.50 81.65
FEAT 66.78 82.05
DeepEMD 68.77 84.13
SeparateEMD 69.03 85.27

Prerequisites

The following packages are required to run the scripts:

Dataset

MiniImageNet Dataset

The MiniImageNet dataset is a subset of the ImageNet that includes a total number of 100 classes and 600 examples per class. We follow the previous setup, and use 64 classes as SEEN categories, 16 and 20 as two sets of UNSEEN categories for model validation and evaluation, respectively.

Code Structures

To reproduce our experiments, please use train_fsl.py. There are four parts in the code.

  • model: It contains the main files of the code, including the few-shot learning trainer, the dataloader, the network architectures, and baseline and comparison models.
  • data: Images and splits for the data sets.
  • saves: The pre-trained weights of different networks.
  • checkpoints: To save the trained models.

Model Training and Evaluation

Please use train.py and follow the instructions below. The file will automatically evaluate the model on the meta-test set with 10,000 tasks after given epochs.

Training scripts for SeparateEMD

For example, to train the 1-shot/5-shot 5-way SeparateEMD model with ResNet-12 backbone on MiniImageNet:

$ python train.py  --max_epoch 60 --model_class SeparateEMD  --backbone_class Res12 --dataset MiniImageNet --way 5 --eval_way 5 --shot 1 --eval_shot 1 --query 15 --eval_query 15 --balance 0.01 --temperature 64 --temperature2 64 --lr 0.0002 --lr_mul 10 --lr_scheduler step --step_size 40 --gamma 0.5 --gpu 1 --init_weights ./saves/initialization/miniimagenet/Res12-pre.pth --eval_interval 1 --use_euclidean
$ python train.py  --max_epoch 60 --model_class SeparateEMD  --backbone_class Res12 --dataset MiniImageNet --way 5 --eval_way 5 --shot 5 --eval_shot 5 --query 15 --eval_query 15 --balance 0.1 --temperature 64 --temperature2 32 --lr 0.0002 --lr_mul 10 --lr_scheduler step --step_size 40 --gamma 0.5 --gpu 0 --init_weights ./saves/initialization/miniimagenet/Res12-pre.pth --eval_interval 1 --use_euclidean

Acknowledgment

We thank the following repos providing helpful components/functions in our work.

separateemd's People

Contributors

yaqiangsun avatar

Stargazers

 avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.