Giter Club home page Giter Club logo

contrastive_planning's Introduction

Contrastive Planning

Code for the paper Inference via Interpolation: Contrastive Representations Provably Enable Planning and Inference.

Installation

  1. Check conda is installed and mujoco200 binaries are in path:
test -z "$CONDA_PREFIX" && wget -O Miniforge3.sh "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" && bash Miniforge3.sh -b -p "$HOME/conda" && rm Miniforge3.sh 
"$HOME/conda/bin/conda" init "$(basename $SHELL)"
test ! -e ~/.mujoco/mujoco200 && mkdir -p ~/.mujoco && wget -O mujoco200.zip https://www.roboti.us/download/mujoco200_$(test $(uname) == "Linux" && echo linux || echo macos).zip && unzip mujoco200.zip && rm mujoco200.zip && mv mujoco200* ~/.mujoco
echo "$LD_LIBRARY_PATH" | grep -q mujoco200 || echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco200/bin' >> ~/.bashrc
  1. Clone repo and build environment:
git clone https://github.com/vivekmyers/contrastive_planning.git
cd contrastive_planning
conda env create
conda activate contrastive_planning

Running experiments

Run the following commands to train the method and baselines discussed in the paper for a single initialization seed and dataset shuffle:

python run_train.py --env maze2d-large-v1
python run_train.py --env door-human-v0

To evaluate success rates by distance and planning MSE for our method and baselines, run the following.

python run_eval.py --env maze2d-large-v1 --n_wypt 20 --rollout
python run_eval.py --env door-human-v0 --mse --n_wypt 5
python run_eval.py --env door-human-v0 --mse --n_wypt 1

The following commands will plot the results as shown in the paper.

python run_plot.py --env maze2d-large-v1 --rollout 
python run_plot.py --env door-human-v0 --waypoint_mse --n_wypt 5
python run_plot.py --env door-human-v0 --barplot --n_wypt 1 
python run_plot.py --env door-human-v0 --plot_plan 

Reproducing Results

All quantitative results from the paper can be reproduced by running make all. This will train and evaluate 100 seeds by default.

contrastive_planning's People

Contributors

vivekmyers avatar

Stargazers

Faisal Ahmed avatar Rishabh Madan avatar Kevin Roice avatar Richard Higgins avatar Nuri Kim avatar Chongyi Zheng avatar Saeid Tafazzol avatar  avatar Akhtyamov Timur avatar Yongwei Che avatar april211 avatar  avatar Xinjie Shen avatar Jayden Lee avatar  avatar  avatar Kuan-Ying Lai avatar louix avatar Seok-Ju Hahn (Adam) avatar  avatar Yoontae Hwang avatar Caleb Martin avatar  avatar Jianhong Wang avatar Maximilian Wolf avatar 唐国梁Tommy avatar  avatar Colin Moore avatar  avatar Andrew Siah avatar  avatar Swastik Haldar avatar Harshveer avatar Sachin Chanchani avatar Christian Fenkart avatar

Watchers

Ben Eysenbach avatar  avatar

contrastive_planning's Issues

Implementation of Uniformity Loss in Symmetrized InfoNCE?

Hi there, thanks for sharing the code from this v interesting paper!

In the uniformity term of the symmetrized InfoNCE loss

jax.nn.logsumexp(-(pdist * (1 - I)), axis=1)
+ jax.nn.logsumexp(-(pdist.T * (1 - I)), axis=1)

could I ask why you've multiplied the logits by (1 - I)?

Shouldn't the InfoNCE loss contain the positive term in both the numerator and denominator, yielding a uniformity term of

jax.nn.logsumexp(-pdist, axis=0) + jax.nn.logsumexp(-pdist, axis=1) 

?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.