Giter Club home page Giter Club logo

drbayes's Introduction

Subspace Inference for Bayesian Deep Learning

This repository contains a PyTorch code for the subspace inference method introduced in the paper

Subspace Inference for Bayesian Deep Learning

by Pavel Izmailov, Wesley Maddox, Polina Kirichenko, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson

Introduction

For deep neural network models, exact Bayesian inference is intractable, and existing approximate inference methods suffer from many limitations, largely due to the high dimensionality of the parameter space. In subspace inference we address this problem by performing inference in low-dimensional subspaces of the parameter space. In the paper, we show how to construct very low-dimensional subspaces that can contain diverse high performing models. Bayesian inference within such subspaces is much easier than in the original parameter space and leads to strong results.

At a high level, our method proceeds as follows:

  1. We construct a low dimensional subspace of the parameter space
  2. We approximate the posterior distribution over the parameters in this subspace
  3. We sample from the approximate posterior in the subspace and perform Bayesian model averaging.

For example, we can achieve state-of-the-art results for a WideResNet with 36 million parameters by performing approximate inference in only a 5-dimensional subspace.

If you use our work, please cite the following paper:

@article{izmailov_subspace_2019,
  title={Subspace Inference for Bayesian Deep Learning},
  author={Izmailov, Pavel and Maddox, Wesley and Kirichenko, Polina and Garipov, Timur and Vetrov, Dmitry and Wilson, Andrew Gordon},
  journal={Uncertainty in Artificial Intelligence (UAI)},
  year={2019}
}

Installation

You can install the package by running the following command

python setup.py

Usage

We provide the scripts and example commands to reproduce the experiments from the paper.

Subspace Construction and Inference Procedures

In the paper we show how to construct random, PCA, and mode-connected subspaces, over which we can perform Bayesian inference. The PCA subspace will often be most practical, in terms of a runtime-accuracy trade-off. Once the subspace is constructed, we consider various approaches to inference within the subspace, such as MCMC and variational methods. The following examples show how to call subspace inference with your choice of subspace and inference procedure.

Visualizing Regression Uncertainty

We provide a jupyter notebook in which we show how to apply subspace inference to a regression problem, using different types of subspaces and approximate inference methods.

Image Classification

The folder experiments/cifar_exps contains the implementation of subspace inference for CIFAR-10 and CIFAR-100 datasets. To run subspace inference, you need to first pre-train the subspace, and then either run variational inference (VI) or elliptical slice sampling (ESS) in the subspace. For example, you can pre-train a PCA subspace for a VGG-16 by running

python3 experiments/cifar_exps/swag.py --data_path=data --epochs=300 --dataset=CIFAR100 --save_freq=300 \
      --model=VGG16 --lr_init=0.05 --wd=5e-4 --swag --swag_start=161 --swag_lr=0.01 --cov_mat --use_test \
      --dir=ckpts/vgg16_run1

Then, you can run ESS in the constructed subspace using the following command

python3 experiments/cifar_exps/subspace_ess.py --dir=ckpts/vgg16_run1/ess/ --dataset=CIFAR100 \
      --data_path=~/datasets/ --model=VGG16 --use_test --rank=5 --checkpoint=ckpts/vgg16_run1/swag-300.pt \
      --temperature=5000 --prior_std=2

Please refer to the README for more detailed information.

UCI Regression

The folder experiments/uci_exps contains implementations of the subspace inference procedure for UCI regression problems.

Please refer to the README for more detailed information.

References for Code Base

Repositories:

  • Stochastic weight averaging (SWA): PyTorch repo. Many of the base methods and model definitions are built off of this repo.
  • SWA-Gaussian (SWAG): PyTorch repo. our codebase here is based the SWAG repo; SWAG repo also contains the implementations of various baselines fof image classification as well as the experiments on Hessian eigenvalues.
  • Mode-Connectivity: PyTorch repo. We use the scripts from this repository to construct our curve subspaces.
  • Bayesian benchmarks: Repo; the experiments/uci_exps folder is a clone of that repo.

Model implementations:

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.