We propose AET-Net, a CNN Attention Enhanced Transformer Network for ReID to solve the occlusion problem. Achieving state-of-the-art performanceon on occluded dataset Occluded-Duke.
- 2022.x Release the code of AET-Net.
conda create -n AET-NET python=3.8 -y
conda activate AET-NET
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
(we use /torch 1.8.1 /torchvision 0.9.1 /timm 0.5.4 /cuda 11.1 for training and evaluation.)
mkdir ../../datasets
Download the person datasets Market-1501, MSMT17, DukeMTMC-reID,Occluded-Duke, then unzip them and rename them under the directory like
datasets
├── market1501
│ └── contents ..
├── MSMT17
│ └── contents ..
├── dukemtmcreid
│ └── contents ..
├── Occluded_Duke
│ └── contents ..
You need to download the ImageNet pretrained transformer model : ViT-Base, ViT-Small
We utilize one 3090 GPU for training.
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/OCC_Duke/baseline.yml --tag ${TAG} MODEL.NAME ${0} MODEL.DEVICE_ID "('your device id')" MODEL.STRIDE_SIZE ${1} MODEL.Attention_type ${2} MODEL.SEM ${3} MODEL.SEM_W ${4} MODEL.SEM_P ${5} MODEL.CEM ${6} MODEL.CEM_W ${7} MODEL.CEM_P ${8} MODEL.AGS ${9} OUTPUT_DIR ${OUTPUT_DIR} DATASETS.NAMES "('your dataset name')"
${tag}
: the tag of running, e.g. 'test' or 'train', default='default'${0}
: the type of model to build, e.g. 'AET-Net' or 'TransReID'${1}
: stride size for pure transformer, e.g. [16, 16], [14, 14], [12, 12]${2}
: the ways of build attention module, e.g. 'RGA' or 'CBAM'${3}
: whether using SEM, True or False.${4}
: the weight of SEM: [0, 1].${5}
: the location of the SEM, 'before', 'after', 'all' (default='before').${6}
: whether using CEM, True or False.${7}
: the weight of CEM: [0, 1].${8}
: the location of the CEM, 'before', 'after' (default='after').${9}
: whether usiong AGS, True or False${OUTPUT_DIR}
: folder for saving logs and checkpoints, e.g.baseline
, the result will output to './logs/{datasets}/{model.NAMES}/baseline/TAG'
or you can directly train with following yml and commands:
# OCC_DukeMTMC AET-NET baseline
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/OCC_Duke/baseline.yml --tag 'train' MODEL.NAME 'AET-Net' MODEL.DEVICE_ID "('0')"
# OCC_DukeMTM AET-NET (baseline + SEM(RGA))
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/OCC_Duke/AET-NET/RGA/SEM.yml --tag 'train' MODEL.NAME 'AET-Net' MODEL.DEVICE_ID "('0')"
# OCC_DukeMTMC AET-NET (baseline + CEM(RGA))
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/OCC_Duke/AET-NET/RGA/CEM.yml --tag 'train' MODEL.NAME 'AET-Net' MODEL.DEVICE_ID "('0')"
# OCC_DukeMTMC AET-NET (baseline + SEM + CEM (RGA))
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/OCC_Duke/AET-NET/RGA/SC.yml --tag 'train' MODEL.NAME 'AET-Net' MODEL.DEVICE_ID "('0')"
# OCC_DukeMTMC TransReID (baseline + SEM + CEM + AGS (RGA))
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/OCC_Duke/AET-NET/RGA/SC_AGS.yml --tag 'train' MODEL.NAME 'AET-Net' MODEL.DEVICE_ID "('0')"
# DukeMTMC baseline
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/Duke/baseline.yml --tag 'train' MODEL.NAME 'AET-Net' MODEL.DEVICE_ID "('0')"
# Market baseline
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/Market/baseline.yml --tag 'train' MODEL.NAME 'AET-Net' MODEL.DEVICE_ID "('0')"
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file 'choose which config to test' --tag 'test' MODEL.DEVICE_ID "('your device id')" TEST.WEIGHT "('your path of trained checkpoints')" TEST.MAS (Whether to use MAS evaluation indicators) HEATMAP.SAVE (Whether to save the heat map) HEATMAP.ROOT "(Path to save the heat map)"
Examples:
# OCC_Duke
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/OCC_Duke/AET-NET/RGA/SC_AGS.yml --tag 'test' MODEL.DEVICE_ID "('0')" TEST.WEIGHT '../logs/occ_duke/AET-NET/RGA/SC_AGS/AET-NET_120.pth'
# Market
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 main.py --config_file configs/Market/AET-NET/RGA/SC_AGS.yml --tag 'test' MODEL.DEVICE_ID "('0')" TEST.WEIGHT '../logs/occ_duke/AET-NET/RGA/SC_AGS/AET-NET_120.pth'
python Visualization.py --config_file "(you config path)" --use_cuda (Whether use cuda) --image_path "(The image path to be visualized)" --OUTPUT_DIR "(Path to Heat map output)" --Model_Type "(the tag of model)" --show (If show the image result) TEST.WEIGHT "(Path to your eval model)"
Example:
# OCC_Duke
python Visualization.py --config_file configs/OCC_Duke/AET-NET/RGA/SC_AGS.yml --use_cuda True --image_path ./demo/1.jpg --OUTPUT_DIR ./demo/results --Model_Type "AET_SC_AGS" TEST.WEIGHT '../logs/occ_duke/AET-NET/RGA/SC_AGS/AET-NET_120.pth'
python calc_parms_and_flops.py --config_file "(you config path)" TEST.WEIGHT "(Path to your eval model)"
Note: The code will be released after the paper is accepted.
Code base from reid-strong-baseline , pytorch-image-models, TransReID, vit-explain