Giter Club home page Giter Club logo

url's Introduction

Universal Representation Learning from Multiple Domains and Cross-domain Few-shot Learning with Task-specific Adapters

This is the implementation of Universal Representation Learning from Multiple Domains for Few-shot Classification (ICCV'21) and Cross-domain Few-shot Learning with Task-specific Adapters (CVPR'22) introduced by Wei-Hong Li, Xialei Liu, and Hakan Bilen.

Updates

Dependencies

This code requires the following:

  • Python 3.6 or greater
  • PyTorch 1.0 or greater
  • TensorFlow 1.14 or greater

Installation

  • Clone or download this repository.
  • Configure Meta-Dataset:
    • Follow the "User instructions" in the Meta-Dataset repository for "Installation" and "Downloading and converting datasets".
    • Edit ./meta-dataset/data/reader.py in the meta-dataset repository to change dataset = dataset.batch(batch_size, drop_remainder=False) to dataset = dataset.batch(batch_size, drop_remainder=True). (The code can run with drop_remainder=False, but in our work, we drop the remainder such that we will not use very small batch for some domains)
    • To test unseen domain (out-of-domain) performance on additional datasets, i.e. MNIST, CIFAR-10 and CIFAR-100, follow the installation instruction in the CNAPs repository to get these datasets.

Initialization

  1. Before doing anything, first run the following commands.

    ulimit -n 50000
    export META_DATASET_ROOT=<root directory of the cloned or downloaded Meta-Dataset repository>
    export RECORDS=<the directory where tf-records of MetaDataset are stored>
    

    Note the above commands need to be run every time you open a new command shell.

  2. Enter the root directory of this project, i.e. the directory where this project was cloned or downloaded.

Universal Representation Learning from Multiple Domains for Few-shot Classification

Figure 1. URL - Universal Representation Learning.

Train the Universal Representation Learning Network (URL)

  1. The easiest way is to download our pre-trained URL model and evaluate its feature using our Pre-classifier Alignment (PA). To download the pretrained URL model, one can use gdown (installed by pip install gdown) and execute the following command in the root directory of this project:

    gdown https://drive.google.com/uc?id=1Dv8TX6iQ-BE2NMpfd0sQmH2q4mShmo1A && md5sum url.zip && unzip url.zip -d ./saved_results/ && rm url.zip
    
    

    This will donwnload the URL model and place it in the ./saved_results directory. One can evaluate this model by our PA (see the Meta-Testing step)

  2. Alternatively, one can train the model from scratch: 1) train 8 single domain learning networks; 2) train the universal feature extractor as follow.

Train Single Domain Learning Networks

  1. The easiest way is to download our pre-trained models and use them to obtain a universal set of features directly. To download single domain learning networks, execute the following command in the root directory of this project:

    gdown https://drive.google.com/uc?id=1MvUcvQ8OQtoOk1MIiJmK6_G8p4h8cbY9 && md5sum sdl.zip && unzip sdl.zip -d ./saved_results/ && rm sdl.zip
    

    This will download all single domain learning models and place them in the ./saved_results directory of this project.

  2. Alternatively, instead of using the pretrained models, one can train the models from scratch. To train 8 single domain learning networks, run:

    ./scripts/train_resnet18_sdl.sh
    

Train the Universal Feature Extractor

To learn the universal feature extractor by distilling the knowledge from pre-trained single domain learning networks, run:

./scripts/train_resnet18_url.sh

Meta-Testing with Pre-classifier Alignment (PA)

Figure 2. PA - Pre-classifier Alignment for Adapting Features in Meta-test.

This step would run our Pre-classifier Alignment (PA) procedure per task to adapt the features to a discriminate space and build a Nearest Centroid Classifier (NCC) on the support set to classify query samples, run:

./scripts/test_resnet18_pa.sh

Cross-domain Few-shot Learning with Task-specific Adapters

Figure 3. Cross-domain Few-shot Learning with Task-specific Adapters (TSA).

We provide code for attaching task-specific adapters (TSA) to a single universal network learned from meta-train and learn the task-specific adapters on the support set. One can download our pre-trained URL model and evaluate its feature adapted by residual adapters in matrix form and pre-classifier alignment, run:

./scripts/test_resnet18_tsa.sh

Expected Results

Below are the results extracted from our papers. The results will vary from run to run by a percent or two up or down due to the fact that the Meta-Dataset reader generates different tasks each run, randomnes in training the networks and in TSA and PA optimization. Note, the results are updated with the up-to-date evaluation from Meta-Dataset. Make sure that you use the up-to-date code from the Meta-Dataset repository to convert the dataset and set shuffle_buffer_size=1000 as mentioned in google-research/meta-dataset#54.

Models trained on all datasets

Test Datasets TSA (Ours) URL (Ours) MDL Best SDL tri-M [8] FLUTE [7] URT [6] SUR [5] Transductive CNAPS [4] Simple CNAPS [3] CNAPS [2]
Avg rank 1.5 2.7 7.1 6.7 5.5 5.1 6.7 6.9 5.7 7.2 -
ImageNet 57.4±1.1  57.5±1.1  52.9±1.2  54.3±1.1  58.6±1.0  51.8±1.1  55.0±1.1  54.5±1.1  57.9±1.1  56.5±1.1  50.8±1.1 
Omniglot 95.0±0.4  94.5±0.4  93.7±0.5  93.8±0.5  92.0±0.6  93.2±0.5  93.3±0.5  93.0±0.5  94.3±0.4  91.9±0.6  91.7±0.5 
Aircraft 89.3±0.4  88.6±0.5  84.9±0.5  84.5±0.5  82.8±0.7  87.2±0.5  84.5±0.6  84.3±0.5  84.7±0.5  83.8±0.6  83.7±0.6 
Birds 81.4±0.7  80.5±0.7  79.2±0.8  70.6±0.9  75.3±0.8  79.2±0.8  75.8±0.8  70.4±1.1  78.8±0.7  76.1±0.9  73.6±0.9 
Textures 76.7±0.7  76.2±0.7  70.9±0.8  72.1±0.7  71.2±0.8  68.8±0.8  70.6±0.7  70.5±0.7  66.2±0.8  70.0±0.8  59.5±0.7 
Quick Draw 82.0±0.6  81.9±0.6  81.7±0.6  82.6±0.6  77.3±0.7  79.5±0.7  82.1±0.6  81.6±0.6  77.9±0.6  78.3±0.7  74.7±0.8 
Fungi 67.4±1.0  68.8±0.9  63.2±1.1  65.9±1.0  48.5±1.0  58.1±1.1  63.7±1.0  65.0±1.0  48.9±1.2  49.1±1.2  50.2±1.1 
VGG Flower 92.2±0.5  92.1±0.5  88.7±0.6  86.7±0.6  90.5±0.5  91.6±0.6  88.3±0.6  82.2±0.8  92.3±0.4  91.3±0.6  88.9±0.5 
Traffic Sign 83.5±0.9  63.3±1.2  49.2±1.0  47.1±1.1  63.0±1.0  58.4±1.1  50.1±1.1  49.8±1.1  59.7±1.1  59.2±1.0  56.5±1.1 
MSCOCO 55.8±1.1  54.0±1.0  47.3±1.1  49.7±1.0  52.8±1.1  50.0±1.0  48.9±1.1  49.4±1.1  42.5±1.1  42.4±1.1  39.4±1.0 
MNIST 96.7±0.4  94.5±0.5  94.2±0.4  91.0±0.5  96.2±0.3  95.6±0.5  90.5±0.4  94.9±0.4  94.7±0.3  94.3±0.4  -
CIFAR-10 80.6±0.8  71.9±0.7  63.2±0.8  65.4±0.8  75.4±0.8  78.6±0.7  65.1±0.8  64.2±0.9  73.6±0.7  72.0±0.8  -
CIFAR-100 69.6±1.0  62.6±1.0  54.7±1.1  56.2±1.0  62.0±1.0  67.1±1.0  57.2±1.0  57.1±1.1  61.8±1.0  60.9±1.1  -

[1] Eleni Triantafillou, Tyler Zhu, Vincent Dumoulin, Pascal Lamblin, Utku Evci, Kelvin Xu, Ross Goroshin, Carles Gelada, Kevin Swersky, Pierre-Antoine Manzagol, Hugo Larochelle; Meta-Dataset: A Dataset of Datasets for Learning to Learn from Few Examples; ICLR 2020.

[2] James Requeima, Jonathan Gordon, John Bronskill, Sebastian Nowozin, Richard E. Turner; Fast and Flexible Multi-Task Classification Using Conditional Neural Adaptive Processes; NeurIPS 2019.

[3] Peyman Bateni, Raghav Goyal, Vaden Masrani, Frank Wood, Leonid Sigal; Improved Few-Shot Visual Classification; CVPR 2020.

[4] Peyman Bateni, Jarred Barber, Jan-Willem van de Meent, Frank Wood; Enhancing Few-Shot Image Classification with Unlabelled Examples; WACV 2022.

[5] Nikita Dvornik, Cordelia Schmid, Julien Mairal; Selecting Relevant Features from a Multi-domain Representation for Few-shot Classification; ECCV 2020.

[6] Lu Liu, William Hamilton, Guodong Long, Jing Jiang, Hugo Larochelle; Universal Representation Transformer Layer for Few-Shot Image Classification; ICLR 2021.

[7] Eleni Triantafillou, Hugo Larochelle, Richard Zemel, Vincent Dumoulin; Learning a Universal Template for Few-shot Dataset Generalization; ICML 2021.

[8] Yanbin Liu, Juho Lee, Linchao Zhu, Ling Chen, Humphrey Shi, Yi Yang; A Multi-Mode Modulator for Multi-Domain Few-Shot Classification; ICCV 2021.

Other Usage

Train a Vanilla Multi-domain Learning Network (optional)

To train a vanilla multi-domain learning network (MDL) on Meta-Dataset, run:

./scripts/train_resnet18_mdl.sh

Other Classifiers for Meta-Testing (optional)

One can use other classifiers for meta-testing, e.g. use --test.loss-opt to select nearest centroid classifier (ncc, default), support vector machine (svm), logistic regression (lr), Mahalanobis distance from Simple CNAPS (scm), or k-nearest neighbor (knn); use --test.feature-norm to normalize feature (l2) or not for svm and lr; use --test.distance to specify the feature similarity function (l2 or cos) for NCC.

To evaluate the feature extractor with NCC and cosine similarity, run:

python test_extractor.py --test.loss-opt ncc --test.feature-norm none --test.distance cos --model.name=url --model.dir <directory of url> 

Five-shot and Five-way-one-shot Meta-test (optional)

One can evaluate the feature extractor in meta-testing for five-shot or five-way-one-shot setting by setting --test.type as '5shot' or '1shot', respectively.

To test the feature extractor for varying-way-five-shot on the test splits of all datasets, run:

python test_extractor.py --test.type 5shot --test.loss-opt ncc --test.feature-norm none --test.distance cos --model.name=url --model.dir <directory of url>

Acknowledge

We thank authors of Meta-Dataset, SUR, Residual Adapter for their source code.

Citation

If you use this code, please cite our papers:

@article{li2022Universal,
    author    = {Li, Wei-Hong and Liu, Xialei and Bilen, Hakan},
    title     = {Universal Representations: A Unified Look at Multiple Task and Domain Learning},
    journal   = {arXiv preprint arXiv:2204.02744},
    year      = {2022}
}

@inproceedings{li2022TaskSpecificAdapter,
    author    = {Li, Wei-Hong and Liu, Xialei and Bilen, Hakan},
    title     = {Cross-domain Few-shot Learning with Task-specific Adapters},
    booktitle = {IEEE/CVF International Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2022}
}

@inproceedings{li2021Universal,
    author    = {Li, Wei-Hong and Liu, Xialei and Bilen, Hakan},
    title     = {Universal Representation Learning From Multiple Domains for Few-Shot Classification},
    booktitle = {IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {9526-9535}
}

@inproceedings{li2020knowledge,
    author    = {Li, Wei-Hong and Bilen, Hakan},
    title     = {Knowledge distillation for multi-task learning},
    booktitle = {European Conference on Computer Vision (ECCV) Workshop},
    year      = {2020},
    xcode     = {https://github.com/VICO-UoE/KD4MTL}
}

url's People

Contributors

weihonglee avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.