Universal Representation Learning from Multiple Domains and Cross-domain Few-shot Learning with Task-specific Adapters
This is the implementation of Universal Representation Learning from Multiple Domains for Few-shot Classification (ICCV'21) and Cross-domain Few-shot Learning with Task-specific Adapters (CVPR'22) introduced by Wei-Hong Li, Xialei Liu, and Hakan Bilen.
- March'22, Code for Cross-domain Few-shot Learning with Task-specific Adapters (CVPR'22) is now available! See TSA.
- Oct'21, Code and models for Universal Representation Learning from Multiple Domains for Few-shot Classification (ICCV'21) are now available!
This code requires the following:
- Python 3.6 or greater
- PyTorch 1.0 or greater
- TensorFlow 1.14 or greater
- Clone or download this repository.
- Configure Meta-Dataset:
- Follow the "User instructions" in the Meta-Dataset repository for "Installation" and "Downloading and converting datasets".
- Edit
./meta-dataset/data/reader.py
in the meta-dataset repository to changedataset = dataset.batch(batch_size, drop_remainder=False)
todataset = dataset.batch(batch_size, drop_remainder=True)
. (The code can run withdrop_remainder=False
, but in our work, we drop the remainder such that we will not use very small batch for some domains) - To test unseen domain (out-of-domain) performance on additional datasets, i.e. MNIST, CIFAR-10 and CIFAR-100, follow the installation instruction in the CNAPs repository to get these datasets.
-
Before doing anything, first run the following commands.
ulimit -n 50000 export META_DATASET_ROOT=<root directory of the cloned or downloaded Meta-Dataset repository> export RECORDS=<the directory where tf-records of MetaDataset are stored>
Note the above commands need to be run every time you open a new command shell.
-
Enter the root directory of this project, i.e. the directory where this project was cloned or downloaded.
Figure 1. URL - Universal Representation Learning.
-
The easiest way is to download our pre-trained URL model and evaluate its feature using our Pre-classifier Alignment (PA). To download the pretrained URL model, one can use
gdown
(installed bypip install gdown
) and execute the following command in the root directory of this project:gdown https://drive.google.com/uc?id=1Dv8TX6iQ-BE2NMpfd0sQmH2q4mShmo1A && md5sum url.zip && unzip url.zip -d ./saved_results/ && rm url.zip
This will donwnload the URL model and place it in the
./saved_results
directory. One can evaluate this model by our PA (see the Meta-Testing step) -
Alternatively, one can train the model from scratch: 1) train 8 single domain learning networks; 2) train the universal feature extractor as follow.
-
The easiest way is to download our pre-trained models and use them to obtain a universal set of features directly. To download single domain learning networks, execute the following command in the root directory of this project:
gdown https://drive.google.com/uc?id=1MvUcvQ8OQtoOk1MIiJmK6_G8p4h8cbY9 && md5sum sdl.zip && unzip sdl.zip -d ./saved_results/ && rm sdl.zip
This will download all single domain learning models and place them in the
./saved_results
directory of this project. -
Alternatively, instead of using the pretrained models, one can train the models from scratch. To train 8 single domain learning networks, run:
./scripts/train_resnet18_sdl.sh
To learn the universal feature extractor by distilling the knowledge from pre-trained single domain learning networks, run:
./scripts/train_resnet18_url.sh
Figure 2. PA - Pre-classifier Alignment for Adapting Features in Meta-test.
This step would run our Pre-classifier Alignment (PA) procedure per task to adapt the features to a discriminate space and build a Nearest Centroid Classifier (NCC) on the support set to classify query samples, run:
./scripts/test_resnet18_pa.sh
Figure 3. Cross-domain Few-shot Learning with Task-specific Adapters (TSA).
We provide code for attaching task-specific adapters (TSA) to a single universal network learned from meta-train and learn the task-specific adapters on the support set. One can download our pre-trained URL model and evaluate its feature adapted by residual adapters in matrix form and pre-classifier alignment, run:
./scripts/test_resnet18_tsa.sh
Below are the results extracted from our papers. The results will vary from run to run by a percent or two up or down due to the fact that the Meta-Dataset reader generates different tasks each run, randomnes in training the networks and in TSA and PA optimization. Note, the results are updated with the up-to-date evaluation from Meta-Dataset. Make sure that you use the up-to-date code from the Meta-Dataset repository to convert the dataset and set shuffle_buffer_size=1000
as mentioned in google-research/meta-dataset#54.
Models trained on all datasets
Test Datasets | TSA (Ours) | URL (Ours) | MDL | Best SDL | tri-M [8] | FLUTE [7] | URT [6] | SUR [5] | Transductive CNAPS [4] | Simple CNAPS [3] | CNAPS [2] |
---|---|---|---|---|---|---|---|---|---|---|---|
Avg rank | 1.5 | 2.7 | 7.1 | 6.7 | 5.5 | 5.1 | 6.7 | 6.9 | 5.7 | 7.2 | - |
ImageNet | 57.4±1.1 | 57.5±1.1 | 52.9±1.2 | 54.3±1.1 | 58.6±1.0 | 51.8±1.1 | 55.0±1.1 | 54.5±1.1 | 57.9±1.1 | 56.5±1.1 | 50.8±1.1 |
Omniglot | 95.0±0.4 | 94.5±0.4 | 93.7±0.5 | 93.8±0.5 | 92.0±0.6 | 93.2±0.5 | 93.3±0.5 | 93.0±0.5 | 94.3±0.4 | 91.9±0.6 | 91.7±0.5 |
Aircraft | 89.3±0.4 | 88.6±0.5 | 84.9±0.5 | 84.5±0.5 | 82.8±0.7 | 87.2±0.5 | 84.5±0.6 | 84.3±0.5 | 84.7±0.5 | 83.8±0.6 | 83.7±0.6 |
Birds | 81.4±0.7 | 80.5±0.7 | 79.2±0.8 | 70.6±0.9 | 75.3±0.8 | 79.2±0.8 | 75.8±0.8 | 70.4±1.1 | 78.8±0.7 | 76.1±0.9 | 73.6±0.9 |
Textures | 76.7±0.7 | 76.2±0.7 | 70.9±0.8 | 72.1±0.7 | 71.2±0.8 | 68.8±0.8 | 70.6±0.7 | 70.5±0.7 | 66.2±0.8 | 70.0±0.8 | 59.5±0.7 |
Quick Draw | 82.0±0.6 | 81.9±0.6 | 81.7±0.6 | 82.6±0.6 | 77.3±0.7 | 79.5±0.7 | 82.1±0.6 | 81.6±0.6 | 77.9±0.6 | 78.3±0.7 | 74.7±0.8 |
Fungi | 67.4±1.0 | 68.8±0.9 | 63.2±1.1 | 65.9±1.0 | 48.5±1.0 | 58.1±1.1 | 63.7±1.0 | 65.0±1.0 | 48.9±1.2 | 49.1±1.2 | 50.2±1.1 |
VGG Flower | 92.2±0.5 | 92.1±0.5 | 88.7±0.6 | 86.7±0.6 | 90.5±0.5 | 91.6±0.6 | 88.3±0.6 | 82.2±0.8 | 92.3±0.4 | 91.3±0.6 | 88.9±0.5 |
Traffic Sign | 83.5±0.9 | 63.3±1.2 | 49.2±1.0 | 47.1±1.1 | 63.0±1.0 | 58.4±1.1 | 50.1±1.1 | 49.8±1.1 | 59.7±1.1 | 59.2±1.0 | 56.5±1.1 |
MSCOCO | 55.8±1.1 | 54.0±1.0 | 47.3±1.1 | 49.7±1.0 | 52.8±1.1 | 50.0±1.0 | 48.9±1.1 | 49.4±1.1 | 42.5±1.1 | 42.4±1.1 | 39.4±1.0 |
MNIST | 96.7±0.4 | 94.5±0.5 | 94.2±0.4 | 91.0±0.5 | 96.2±0.3 | 95.6±0.5 | 90.5±0.4 | 94.9±0.4 | 94.7±0.3 | 94.3±0.4 | - |
CIFAR-10 | 80.6±0.8 | 71.9±0.7 | 63.2±0.8 | 65.4±0.8 | 75.4±0.8 | 78.6±0.7 | 65.1±0.8 | 64.2±0.9 | 73.6±0.7 | 72.0±0.8 | - |
CIFAR-100 | 69.6±1.0 | 62.6±1.0 | 54.7±1.1 | 56.2±1.0 | 62.0±1.0 | 67.1±1.0 | 57.2±1.0 | 57.1±1.1 | 61.8±1.0 | 60.9±1.1 | - |
[1] Eleni Triantafillou, Tyler Zhu, Vincent Dumoulin, Pascal Lamblin, Utku Evci, Kelvin Xu, Ross Goroshin, Carles Gelada, Kevin Swersky, Pierre-Antoine Manzagol, Hugo Larochelle; Meta-Dataset: A Dataset of Datasets for Learning to Learn from Few Examples; ICLR 2020.
[2] James Requeima, Jonathan Gordon, John Bronskill, Sebastian Nowozin, Richard E. Turner; Fast and Flexible Multi-Task Classification Using Conditional Neural Adaptive Processes; NeurIPS 2019.
[3] Peyman Bateni, Raghav Goyal, Vaden Masrani, Frank Wood, Leonid Sigal; Improved Few-Shot Visual Classification; CVPR 2020.
[4] Peyman Bateni, Jarred Barber, Jan-Willem van de Meent, Frank Wood; Enhancing Few-Shot Image Classification with Unlabelled Examples; WACV 2022.
[5] Nikita Dvornik, Cordelia Schmid, Julien Mairal; Selecting Relevant Features from a Multi-domain Representation for Few-shot Classification; ECCV 2020.
[6] Lu Liu, William Hamilton, Guodong Long, Jing Jiang, Hugo Larochelle; Universal Representation Transformer Layer for Few-Shot Image Classification; ICLR 2021.
[7] Eleni Triantafillou, Hugo Larochelle, Richard Zemel, Vincent Dumoulin; Learning a Universal Template for Few-shot Dataset Generalization; ICML 2021.
[8] Yanbin Liu, Juho Lee, Linchao Zhu, Ling Chen, Humphrey Shi, Yi Yang; A Multi-Mode Modulator for Multi-Domain Few-Shot Classification; ICCV 2021.
To train a vanilla multi-domain learning network (MDL) on Meta-Dataset, run:
./scripts/train_resnet18_mdl.sh
One can use other classifiers for meta-testing, e.g. use --test.loss-opt
to select nearest centroid classifier (ncc, default), support vector machine (svm), logistic regression (lr), Mahalanobis distance from Simple CNAPS (scm), or k-nearest neighbor (knn); use --test.feature-norm
to normalize feature (l2) or not for svm and lr; use --test.distance
to specify the feature similarity function (l2 or cos) for NCC.
To evaluate the feature extractor with NCC and cosine similarity, run:
python test_extractor.py --test.loss-opt ncc --test.feature-norm none --test.distance cos --model.name=url --model.dir <directory of url>
One can evaluate the feature extractor in meta-testing for five-shot or five-way-one-shot setting by setting --test.type
as '5shot' or '1shot', respectively.
To test the feature extractor for varying-way-five-shot on the test splits of all datasets, run:
python test_extractor.py --test.type 5shot --test.loss-opt ncc --test.feature-norm none --test.distance cos --model.name=url --model.dir <directory of url>
We thank authors of Meta-Dataset, SUR, Residual Adapter for their source code.
If you use this code, please cite our papers:
@article{li2022Universal,
author = {Li, Wei-Hong and Liu, Xialei and Bilen, Hakan},
title = {Universal Representations: A Unified Look at Multiple Task and Domain Learning},
journal = {arXiv preprint arXiv:2204.02744},
year = {2022}
}
@inproceedings{li2022TaskSpecificAdapter,
author = {Li, Wei-Hong and Liu, Xialei and Bilen, Hakan},
title = {Cross-domain Few-shot Learning with Task-specific Adapters},
booktitle = {IEEE/CVF International Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2022}
}
@inproceedings{li2021Universal,
author = {Li, Wei-Hong and Liu, Xialei and Bilen, Hakan},
title = {Universal Representation Learning From Multiple Domains for Few-Shot Classification},
booktitle = {IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2021},
pages = {9526-9535}
}
@inproceedings{li2020knowledge,
author = {Li, Wei-Hong and Bilen, Hakan},
title = {Knowledge distillation for multi-task learning},
booktitle = {European Conference on Computer Vision (ECCV) Workshop},
year = {2020},
xcode = {https://github.com/VICO-UoE/KD4MTL}
}