Giter Club home page Giter Club logo

scalable-inference-in-sdes's Introduction

Scalable Inference in SDEs by Direct Matching of the Fokker–Planck–Kolmogorov Equation

This repository is the official implementation of the methods in the publication

  • Solin, A., Tamir, E., and Verma, P. (2021) Scalable Inference in SDEs by Direct Matching of the Fokker–Planck–Kolmogorov Equation. In Advances in Neural Information Processing Systems 35 (NeurIPS). [arXiv]

In the paper, we advocate alternative solution concepts to stochastic differential equation (SDE) models in machine learning, where simulation-based techniques such as variants of stochastic Runge–Kutta are currently the de facto approach. These methods are convenient, general-purpose, and used with parametric and non-parametric models, and neural SDEs. Yet, stochastic Runge–Kutta relies on the use of sampling schemes that can be inefficient in high dimensions. We address this issue by revisiting the classical SDE literature and derive direct approximations to the (typically intractable) Fokker–Planck–Kolmogorov equation by matching moments. The codebase in this repository includes the building blocks for the figures and code for the experiments in the paper.

Python environment

The code should be run using python 3.6. If you are already using python 3.6, dependencies can be installed using the requirements file

pip install -r requirements.txt

Alternatively, conda virtual environment can be created using the environment.yml file

conda env create -f environment.yml
conda activate scalable-sde

MOCAP experiment

The MOCAP experiment specific code is in experiments/mocap. To prepare the data for the experiment, place the Mocap matlab data file mocap35.mat to folder [base_folder]/data/mocap_data, where base_folder is to be given as input to the training and test scripts. We use the preprocessed MOCAP data from https://github.com/cagatayyildiz/ODE2VAE.

In order to run the MOCAP training, run

python experiments/mocap/walking_tf_train.py [-base_folder BASE_FOLDER] [-task TASK] [-decoder_dist DECODER_DIST]
                 [-model_name MODEL_NAME] [-prior_model_name PRIOR_MODEL_NAME] [-vae_name VAE_NAME]
                 [--dt DT] [--latent_dim LATENT_DIM] [--context_dim CONTEXT_DIM]
                 [--epochs EPOCHS] [--start_len START_LEN]

For testing a trained MOCAP model, run

python experiments/mocap/walking_tf_test.py [-base_folder BASE_FOLDER] [-task TASK] [-decoder_dist DECODER_DIST]
                 [-model_name MODEL_NAME] [-vae_name VAE_NAME] [--dt DT]
                 [--latent_dim LATENT_DIM] [--context_dim CONTEXT_DIM] [--start_len START_LEN]

See the train and test scripts for further documentation of their input arguments. To modify the codebase for some other flat dataset (VAE implementation doesn't support image data), modify the utility function get_data in experiments/mocap/walking_tf_functions.py to output another dataset class.

Alternative SDE Approximations

You can run the MOCAP experiment with any new SDE approximator, as long as it inherits from the class SDEApprox in src/sde_tf/sde_approx/sde_approx.py.

Rotating MNIST

The code for the rotating MNIST experiment is available in experiments/mnist. In order to run the experiment:

cd experiments/mnist/
python main.py

All the experiment related parameters are present in config.py from where they can be modified. By default the output folder is experiments/mnist/output where the trained models and inference plots are saved.

Notebooks

The code used to generate Figure 3 and Figure 4 of the paper is available in the jupyter notebook, /notebooks/.

Citation

If you use the code in this repository for your research, please cite the paper as follows:

@inproceedings{solin2021,
  title={Scalable Inference in SDEs by Direct Matching of the {F}okker--{P}lanck--{K}olmogorov Equation},
  author={Solin, Arno and Tamir, Ella and Verma, Prakhar},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
  year={2021}
}

Contributing

For all correspondence, please contact [email protected], [email protected], or [email protected] .

License

This software is provided under the MIT license.

scalable-inference-in-sdes's People

Contributors

asolin avatar prakharverma avatar tamell avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

scalable-inference-in-sdes's Issues

Reproducing Mocap Experiments and possible bug in code

Hi!

Thanks a lot for releasing a clean implementation of your model.

I am trying to reproduce the results reported in the paper for the Mocap dataset. Can I confirm that the default arguments in this file correspond to the best results reported in the paper (most of the things seem to match except the number of epochs)? If not, could you please share the best config?

There also appears to be a bug in the following line:

loss = recon_nll + encoder_kl*self.gamma + weight**model_kl + vae_init_loss

Should it be weight*model_kl instead of weight**model_kl?

I tried training the model both before and after fixing this potential bug and I got the following test MSE values:

weight**model_kl, decoder_dist: True, epochs: 1500 = 19.74 +/- 0.04
weight*model_kl, decoder_dist: True, epochs: 1500 = 17.32 +/- 0.03

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.