Giter Club home page Giter Club logo

pisac's Introduction

Predictive Information Accelerates Learning in RL

Kuang-Huei Lee, Ian Fischer, Anthony Liu, Yijie Guo, Honglak Lee, John Canny, Sergio Guadarrama

NeurIPS 2020

cheetah_video walker_video bic_video cartpole_video finger_video

This repository hosts the open source implementation of PI-SAC, the reinforcement learning agent introduced in Predictive Information Accelerates Learning in RL. PI-SAC combines the Soft Actor-Critic Agent with an additional objective that learns compressive representations of predictive information. PI-SAC agents can substantially improve sample efficiency and returns over challenging baselines on tasks from the DeepMind Control Suite of vision-based continuous control environments, where observations are pixels.

If you find this useful for your research, please use the following to reference:

@article{lee2020predictive,
  title={Predictive Information Accelerates Learning in RL},
  author={Lee, Kuang-Huei and Fischer, Ian and Liu, Anthony and Guo, Yijie and Lee, Honglak and Canny, John and Guadarrama, Sergio},
  journal={arXiv preprint arXiv:2007.12401},
  year={2020}
}

Methods

pi2small

PI-SAC learns compact representations of the predictive information I(X_past;Y_future) that captures the environment transition dynamics, in addition to actor and critic learning. We capture the predictive information in a representation Z by maximizing I(Y_future;Z) and minimizing I(X_past;Z|Y_future) to compress out the non-predicitve part for better generalization, which reflects in better sampled efficiency, returns, and transferability. When interacting with the environment, it simply executes the actor model.

Find out more:

Training and Evaluation

To train the model(s) in the paper with periodic evaluation, run this command:

python -m pisac.run --root_dir=/tmp/pisac_cartpole_swingup \
--gin_file=pisac/config/pisac.gin \
--gin_bindings=train_pisac.train_eval.domain_name=\'cartpole\' \
--gin_bindings=train_pisac.train_eval.task_name=\'swingup\' \
--gin_bindings=train_pisac.train_eval.action_repeat=4 \
--gin_bindings=train_pisac.train_eval.initial_collect_steps=1000 \
--gin_bindings=train_pisac.train_eval.initial_feature_step=5000

We use gin to config hyperparameters. The default configs are specificed in pisac/config/pisac.gin. To reproduce the main DM-Control experiments, you need to specify different domain_name, task_name, action_repeat, initial_collect_steps, initial_feature_step for each environment.

domain_name task_name action_repeat initial_collect_steps initial_feature_step
cartpole swingup 4 1000 5000
cartpole balance_sparse 2 1000 5000
reacher easy 4 1000 5000
ball_in_cup catch 4 1000 5000
finger spin 1 10000 0
cheetah run 4 10000 10000
walker walk 2 10000 10000
walker stand 2 10000 10000
hopper stand 2 10000 10000

To use multiple gradient steps per environment step, change train_pisac.train_eval.collect_every to a number larger than 1.

Results

DeepMind Control Suite

pisac_full2

*gs: number of gradient steps per environment step

Requirements

The PI-SAC code uses Python 3 and these packages:

  • tensorflow-gpu==2.3.0
  • tf_agents==0.6.0
  • tensorflow_probability
  • dm_control (egl rendering option recommended)
  • gym
  • imageio
  • matplotlib
  • scikit-image
  • scipy
  • gin
  • pstar
  • qj

If you ever see that dm_control complains about some threading issues, please try adding --gin_bindings=train_pisac.train_eval.drivers_in_graph=False to put dm_control environment outside of the TensorFlow graph.

Disclaimer: This is not an official Google product.

pisac's People

Contributors

kuanghuei avatar tfboyd avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.