Giter Club home page Giter Club logo

se3-augmented-coupling-flows's Introduction

SE(3) Equivariant Augmented Coupling Flows

Code for the paper https://arxiv.org/abs/2308.10364. Results can be obtained by running the commands in the Experiments section.

Install

JAX needs to be installed independently following the instruction on the JAX homepage. At time of publishing we used JAX 0.4.13 with python 3.10. This repo also has dependency on pytorch (NB: use CPU version so it doesn't clash with JAX) which may be installed with:

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

For the alanine dipeptide problem we need to install openmmtools with conda:

conda install -c conda-forge openmm openmmtools

Finally then,

pip install -e .

Experiments

Experiments may be run with the following commands. We use hydra to configure all experiments. The flow type may be set as shown in the first line. For Alanine Dipeptide the data must first be downloaded from Zenodo which may be done with the script eacf/targets/aldp.sh. For running an experiment make sure to configure the WANDB logger in the config file to match your WANDB account (e.g. for dw4 see experiments/config/dw4.yaml ).

python examples/dw4.py flow.type=spherical # Flow types: spherical,proj,along_vector,non_equivariant
python examples/lj13.py
python examples/qm9.py
python examples/aldp.py
python examples/dw4_fab.py
python examples/lj13_fab.py

The code for the equivariant CNF baseline can be found in the ecnf-baseline-neurips-2023 repo.

Upcoming additions

  • Quickstart notebook with inference using model weights

Citation

If you use this code in your research, please cite it as:

@inproceedings{
midgley2023eacf,
title={{SE}(3) Equivariant Augmented Coupling Flows},
author={Laurence Illing Midgley and Vincent Stimper and Javier Antoran and Emile Mathieu and Bernhard Sch{\"o}lkopf and Jos{\'e} Miguel Hern{\'a}ndez-Lobato},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=KKxO6wwx8p}
}

se3-augmented-coupling-flows's People

Contributors

javierantoran avatar lollcat avatar vincentstimper avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Forkers

amelie-iska

se3-augmented-coupling-flows's Issues

feat: plotting

  • Visualise samples of the target after they have been passed through the bijector, compared to the base distribution
  • Visualise (1) marginal augmented dist from flow, (2) marginal augmented dist from target, (3) a ~ p(a | x) where x is from flow and p is target.

# Stable mlp

  • try stabilise mlp with residual connections and layer norm

Training Run Example

Thank you for the great work! This research direction is very promising. Is there any chance you could include how to set up a training run for one of the example systems in the README.md?

feat: train via data-augmentation

  • Use non-rotation equivariant augmented normalizing flow, but with data augmentation
  • This flow will still have the zero mean trick for translation equivariance

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.