Giter Club home page Giter Club logo

disentanglement_via_mechanism_sparsity's Introduction

Disentanglement via Mechanism Sparsity Regularization: A New Principle for Nonlinear ICA.

By Sébastien Lachapelle, Pau Rodríguez López, Yash Sharma, Katie Everett, Rémi Le Priol, Alexandre Lacoste, Simon Lacoste-Julien

This repository contains the code used to run the experiments in the paper "Disentanglement via Mechanism Sparsity Regularization: A New Principle for Nonlinear ICA".

Environment:

Tested on python 3.7.

See requirements.txt.

Time-sparsity experiment

OUTPUT_DIR=<where to save experiment>
DATAROOT=<where data is located>
python disentanglement_via_mechanism_sparsity/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --mode vae --dataset toy-nn/temporal_sparsity_non_trivial --freeze_g --freeze_gc --z_dim 10 --gt_z_dim 10 --gt_x_dim 20 --n_lag 1 --full_seq --time_limit 3

Action-sparsity experiment

OUTPUT_DIR=<where to save experiment>
DATAROOT=<where data is located>
DATASET=<name of dataset>
python disentanglement_via_mechanism_sparsity/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --mode vae --dataset toy-nn/action_sparsity_non_trivial --freeze_g --freeze_gc --z_dim 10 --gt_z_dim 10 --gt_x_dim 20 --n_lag 0 --time_limit 3

Regularization

In the minimal commands provided above, all regularizations are deactivated (via the --freeze_g and --freeze_gc flags). To activate the regularization for say the temporal mask G^z, replace --freeze_g by --g_reg_coeff COEFF_VALUE. Same syntax works also for the action mask G^a (named gc in the code). Here's the correspondence between the mask names in the code (left) and in the paper (right):

g = Mask G^z (Time sparsity)

gc = Mask G^a (Action sparsity)

Synthetic datasets

For synthetic data (--dataset toy-*), the data is generated before training, so no need to download anything. Here are the datasets used in the paper:

  • toy-nn/temporal_sparsity_trivial
  • toy-nn/temporal_sparsity_non_trivial
  • toy-nn/temporal_sparsity_non_trivial_no_graph_crit
  • toy-nn/temporal_sparsity_non_trivial_no_suff_var
  • toy-nn/action_sparsity_trivial
  • toy-nn/action_sparsity_non_trivial
  • toy-nn/action_sparsity_non_trivial_no_graph_crit
  • toy-nn/action_sparsity_non_trivial_no_suff_var

Baselines

TCVAE

Code adapted from: https://github.com/rtqichen/beta-tcvae

OUTPUT_DIR=<where to save experiment>
DATAROOT=<where data is located>
python disentanglement_via_mechanism_sparsity/baseline_models/beta-tcvae/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR  --dataset toy-nn/action_sparsity_non_trivial --tcvae --beta 1 --gt_z_dim 10 --gt_x_dim 20 --time_limit 3
iVAE

Code adapted from: https://github.com/ilkhem/icebeem

OUTPUT_DIR=<where to save experiment>
DATAROOT=<where data is located>
python disentanglement_via_mechanism_sparsity/baseline_models/icebeem/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR  --dataset toy-nn/action_sparsity_non_trivial --method ivae --gt_z_dim 10 --gt_x_dim 20 --time_limit 3
SlowVAE

Code adapted from: https://github.com/bethgelab/slow_disentanglement

OUTPUT_DIR=<where to save experiment>
DATAROOT=<where data is located>
python disentanglement_via_mechanism_sparsity/baseline_models/slowvae_pcl/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --dataset toy-nn/temporal_sparsity_non_trivial --gt_z_dim 10 --gt_x_dim 20 --time_limit 3
PCL

Code adapted from: https://github.com/bethgelab/slow_disentanglement/tree/baselines

OUTPUT_DIR=<where to save experiment>
DATAROOT=<where data is located>
python disentanglement_via_mechanism_sparsity/baseline_models/slowvae_pcl/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --dataset toy-nn/temporal_sparsity_non_trivial --pcl --r_func mlp --gt_z_dim 10 --gt_x_dim 20 --time_limit 3

disentanglement_via_mechanism_sparsity's People

Contributors

slachapelle avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.