Giter Club home page Giter Club logo

csiro-robotics / isice Goto Github PK

View Code? Open in Web Editor NEW
7.0 5.0 3.0 1.73 MB

[CVPR2023] The official repository for paper "Learning Partial Correlation based Deep Visual Representation for Image Classification" To appear in 2023 The IEEE / CVF Computer Vision and Pattern Recognition Conference (CVPR)

Home Page: https://csiro-robotics.github.io/iSICE/

License: Other

Shell 3.96% Python 96.04%
classification covariance-matrix deep-learning representation-learning pytorch

isice's Introduction

Learning Partial Correlation based Deep Visual Representation for Image Classification
Official implementation with PyTorch

iSICE This repository contains the model definitions, training/evaluation code and pre-trained model weights for our paper exploring partial correlation based deep SPD visual representation. More information are available on our project website.

Learning Partial Correlation based Deep Visual Representation for Image Classification
Saimunur Rahman, Piotr Koniusz, Lei Wang, Luping Zhou, Peyman Moghadam, Changming Sun
CSIRO Data61, University of Wollongong, Australian National University, University of Sydney, Queensland University of Technology

Visual representation based on covariance matrix has demonstrates its efficacy for image classification by characterising the pairwise correlation of different channels in convolutional feature maps. However, pairwise correlation will become misleading once there is another channel correlating with both channels of interest, resulting in the "confounding" effect. For this case, "partial correlation" which removes the confounding effect shall be estimated instead. Nevertheless, reliably estimating partial correlation requires to solve a symmetric positive definite matrix optimisation, known as sparse inverse covariance estimation (SICE). How to incorporate this process into CNN remains an open issue. In this work, we formulate SICE as a novel structured layer of CNN. To ensure the CNN still be end-to-end trainable, we develop an iterative method based on Newton-Schulz iteration to solve the above matrix optimisation during forward and backward propagation steps. Our work not only obtains a partial correlation based deep visual representation but also mitigates the small sample problem frequently encountered by covariance matrix estimation in CNN. Computationally, our model can be effectively trained with GPU and works well with a large number of channels in advanced CNN models. Experimental results confirm the efficacy of the proposed deep visual representation and its superior classification performance to that of its covariance matrix based counterparts.

This repository contains:

✔️ A simple implementation of our method with PyTorch
✔️ A script useful for training/evaluating our method on various datasets
✔️ Pre-trained model weights on several datasets

Repository Setup Guide

To run our code on your machine, the first step would be repository download which can be done using the following commands:

cd /the/path/where/you/want/to/copy/the/code
git clone https://github.com/csiro-robotics/iSICE.git
cd iSICE

The second step is to create a conda enovironment with necessary python packages which can be done using the following commands:

conda create -name iSICE
conda install pytorch torchvision cudatoolkit torchaudio scipy matplotlib -c pytorch

For easiness of use, we only use common python packages so that users can run our code with less difficulty. If you do not have anaconda installed, you can either install anaconda or its lighter version miniconda, or use python virtual environment. In case of python virtual environment, the packages can be installed with pip. Please see here for details. We also provided the `isice.yml' file for creating conda environment similar to us.

Note that we have evaluated our code with PyTorch 1.9.0. However, there should not be problem with other versions released after PyTorch 0.4.0. The above command will provide GPU support via CUDA which supports CPU by default.

The third step is to activate the above conda enovironment with the following command:

conda activate iSICE

The forth step will be downloading the datasets. All datasets should be prepared as follows.

.
├── train
│   ├── class 1
│   │   ├── image_001.format
│   │   ├── image_002.format
|   |   └── ...
│   ├── class 2
│   ├── class 3
│   ├── ...
│   ├── ...
│   └── class N
└── val
    ├── class 1
    │   ├── image_001.format
    │   ├── image_002.format
    |   └── ...
    ├── class 2
    ├── class 3
    ├── ...
    ├── ...
    └── class N

Repository Overview

We use a modular design for this repository. From our experience, we find that such design is easy to manage and extend. Our code repository is segmented as follows.

├── main.py
├── imagepreprocess.py
├── functions.py
├── model_init.py
├── src
│   ├── network
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── inception.py
│   │   ├── alexnet.py
│   │   ├── resnet.py
│   │   └── vgg.py
│   ├── representation
│   │   ├── __init__.py
│   │   ├── SICE.py
│   │   ├── INVCOV.py
│   │   ├── COV.py
│   │   ├── GAvP.py
├── train_iSICE_model.sh

How to use our code

Our main.py maintains the process of running our code to reproduce the results reported in the paper. Supppose, we want to train a partial correlation representation model based on VGG-16 backbone with CUB-200 dataset (referred as Birds dataset in the paper) and evaluate on the same dataset, the following command can be used:

python main.py /path/to/CUB --benchmark CUB --pretrained -a vgg16_bn --epochs 100 --lr 1.2e-4 --lr-method step --lr-params 15 30 -j 10 -b 65 --num-classes 200 --representation SICE --freezed-layer 0 --classifier-factor 5 --modeldir /path/to/save/the/model/and/meta/information

As training progresses, loss, top-1 error and top-5 error information for both training and test evaluation will be automatically saved in the path specified with --modeldir parameter above.

For training on computing clusters such as HPC please use the train_iSICE_model.sh script by changing it various fields as per the given instructions on the script (we showed how to train using MIT indoor dataset). Our code is compatible with multiple GPU training.

Pre-trained models

For convanience, we provide our VGG-16 and ResNet-50 based partial correlation models on traned on fine-grained and scene datasets. They can be downloaded here.

Pairwise correlation based models (computed via iSQRT-COV pooling)

MIT Airplane Birds Cars
Backbone top1 acc. (%) Model top1 acc. (%) Model top1 acc. (%) Model top1 acc. (%) Model
VGG-16 76.1 TBA 90.0 TBA 84.5 TBA 91.2 TBA
ResNet-50 78.8 TBA 90.9 TBA 84.3 TBA 92.1 TBA

Partial correlation based models (computed via Precision Matrix described in Algorithm 1 of the paper)

MIT Airplane Birds Cars
Backbone top1 acc. (%) Model top1 acc. (%) Model top1 acc. (%) Model top1 acc. (%) Model
VGG-16 80.2 TBA 89.4 TBA 83.4 TBA 92.0 TBA
ResNet-50 80.8 TBA 91.2 TBA 84.7 TBA 92.0 TBA

Partial correlation based models (computed via iSICE described in Algorithm 2 of the paper)

MIT Airplane Birds Cars
Backbone top1 acc. (%) Model top1 acc. (%) Model top1 acc. (%) Model top1 acc. (%) Model
VGG-16 78.7 TBA 92.2 TBA 86.5 TBA 94.0 TBA
ResNet-50 80.5 TBA 92.7 TBA 85.9 TBA 93.5 TBA

Pre-trained models canbe used as a checkpoints for further training/evaluation using the following command:

python main.py /path/to/CUB --benchmark CUB --pretrained -a vgg16_bn --epochs 100 --lr 1.2e-4 --lr-method step --lr-params 15\ 30 -j 10 -b 65 --num-classes 200 --representation SICE --freezed-layer 0 --classifier-factor 5 --resume /path/to/downloaded/model

How to cite our paper

Please use the following bibtex reference to cite our paper.

@InProceedings{isice_cvpr,
    author = {Rahman, Saimunur and Koniusz, Piotr and Wang, Lei and Zhou, Luping and Moghadam, Peyman and Sun, Changming},
    title = {Learning Partial Correlation based Deep Visual Representation for Image Classification},
    booktitle = {IEEE/CVF Int. Conf. on Computer Vision and Pattern Recognition (CVPR)},
    month = {June},
    year = {2023}
}

Acknowledgments

This codebase borrows from iSQRT-COV repository, we thank the authors for maintaining the repository.

Contact

If you have any questions or suggestions, please contact [email protected].

isice's People

Contributors

peymmo avatar saimunur avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

isice's Issues

What should the --lr-params be?

Thanks a lot for your great wok!
But when I run the command 'python main.py data/ --benchmark dataset --epochs 100 --lr 1.2e-4 --lr-method step --lr-params 15\ 30 -j 10 -b 65 --num-classes 200 --representation SICE --freezed-layer 0 --classifier-factor 5 --modeldir checkpoints/', I get error message: 'main.py: error: argument --lr-params: invalid float value: '15\''

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.