Giter Club home page Giter Club logo

amortizedcausaldiscovery's Introduction

Amortized Causal Discovery

This repo contains the official PyTorch implementation of:

Sindy Löwe*, David Madras*, Richard Zemel, Max Welling - Amortized Causal Discovery: Learning to Infer Causal Graphs from Time-Series Data

With Amortized Causal Discovery we learn to infer causal relations from samples with different underlying causal graphs but shared dynamics. This enables us to generalize across samples and thus improve our performance with increasing training data size.

*equal contribution

What is Amortized Causal Discovery?

With Amortized Causal Discovery, we separate causal relation prediction from dynamics modelling. Our amortized encoder learns to infer causal relations across samples with different underlying graphs. Our decoder learns to model the shared dynamics of the predicted relations.

This separation allows us to train a joint model for samples with different underlying causal graphs. This is in contrast to previous approaches, which need to refit a new model whenever they encounter samples with a different underlying causal graph.

What we found exciting is that this allows us to achieve tremendous improvements in causal inference performance with increasing training data size. Amortized Causal Discovery (ACD) manages to outperform previous causal discovery approaches with as little as 50 training samples; with 50.000 samples it outperforms them by more than 30% points.

How to run the code

Dependencies

  • Python and Conda

  • Setup the conda environment ACD by running:

    bash setup_dependencies.sh
  • Don't forget to activate the environment and cd into the codebase directory when playing with the code later on

    source activate ACD
    cd codebase

Datasets

  • To generate the particles with springs dataset from our paper, run

    python -m data.generate_dataset
  • To generate a particles dataset with varying latent temperature, run

    python -m data.generate_dataset --temperature_dist --temperature_alpha 2 --temperature_num_cats 3
  • To generate the Kuramoto dataset from our paper, run

    python -m data.generate_ODE_dataset
  • The Netsim dataset is available here

Experiments

  • Run the Springs experiment by running

     python -m train --suffix _springs5

    the Kuramoto experiment with

    python -m train --suffix _kuramoto5 --encoder cnn

    and the Netsim experiment with

    python -m train --suffix netsim
  • To run the experiment with an unobserved temperature variable, run

     python -m train --suffix _springs5 --encoder cnn --decoder sim --global_temp --load_temperatures
  • To run the experiment with an unobserved time-series, run

     python -m train --suffix _springs5 --unobserved 1
  • View all possible command-line options by running

    python -m train --help

Cite

Please cite our paper if you use this code in your own work:

@article{lowe2020amortized,
  title={Amortized Causal Discovery: Learning to Infer Causal Graphs from Time-Series Data},
  author={L{\"o}we, Sindy and Madras, David and Zemel, Richard, and Welling, Max},
  journal={arXiv preprint},
  year={2020}
}

References

Acknowledgements

The Robert Bosch GmbH is acknowledged for financial support.

amortizedcausaldiscovery's People

Contributors

loewex avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.