Giter Club home page Giter Club logo

ielt's Introduction

Fine-Grained Visual Classification via Internal Ensemble Learning Transformer

Official Pytorch implementation of :

Article: Fine-Grained Visual Classification via Internal Ensemble Learning Transformer

Published in: IEEE Transactions on Multimedia ( Early Access )

If this article is helpful to your work, please cite the following entry:

@ARTICLE{10042971,
  author={Xu, Qin and Wang, Jiahui and Jiang, Bo and Luo, Bin},
  journal={IEEE Transactions on Multimedia}, 
  title={Fine-Grained Visual Classification Via Internal Ensemble Learning Transformer}, 
  year={2023},
  volume={},
  number={},
  pages={1-14},
  doi={10.1109/TMM.2023.3244340}}

or:

Q. Xu, J. Wang, B. Jiang and B. Luo, "Fine-Grained Visual Classification Via Internal Ensemble Learning Transformer," in IEEE Transactions on Multimedia, doi: 10.1109/TMM.2023.3244340.

Abstract

Recently, vision transformers (ViTs) have been investigated in fine-grained visual recognition (FGVC) and are now considered state of the art. However, most ViT-based works ignore the different learning performances of the heads in the multihead self-attention (MHSA) mechanism and its layers. To address these issues, in this paper, we propose a novel internal ensemble learning transformer (IELT) for FGVC. The proposed IELT involves three main modules: multi-head voting (MHV) module, cross-layer refinement (CLR) module, and dynamic selection (DS) module. To solve the problem of the inconsistent performances of multiple heads, we propose the MHV module, which considers all of the heads in each layer as weak learners and votes for tokens of discriminative regions as cross-layer feature based on the attention maps and spatial relationships. To effectively mine the cross-layer feature and suppress the noise, the CLR module is proposed, where the refined feature is extracted and the assist logits operation is developed for the final prediction. In addition, a newly designed DS module adjusts the token selection number at each layer by weighting their contributions of the refined feature. In this way, the idea of ensemble learning is combined with the ViT to improve fine-grained feature representation. The experiments demonstrate that our method achieves competitive results compared with the state of the art on five popular FGVC datasets.

Network!](figures/Network.jpg)

Experiments Results

Datasets Accuracy (%) Models Logs
CUB_200_2011 91.81 - link
Stanford Dogs 91.84 - link
NABirds 90.78 - link
Oxford 102 Flowers 99.64 - link
Oxford-IIIT Pet 95.29 - link

Code Running

Requirements

python >= 3.9

pytorch >= 1.8.1

Apex (optional)

Training

  1. Put the pre-trained ViT model in pretrained/, and rename it to ViT-B_16.npz, you can download from ViT pretrained.
  2. Select a experiments setting file in configs/, and Modify the path of dataset.
  3. Modify the path in setup.py in line 5, and you can change the log name and cuda visible by modify line 13,14.
  4. Running the following code according to you pytorch version:

Single GPU

python -m main.py

Multiple GPUs

If pytorch < 1.12.0

python -m torch.distributed.launch --nproc_per_node 4 main.py 

If pytorch >= 1.12.0

torchrun --nproc_per_node 4 main.py

You need to change the number behind the -nproc_per_node to your number of GPUs

ielt's People

Contributors

mobulan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

ielt's Issues

result on cub_200_2011

HI,I reproduce this work,but I got the accuracy only 86%,where is the problem?Thanks a lot!

CUB dataset can't running and gdrive link can't be accesed

I have an error while running for CUB_200_2011, and the error message goes like this

RuntimeError: The MD5 checksum of the download file /workspace/IELTS/dataset/CUB_200_2011/CUB_200_2011.tgz does not match the one on record.Please delete the file and try again. If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues.

then there's gdrive link that shows so i can download the tgz or the dataset but it can't
this is the link that shown
error issue
(https://drive.usercontent.google.com/download)

Test.py

I would like to ask how to do the test and how to apply this code to my custom dataset.

string indices must be integers

Thans for your excellent work! But when using pet dataset training model, got error string indices must be integers.

Traceback (most recent call last):
  File "main.py", line 409, in <module>
    main(config)
  File "main.py", line 66, in main
    model, model_without_ddp = build_model(config, num_classes)
  File "main.py", line 16, in build_model
    model = build_models(config, num_classes)
  File "/home/guojiao/codes/disease_or_excellent_v2/methods/IELT/models/build.py", line 32, in build_models
    model.load_from(config.model.pretrained)
  File "/home/guojiao/codes/disease_or_excellent_v2/methods/IELT/models/IELT.py", line 93, in load_from
    np2th(weights["embedding/kernel"], conv=True)
TypeError: string indices must be integers

And the pretrained npz file has already been downloaded, renamed and moved to pretrained dir.

using ilet modul

Hello, can you provide code for using only the multihead voting, CLR, or DS module?

y.cuda()

我在训我自己的数据集时发现y是list无法放入cuda。我改变了y的数据类型但是在交叉熵损失又出现了错误,请问你训flowers的y是什么类型。我该如何解决这个问题

Model checkpoints

Hi Qin,

Thank you for the wonderful work! I'm trying to do some prediction task for birds and flowers, and I found your work. May I have your trained model weights for CUB and FLO datasets?

Cheers,
Zhi

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.