Giter Club home page Giter Club logo

ifsl's Introduction

Interventional Few-Shot Learning

This project provides a strong Baseline with WRN28-10 and ResNet-10 backbone for the following Few-Shot Learning methods:

  • Fine-tuning
  • Matching Networks
  • Model-Agnostic Meta-Learning (MAML)
  • Meta-Transfer Learning for Few-Shot Learning (MTL)
  • Meta-Learning with Latent Embedding Optimization (LEO)
  • Synthetic Information Bottleneck (SIB)

This also includes implementation of our NeurIPS 2020 paper Interventional Few-Shot Learning, which proposes IFSL classifier based on intervention P(Y|do(X)) to remove the confounding bias from pre-trained knowledge. Our IFSL classifier is generally applicable to all fine-tuning and meta-learning method, easy to plug in and involves no additional training steps.

The codes are organized into four folders according to methods. The folder MAML_MN_FT contains baseline and IFSL for fine-tuning, Matching Networks and MAML.

Dependencies

Recommended version:

  • Python 3.7.6
  • PyTorch 1.4.0

Preparation

After downloading the weights and datasets, you can follow the instructions in each folder to modify the code and finish preparation.

TODO

Apologize in advance for dirty code, which I will clean up gradually.

  • Code refactoring
  • Improve documentation and optimize project setup procedures

References

The implementation is based on the following repositories (for correctness of baseline, most of our code is based on the official released code).

Citation

If you find our work or the code useful, please consider cite our paper using:

@inproceedings{yue2020interventional,
  title={Interventional Few-Shot Learning},
  author={Yue, Zhongqi and Zhang, Hanwang and Sun, Qianru and Hua, Xian-Sheng},
  booktitle= {NeurIPS},
  year={2020}
}

ifsl's People

Contributors

yue-zhongqi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

ifsl's Issues

how to train MAML+ifsl with dataset miniImagenet and model ResNet10

hello,I want to reproduce results 76.37% from your paper (MAML + ifsl with dataset miniImagenet and model ResNet10 ). But when I run an example command (python main.py --dataset miniImagenet --model na --method metatrain --train_aug --test maml5_ifsl_resnet ),get a wrong answer whose acc=32. Could you tell how to get the 76.37% from your paper?If my command was wrong or not, what is the correct command? looking forward to your reply.thank you!

need file model_best.pth.tar

Hi, I encounter some problems - I'm trying to reproduce results from your paper and when I run an example command from README for MAML_MN_FT (python main.py --method metatrain --train_aug --test maml5_resnet ). In 116 line in PretrainedModel.py ,codes("checkpoint = torch.load(model_dir)") require model_best.pth.tar file. Could you tell whether it's possible for you to provide this file or create one on our own?

RunTimeError: Tensor size mismatch

Hello. I run python main.py --method metatrain --train_aug --test maml5_ifsl_wrn_tiered and got an error like this:

File "E:\LinZD\IFSL\MAML_MN_FT\methods\DMAML.py", line 168, in train_loop
loss = self.set_forward_loss(x)
File "E:\LinZD\IFSL\MAML_MN_FT\methods\DMAML.py", line 152, in set_forward_loss
scores = self.set_forward(x, is_feature=False)
File "E:\LinZD\IFSL\MAML_MN_FT\methods\DMAML.py", line 106, in set_forward
split_support, support_d, split_query, query_d = self.feature_processor.get_features(support, query)
File "E:\LinZD\IFSL\MAML_MN_FT\methods\meta_toolkits.py", line 145, in get_features
support_d = self.get_d_feature(support)
File "E:\LinZD\IFSL\MAML_MN_FT\methods\meta_toolkits.py", line 139, in get_d_feature
d_feature[i] = pd
RuntimeError: The expanded size of the tensor (64) must match the existing size (351) at non-singleton dimension 1. Target sizes: [25, 64]. Tensor sizes: [25, 351]

Any help would be appreciated! Thank you!!

My Environment:
python 3.7.6, pyTorch 1.4.0 and necessary datasets and packages

Hyperparametric problems

under your instructions, i run python main.py --dataset miniImagenet --method metatrain --train_aug --test maml5_resnet
and finally get the result of 66.5%,far from the result in your paper about 70%, is the Hyperparametric problems?
ps: I did not use the novel.hdf5 file,because I don't think this file is getting anything useful,does file make the situation?

How to use save_feature.py

hello

I have a problem. How to use save_feature.py? Using this script lacks a lot of dependencies.
Thanks for you advices.

Where is the code for function c = g(x, d) in the paper?

Thanks for amazing work and open source code.

However, I have a question:
In supplemental materials, for fine-tuning+IFSL, you calculate c = g(x, d), I wonder where is this function in your code?
Thanks very much and looking forward to your reply!

question about running this code with tieredwrn_1_ifsl.yaml in file SIB

hello authors!
There is a problem while I'm running the code with config file tieredwrn_1_ifsl.yaml, an assertion error occurred. And it happened in class LinearDiag, networks.py. It shows that error occurred in “assert(X.dim()==2 and assert(X.dim()==2 and X.size(1)==self.weight.size(0)))”, and I finally find " (X.size(1)==self.weight.size(0)) = False", because the value of X.size(1) is 640 and the value of self.weight.size(0) is 512. Could you please give me some advises to solve this problem? Thanks a lot for your help!

fine-tuning

Is there any method for fine-tuning included here?

Can't get the reported performance

Sorry that I can not get the reported performance.
I am running the experiment SIB/minires_1_ifsl.yaml, my validation acc is lower than 60%, but reported 68.85%?

Look forward for reply! Thanks!

How to config IFSL?

hello, your job was nice. I try to run your code in MTL with "--config=mini_5_resnet_d" ,are you sure the current config is rigtht? how to get into mtl.py 68 line.
‘’‘
if args.param.learner == "IFSL":
self.base_learner = IFSL(args.way, args.shot, self.pretrain, **args.param.ifsl_params)
’‘’
There seems to be a bug in this mtl.py 68 line. ” No IFSL“. Thanks

How to plug and play

sorry for my poor program level, I can not replay this code,I try to replay the four files code but all fail. Besides, I read the code but I do not know which part of the code implements feature intervention and which part implements class intervention, so I do not know how to plug and play in another moedls.Has anyone done it? can you help me? Thank you very much.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.