Giter Club home page Giter Club logo

pnml_ood_detection's Introduction

Single Layer Predictive Normalized Maximum Likelihood for Out-of-Distribution Detection

PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC

https://arxiv.org/abs/2110.09246

This is a fast and scalable approach for detecting out-of-distribution test samples. It can be applied to any pretrained model.

Pseudocode

# Assuming to have: trainloader, testloader, and model with model.backbone() and model.classifer() methods

# Extract the training set features. Dimensions 2D matrix: training_size x num_features
features = torch.vstack([model.backbone(images) for images, label in trainloader])
norm = torch.linalg.norm(features, dim=-1, keepdim=True)
features = features / norm

# Compute the training data empirical correlation matrix inverse
x_t_x_inv = torch.linalg.inv(features.T @ features)

# Calculate the regret: Large regret means out-of-distribution sample
for images, labels in testloader:
    features = model.backbone(images)

    # Normalize
    norm = torch.linalg.norm(features, dim=-1, keepdim=True)
    features = features / norm
    
    # Get the probabilities of the normalized features
    probs = torch.softmax(model.classifier(features), dim=-1)

    # Calc projection
    x_proj = features @ x_t_x_inv @ features.T
    xt_g = x_proj / (1 + x_proj)

    # Equation 20
    n_classes = probs.shape[-1]
    nf = torch.sum(probs / (probs + (1 - probs) * (probs ** x_t_g)), dim=-1)
    regrets = torch.log(nf) / torch.log(torch.tensor(n_classes))

Paper results

Regret for low dimentional data

OOD detection result

Run to code

Install requirements

# Create env
conda create -n pnml_ood python=3.8.0 --yes
conda activate pnml_ood

# Install pip for fallback
conda install --yes pip

# Pytorch with GPU
conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.2 -c pytorch --yes

# All other: Install with conda. If package installation fails, install with pip.
while read requirement; do conda install --yes $requirement || pip install $requirement; done < requirements.txt 

Download data and models

# Download OOD data
cd bash_scripts
chmod 777 ./download_data.sh
./download_data.sh

# Download pretrained models
chmod 777 ./download_models.sh
./download_models.sh

Optional: Download imagenet30

Follow https://github.com/alinlab/CSI

Imagenet30 training set: https://drive.google.com/file/d/1B5c39Fc3haOPzlehzmpTLz6xLtGyKEy4/view

Imagenet30 testing set: https://drive.google.com/file/d/13xzVuQMEhSnBRZr-YaaO08coLU2dxAUq/view

Put and untar under ./data/Imagenet30

.
├── README.md
├── data
│   ├── Imagenet30
│   │   ├── one_class_test
│   │   ├── one_class_test.tar
│   │   ├── one_class_train
│   │   └── one_class_train.tar

Execute methods

Execute a single method. Our pNML method runs on-top of the executed method

cd src
python main_execute_method.py model=densenet trainset=cifar100 method=baseline

Execute all methods

cd bash_scripts
chmod 777 ./execute_methods.sh
./execute_methods.sh

Create paper's tables

cd src
python main_create_tables.py

Citing

If you use this code in your research or wish to refer to the baseline results, please use the following BibTeX entry.

@inproceedings{bibas2021single,
  title={Single Layer Predictive Normalized Maximum Likelihood for Out-of-Distribution Detection},
  author={Bibas, Koby and Feder, Meir and Hassner, Tal},
  booktitle={Advances in Neural Information Processing Systems},
  year={2021}
}

pnml_ood_detection's People

Contributors

dependabot[bot] avatar kobybibas avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

pnml_ood_detection's Issues

ImageNet-30 models

Thank you very much for a very reproducible-oriented repository!

Can you please publish your ImageNet-30 models (ResNet-18 & ResNet-101) so we can run pNML and evaluate OOD detection performance rates on them?
I think the easiest way would be to add it to the bash_scripts/download_models.sh file.

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.