Giter Club home page Giter Club logo

satmae_pp's Introduction

SatMAE++: Rethinking Transformers Pre-training for Multi-Spectral Satellite Imagery (CVPR 2024)

Updates

  • March 11, 2024: SatMAE++ paper is released [arXiv] [PDF]
  • March 13, 2024: Codebase is released.
  • March 26, 2024: Pre-trained and Finetuned ViT-Large model weights are released. Json data split files for FMoW-RGB are uploaded.

Overview

Different from standard natural image datasets, remote sensing data is acquired from various sensor technologies and exhibit diverse range of scale variations as well as modalities. Existing satellite image pre-training methods either ignore the scale information present in the remote sensing imagery or restrict themselves to use only a single type of data modality. Compared to existing works, SatMAE++ with multi-scale pre-training is equally effective for both optical as well as multi-spectral imagery. SatMAE++ performs multi-scale pre-training and utilizes convolution based upsampling blocks to reconstruct the image at higher scales making it extensible to include more scales.

Method

SatMAE++ incorporates the multiscale information by reconstructing the image at multiscale levels thereby improving the performance on various scene classification downstream datasets.

image


FMoW-Sentinel

You can download the dataset and corresponding train/val csv files from these links [satmae github] [fmow-sentinel]

Directory structure of the dataset should be as below:

[Root folder]
____ train.csv
____ val.csv
____ [images folder]
________ train
____________ aiport
____________ aiport_hangar
____________ .......
________ val
____________ aiport
____________ aiport_hangar
____________ .......

Pretraining

To pretrain the ViT model (default is ViT-L) using SatMAE++ approach on fmow_sentinel dataset, use the command as below:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port=29201 main_pretrain.py \
--batch_size 16 --accum_iter 16 \
--epochs 50 --warmup_epochs 20 \
--input_size 96 --patch_size 8 \
--mask_ratio 0.75 \
--model_type group_c \
--dropped_bands 0 9 10 \
--dataset_type sentinel --dropped_bands 0 9 10 \
--grouped_bands 0 1 2 6 --grouped_bands 3 4 5 7 --grouped_bands 8 9 \
--blr 0.0001 --num_workers 16 \
--train_path /home/fmow-sentinel/train.csv \
--output_dir ./output_dir \
--log_dir ./output_dir

Finetuning

To finetune the ViT model (default is ViT-L), use the command as below:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port=29202 main_finetune.py \
--batch_size 8 --accum_iter 16 \
--epochs 30 --warmup_epochs 5 \
--input_size 96 --patch_size 8 \
--model_type group_c \
--dropped_bands 0 9 10 \
--dataset_type sentinel --dropped_bands 0 9 10 \
--grouped_bands 0 1 2 6 --grouped_bands 3 4 5 7 --grouped_bands 8 9 \
--weight_decay 0.05 --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
--blr 0.0002 --num_workers 16 \
--train_path /home/fmow-sentinel/train.csv \
--test_path /home/fmow-sentinel/val.csv \
--output_dir ./finetune_dir \
--log_dir ./finetune_dir \
--finetune ./output_dir/checkpoint-49.pth

FMoW-RGB

You can download the dataset by following the instructions here [fmow-github]

Download the train and validation json files [data-split]. Alternately, you can preprocess data and create your own json/csv files using the script here [fmow-rgb preprocessing issue] [CSV Files (SatMAE)]

Directory structure of the dataset should look like as below:

[Root folder]
____ train_62classes.json
____ val_62classes.json
____ train
________ aiport
________ aiport_hangar
________ .......
____ val
________ aiport
________ aiport_hangar
________ .......

Pretraining

Use the below command to pretrain the ViT model (default is ViT-L) on fmow_RGB dataset:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port=29201 main_pretrain.py \
--batch_size 64 --accum_iter 32 \
--epochs 800 --warmup_epochs 20 \
--input_size 224 --patch_size 16 \
--mask_ratio 0.75 \
--model_type vanilla \
--dataset_type rgb \
--weight_decay 0.3 \
--lr 0.0007 --num_workers 16 \
--train_path /home/fmow-rgb/train_62classes.json \
--output_dir ./output_dir \
--log_dir ./output_dir

Finetuning

Use the following command to finetune the ViT model (default is ViT-L):

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port=29202 main_finetune.py \
--batch_size 8 --accum_iter 16 \
--epochs 50 --warmup_epochs 5 \
--input_size 224 --patch_size 16 \
--model_type vanilla \
--dataset_type rgb \
--weight_decay 0.05 --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
--lr 0.001 --num_workers 16 \
--train_path /home/fmow-rgb/train_62classes.json \
--test_path /home/fmow-rgb/val_62classes.json \
--output_dir ./finetune_dir \
--log_dir ./finetune_dir \
--finetune ./output_dir/checkpoint-799.pth

Downstream Datasets

Data splits for EuroSAT, UCMerced and RESISC-45 are available at [google-research].


Model Weights

Model Dataset Top1 Acc (%) Pretrain Finetune
ViT-L FMoW-Sentinel 63.23 download download
ViT-L FMoW-RGB 78.14 download download

Acknowledgements

The codebase is inspired from the SatMAE repository. We thank them for releasing their valuable codebase.

Citation

@inproceedings{satmaepp2024rethinking,
      title={Rethinking Transformers Pre-training for Multi-Spectral Satellite Imagery}, 
      author={Mubashir Noman and Muzammal Naseer and Hisham Cholakkal and Rao Muhammad Anwar and Salman Khan and Fahad Shahbaz Khan},
      year={2024},
      booktitle={CVPR}
}

satmae_pp's People

Contributors

techmn avatar

Stargazers

 avatar Qidi Shu avatar Wenyuan Li avatar Zelin Xu avatar  avatar YonghuiTAN avatar Yong Sun avatar jianghuixin avatar Tianlong Ai avatar Jihoon Oh avatar  avatar Salman Khan avatar teddy avatar  avatar tms avatar  avatar  avatar Jules Bourcier avatar  avatar Jose Sosa avatar Fadillah Adamsyah Ma'ani avatar Srikumar Sastry avatar Kaijie Yin avatar Markson-Young avatar  avatar  avatar Ajitabh Kumar avatar TaoBingcheng avatar  avatar Sachin Chanchani avatar Tarashish Mishra avatar  avatar  avatar Alexander Kalinovsky avatar Aditya Adiga avatar Dai-GoGo avatar Xiaopeng Wang avatar Jieyi Tan avatar CHENG XIN avatar  avatar lwdinWHU avatar  avatar  avatar  avatar  avatar  avatar Giorgos Angelis avatar Chuishun Kong avatar Benjay·Shaw avatar  avatar Airlamb avatar Xiaolong Liao avatar idris avatar Mustansar Fiaz avatar Konstantin Klemmer avatar Hisham Cholakkal avatar Muhammad Uzair Khattak avatar Malik Hashmat avatar Jose Cohenca avatar Muzammal Naseer avatar Julien Seillade avatar PanYang avatar Nick Stephens avatar Casey Hilland avatar  avatar  avatar Haoran Wang avatar Michael Gregory avatar  avatar  avatar  avatar Isaac Corley avatar  avatar Pratinav Seth avatar  avatar Robin Cole avatar Kang Wu avatar Chenhui Zhang avatar  avatar  avatar Lixiang Ru avatar

Watchers

Nick Stephens avatar Kostas Georgiou avatar  avatar Mikhail Kondratyev avatar Tuna Maughan avatar Pratinav Seth avatar  avatar Raphaël DELAIR avatar

satmae_pp's Issues

Evaluation on fmow-RGB

Thank you for your paper and code!

The results in your paper are excellent. I tried to run the evaluation on my machine:

CUDA_VISIBLE_DEVICES=4 python main_finetune.py \
 --batch_size 8 --input_size 224 --patch_size 16 --model_type vanilla --dataset_type rgb --weight_decay 0.05 \
 --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 --lr 0.001 --num_workers 2 \
 --resume ./weights/checkpoint_ViT-L_finetune_fmow_rgb.pth \
 --eval \
 --train_path /data_dir/train_62classes.json \
 --test_path /data_dir/val_62classes.json \
 --log_dir ./finetune_dir

But the result is:

[21:55:17.993992] Test: Total time: 1:28:20 (0.7993 s / it)                                                         
[21:55:17.994328] * Acc@1 51.411 Acc@5 78.786 loss 2.040  
[21:55:17.995148] Evaluation on 53041 test images- acc1: 51.41%, acc5: 78.79%

Maybe I run the evaluation in a wrong way. Could you please provide the correct instruction?

My conda environment is:

_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
absl-py                   2.1.0                    pypi_0    pypi
affine                    2.4.0                    pypi_0    pypi
attrs                     23.2.0                   pypi_0    pypi
blas                      1.0                         mkl  
brotli-python             1.0.9            py38h6a678d5_8  
bzip2                     1.0.8                h5eee18b_6  
ca-certificates           2024.3.11            h06a4308_0  
cachetools                5.3.3                    pypi_0    pypi
certifi                   2024.2.2         py38h06a4308_0  
charset-normalizer        2.0.4              pyhd3eb1b0_0  
click                     8.1.7                    pypi_0    pypi
click-plugins             1.1.1                    pypi_0    pypi
cligj                     0.7.2                    pypi_0    pypi
cudatoolkit               11.3.1               h2bc3f7f_2  
ffmpeg                    4.3                  hf484d3e_0    pytorch
freetype                  2.12.1               h4a9f257_0  
gmp                       6.2.1                h295c915_3  
gnutls                    3.6.15               he1e5248_0  
google-auth               2.29.0                   pypi_0    pypi
google-auth-oauthlib      1.0.0                    pypi_0    pypi
grpcio                    1.64.0                   pypi_0    pypi
idna                      3.7              py38h06a4308_0  
importlib-metadata        7.1.0                    pypi_0    pypi
intel-openmp              2023.1.0         hdb19cb5_46306  
jpeg                      9e                   h5eee18b_1  
lame                      3.100                h7b6447c_0  
lcms2                     2.12                 h3be6417_0  
ld_impl_linux-64          2.38                 h1181459_1  
lerc                      3.0                  h295c915_0  
libdeflate                1.17                 h5eee18b_1  
libffi                    3.4.4                h6a678d5_1  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libiconv                  1.16                 h5eee18b_3  
libidn2                   2.3.4                h5eee18b_0  
libpng                    1.6.39               h5eee18b_0  
libstdcxx-ng              11.2.0               h1234567_1  
libtasn1                  4.19.0               h5eee18b_0  
libtiff                   4.5.1                h6a678d5_0  
libunistring              0.9.10               h27cfd23_0  
libuv                     1.44.2               h5eee18b_0  
libwebp-base              1.3.2                h5eee18b_0  
lz4-c                     1.9.4                h6a678d5_1  
markdown                  3.6                      pypi_0    pypi
markupsafe                2.1.5                    pypi_0    pypi
mkl                       2023.1.0         h213fc3f_46344  
mkl-service               2.4.0            py38h5eee18b_1  
mkl_fft                   1.3.8            py38h5eee18b_0  
mkl_random                1.2.4            py38hdb19cb5_0  
ncurses                   6.4                  h6a678d5_0  
nettle                    3.7.3                hbbd107a_1  
numpy                     1.24.3           py38hf6e8229_1  
numpy-base                1.24.3           py38h060ed82_1  
oauthlib                  3.2.2                    pypi_0    pypi
opencv-python             4.9.0.80                 pypi_0    pypi
openh264                  2.1.1                h4ff587b_0  
openjpeg                  2.4.0                h3ad879b_0  
openssl                   3.0.13               h7f8727e_2  
pandas                    2.0.3                    pypi_0    pypi
pillow                    10.3.0           py38h5eee18b_0  
pip                       24.0             py38h06a4308_0  
protobuf                  5.27.0                   pypi_0    pypi
pyasn1                    0.6.0                    pypi_0    pypi
pyasn1-modules            0.4.0                    pypi_0    pypi
pyparsing                 3.1.2                    pypi_0    pypi
pysocks                   1.7.1            py38h06a4308_0  
python                    3.8.19               h955ad1f_0  
python-dateutil           2.9.0.post0              pypi_0    pypi
pytorch                   1.11.0          py3.8_cuda11.3_cudnn8.2.0_0    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2024.1                   pypi_0    pypi
rasterio                  1.3.10                   pypi_0    pypi
readline                  8.2                  h5eee18b_0  
requests                  2.32.2           py38h06a4308_0  
requests-oauthlib         2.0.0                    pypi_0    pypi
rsa                       4.9                      pypi_0    pypi
setuptools                69.5.1           py38h06a4308_0  
six                       1.16.0                   pypi_0    pypi
snuggs                    1.4.7                    pypi_0    pypi
sqlite                    3.45.3               h5eee18b_0  
tbb                       2021.8.0             hdb19cb5_0  
tensorboard               2.14.0                   pypi_0    pypi
tensorboard-data-server   0.7.2                    pypi_0    pypi
timm                      0.4.12                   pypi_0    pypi
tk                        8.6.14               h39e8969_0  
torchaudio                0.11.0               py38_cu113    pytorch
torchvision               0.12.0               py38_cu113    pytorch
typing_extensions         4.11.0           py38h06a4308_0  
tzdata                    2024.1                   pypi_0    pypi
urllib3                   2.2.1            py38h06a4308_0  
werkzeug                  3.0.3                    pypi_0    pypi
wheel                     0.43.0           py38h06a4308_0  
xz                        5.4.6                h5eee18b_1  
zipp                      3.19.0                   pypi_0    pypi
zlib                      1.2.13               h5eee18b_1  
zstd                      1.5.5                hc292b87_2  

Questions about reconstruction loss

I thank you for sharing your interesting paper.

I understand that only the reconstructed image patches are used for loss calculation in MAE.
However, when I checked the paper and the code in this repo, I noticed that the loss for higher-scale images involves using all image pixels, not just the reconstructed image patch regions.
Q1. Did you make any ablation studies regarding this?

Additionally, the paper mentions applying L1 loss to higher scales, similar to super-resolution, during multi-scale image reconstruction (just above Equation 4 of the paper).
Q2. Could you share any references related to this part?

Environment details?

Thanks very much for making this available - this is excellent.

I am considering adapting.

Any chance you could share your environment details?

Always tricky getting deep learning and geospatial things working together, so would be nice to know.

Thanks again,

Richard

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.