Giter Club home page Giter Club logo

adversarial-continual-learning's Introduction

Adversarial Continual Learning

This is the official PyTorch implementation of the Adversarial Continual Learning at ECCV 2020. .

Abstract

Continual learning aims to learn new tasks without forgetting previously learned ones. We hypothesize that representations learned to solve each task in a sequence have shared structure while containing some task-specific properties. We show that shared features are significantly less prone to forgetting and propose a novel hybrid continual learning framework that learns a disjoint representation for task-invariant and task-specific features required to solve a sequence of tasks. Our model combines architecture growth to prevent forgetting of task-specific skills and an experience replay approach to preserve shared skills. We demonstrate our hybrid approach is effective in avoiding forgetting and show it is superior to both architecture-based and memory-based approaches on class incrementally learning of a single dataset as well as a sequence of multiple datasets in image classification.

Authors:

Sayna Ebrahimi (UC Berkeley, FAIR), Franziska Meier (FAIR), Roberto Calandra (FAIR), Trevor Darrell (UC Berkeley), Marcus Rohrbach (FAIR)

Citation

If using this code, parts of it, or developments from it, please cite our paper:

@article{ebrahimi2020adversarial,
  title={Adversarial Continual Learning},
  author={Ebrahimi, Sayna and Meier, Franziska and Calandra, Roberto and Darrell, Trevor and Rohrbach, Marcus},
  journal={arXiv preprint arXiv:2003.09553},
  year={2020}
}

Prerequisites:

  • Linux-64
  • Python 3.6
  • PyTorch 1.3.1
  • CPU or NVIDIA GPU + CUDA10 CuDNN7.5

Installation

  • Create a conda environment and install required packages:
conda create -n <env> python=3.6
conda activate <env>
pip install -r requirements.txt
  • Clone this repo:
mkdir ACL
cd ACL
git clone [email protected]:facebookresearch/Adversarial-Continual-Learning.git
  • The following structure is expected in the main directory:
./src                     : main directory where all scripts are placed in
./data                    : data directory
./src/checkpoints         : results are saved in here
For each datasest run the following commands from src directory. Config file for each experiment contains the hyperparameters we used in the paper:

Split MNIST (5 Tasks):

python main.py --config ./configs/config_mnist5.yml

Permuted MNIST (10 Tasks):

python main.py --config ./configs/config_pmnist.yml

Split CIFAR100 (20 Tasks):

python main.py --config ./configs/config_cifar100.yml

Split MiniImageNet (20 Tasks):

python main.py --config ./configs/config_miniimagenet.yml

Sequence of 5 Tasks (CIFAR10, MNIST, notMNIST, Fashion MNIST, SVHN)

python main.py --config ./configs/config_multidatasets.yml

ACL with ResNet18 backbone

See here.

Datasets

miniImageNet data should be downloaded and pickled as a dictionary (data.pkl) with images and labels keys and placed in a sub-folder in ags.data_dir named as miniimagenet. The script used to split data.pkl into training and test sets is included in data dorectory (data/)

notMNIST dataset is included here in ./data/notMNIST as it was used in our experiments.

Other datasets will be automatically downloaded and extracted to ./data if they do not exist.

Questions/ Bugs

License

This source code is released under The MIT License found in the LICENSE file in the root directory of this source tree.

Acknowledgements

Our code structure is inspired by HAT.

adversarial-continual-learning's People

Contributors

saynaebrahimi avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.