Giter Club home page Giter Club logo

natural-posterior-network's Introduction

Natural Posterior Network

This repository provides the official implementation of the Natural Posterior Network (NatPN) and the Natural Posterior Ensemble (NatPE) as presented in the following paper:

Natural Posterior Network: Deep Bayesian Predictive Uncertainty for Exponential Family Distributions
Bertrand Charpentier*, Oliver Borchert*, Daniel Zügner, Simon Geisler, Stephan Günnemann
International Conference on Learning Representations (ICLR), 2022

Paper | Video

Model Overview

Features

The implementation of NatPN that is found in this repository provides the following features:

  • High-level estimator interface that makes NatPN as easy to use as Scikit-learn estimators
  • Simple bash script to train and evaluate NatPN
  • Ready-to-use PyTorch Lightning data modules with 8 of the 9 datasets used in the paper*

In addition, we provide a public Weights & Biases project. This project will be filled with training and evaluation runs that allow you (1) to inspect the performance of different NatPN models and (2) to download the model parameters. See the example notebook for instructions on how to use such a pretrained model.

*The Kin8nm dataset is not included as it has disappeared from the UCI Repository.

Installation

Prior to installation, you may want to install all dependencies (Python, CUDA, Poetry). If you are running on an AWS EC2 instance with Ubuntu 20.04, you can use the provided bash script:

sudo bash bin/setup-ec2.sh

In order to use the code in this repository, you should first clone the repository:

git clone [email protected]:borchero/natural-posterior-network.git natpn

Then, in the root of the repository, you can install all dependencies via Poetry:

poetry install

Quickstart

Shell Script

To simply train and evaluate NatPN on a particular dataset, you can use the train shell script. For example, to train and evaluate NatPN on the Sensorless Drive dataset, you can run the following command in the root of the repository:

poetry run train --dataset sensorless-drive

The dataset gets downloaded automatically the first time this command is called. The performance metrics of the trained model is printed to the console and the trained model is discarded. In order to track both the metrics and the model parameters via Weights & Biases, use the following command:

poetry run train --dataset sensorless-drive --experiment first-steps

To list all options of the shell script, simply run:

poetry run train --help

This command will also provide explanations for all the parameters that can be passed.

Estimator

If you want to use NatPN from your code, the easiest way to get started is to use the Scikit-learn-like estimator:

from natpn import NaturalPosteriorNetwork

The documentation of the estimator's __init__ method provides a comprehensive overview of all the configuration options. For a simple example of using the estimator, refer to the example notebook.

Module

If you need even more customization, you can use natpn.nn.NaturalPosteriorNetworkModel directly. The natpn.nn package provides plenty of documentation and allows to configure your NatPN model as much as possible.

Further, the natpn.model package provides PyTorch Lightning modules which allow you to train, evaluate, and fine-tune models.

Running Hyperparameter Searches

If you want to run hyperparameter searches on a local Slurm cluster, you can use the files provided in the sweeps directory. To run the grid search, simply execute the file:

poetry run python sweeps/<file>

To make sure that your experiment is tracked correctly, you should also set the WANDB_PROJECT environment variable in a place that is read by the slurm script (found in sweeps/slurm).

Feel free to adapt the scripts to your liking to run your own hyperparameter searches.

Citation

If you are using the model or the code in this repository, please cite the following paper:

@inproceedings{natpn,
    title={{Natural} {Posterior} {Network}: {Deep} {Bayesian} {Predictive} {Uncertainty} for {Exponential} {Family} {Distributions}},
    author={Charpentier, Bertrand and Borchert, Oliver and Z\"{u}gner, Daniel and Geisler, Simon and G\"{u}nnemann, Stephan},
    booktitle={International Conference on Learning Representations},
    year={2022}
}

Contact Us

If you have any questions regarding the code, please contact us via mail.

License

The code in this repository is licensed under the MIT License.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.