Giter Club home page Giter Club logo

peitong-li / aet-net Goto Github PK

View Code? Open in Web Editor NEW
0.0 1.0 0.0 2.25 MB

In this paper, a transformer-based occluded ReID network, referred to as AET-Net, is proposed. In addition, two attention modules, SAEM and CAEM, are designed to enhance the local feature extraction ability of the transformer model. Moreover, an attention-guided model optimization strategy is proposed to mitigate the over-bias of the model to the attention features. In addition, a novel evaluation metric MAS is proposed to better evaluate the ReID model. Several studies are performed on different benchmark datasets. The obtained results show that AET-Net achieves a high performance on the occluded ReID dataset, which also provides a novel idea for the transformer-based ReID tasks.

aet-net's Introduction

CNN Attention Enhanced Transformer for Occluded Person Re-Identification

We propose AET-Net, a CNN Attention Enhanced Transformer Network for ReID to solve the occlusion problem. Achieving state-of-the-art performanceon on occluded dataset Occluded-Duke.

News

  • 2022.x Release the code of AET-Net.

Pipeline

framework MAS

Comparison results between AET-Net and the state-of-the-art methods

occ_duke market_duke

Abaltion Study of AET-NET

Ablation Portable_ablation Inferential Costs

Requirements

Installation

conda create -n AET-NET python=3.8 -y
conda activate AET-NET
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
(we use /torch 1.8.1 /torchvision 0.9.1 /timm 0.5.4 /cuda 11.1 for training and evaluation.)

Prepare Datasets

mkdir ../../datasets

Download the person datasets Market-1501, MSMT17, DukeMTMC-reID,Occluded-Duke, then unzip them and rename them under the directory like

datasets
├── market1501
│   └── contents ..
├── MSMT17
│   └── contents ..
├── dukemtmcreid
│   └── contents ..
├── Occluded_Duke
│   └── contents ..

Prepare ViT Pre-trained Models

You need to download the ImageNet pretrained transformer model : ViT-Base, ViT-Small

Training

We utilize one 3090 GPU for training.

python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/OCC_Duke/baseline.yml --tag ${TAG} MODEL.NAME ${0} MODEL.DEVICE_ID "('your device id')" MODEL.STRIDE_SIZE ${1} MODEL.Attention_type ${2} MODEL.SEM ${3} MODEL.SEM_W ${4} MODEL.SEM_P ${5}  MODEL.CEM ${6} MODEL.CEM_W ${7} MODEL.CEM_P ${8} MODEL.AGS ${9} OUTPUT_DIR ${OUTPUT_DIR} DATASETS.NAMES "('your dataset name')" 

Arguments

  • ${tag}: the tag of running, e.g. 'test' or 'train', default='default'
  • ${0}: the type of model to build, e.g. 'AET-Net' or 'TransReID'
  • ${1}: stride size for pure transformer, e.g. [16, 16], [14, 14], [12, 12]
  • ${2}: the ways of build attention module, e.g. 'RGA' or 'CBAM'
  • ${3}: whether using SEM, True or False.
  • ${4}: the weight of SEM: [0, 1].
  • ${5}: the location of the SEM, 'before', 'after', 'all' (default='before').
  • ${6}: whether using CEM, True or False.
  • ${7}: the weight of CEM: [0, 1].
  • ${8}: the location of the CEM, 'before', 'after' (default='after').
  • ${9}: whether usiong AGS, True or False
  • ${OUTPUT_DIR}: folder for saving logs and checkpoints, e.g. baseline, the result will output to './logs/{datasets}/{model.NAMES}/baseline/TAG'

or you can directly train with following yml and commands:

# OCC_DukeMTMC AET-NET baseline
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/OCC_Duke/baseline.yml --tag 'train' MODEL.NAME 'AET-Net' MODEL.DEVICE_ID "('0')"
# OCC_DukeMTM AET-NET (baseline + SEM(RGA))
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/OCC_Duke/AET-NET/RGA/SEM.yml --tag 'train' MODEL.NAME 'AET-Net' MODEL.DEVICE_ID "('0')" 
# OCC_DukeMTMC AET-NET (baseline + CEM(RGA))
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/OCC_Duke/AET-NET/RGA/CEM.yml --tag 'train' MODEL.NAME 'AET-Net' MODEL.DEVICE_ID "('0')"
# OCC_DukeMTMC AET-NET (baseline + SEM + CEM (RGA))
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/OCC_Duke/AET-NET/RGA/SC.yml --tag 'train' MODEL.NAME 'AET-Net' MODEL.DEVICE_ID "('0')"
# OCC_DukeMTMC TransReID (baseline + SEM + CEM + AGS (RGA))
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/OCC_Duke/AET-NET/RGA/SC_AGS.yml --tag 'train' MODEL.NAME 'AET-Net' MODEL.DEVICE_ID "('0')"

# DukeMTMC baseline
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/Duke/baseline.yml --tag 'train' MODEL.NAME 'AET-Net' MODEL.DEVICE_ID "('0')"
# Market baseline
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/Market/baseline.yml --tag 'train' MODEL.NAME 'AET-Net' MODEL.DEVICE_ID "('0')"

Evaluation

python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file 'choose which config to test' --tag 'test' MODEL.DEVICE_ID "('your device id')" TEST.WEIGHT "('your path of trained checkpoints')" TEST.MAS (Whether to use MAS evaluation indicators) HEATMAP.SAVE (Whether to save the heat map) HEATMAP.ROOT "(Path to save the heat map)"

Examples:

# OCC_Duke
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/OCC_Duke/AET-NET/RGA/SC_AGS.yml --tag 'test' MODEL.DEVICE_ID "('0')" TEST.WEIGHT '../logs/occ_duke/AET-NET/RGA/SC_AGS/AET-NET_120.pth'
# Market
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/Market/AET-NET/RGA/SC_AGS.yml --tag 'test' MODEL.DEVICE_ID "('0')" TEST.WEIGHT '../logs/occ_duke/AET-NET/RGA/SC_AGS/AET-NET_120.pth'

Visualization

Visualization

python Visualization.py --config_file "(you config path)" --use_cuda (Whether use cuda) --image_path "(The image path to be visualized)" --OUTPUT_DIR "(Path to Heat map output)" --Model_Type "(the tag of model)" --show (If show the image result) TEST.WEIGHT "(Path to your eval model)"

Example:

# OCC_Duke
python Visualization.py --config_file configs/OCC_Duke/AET-NET/RGA/SC_AGS.yml --use_cuda True --image_path ./demo/1.jpg  --OUTPUT_DIR ./demo/results --Model_Type "AET_SC_AGS" TEST.WEIGHT '../logs/occ_duke/AET-NET/RGA/SC_AGS/AET-NET_120.pth'

Inference Costs

python calc_parms_and_flops.py --config_file "(you config path)" TEST.WEIGHT "(Path to your eval model)"

Trained Models and logs

Note: The code will be released after the paper is accepted.

Acknowledgement

Code base from reid-strong-baseline , pytorch-image-models, TransReID, vit-explain

Citation

Contact

aet-net's People

Contributors

peitong-li avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.