Giter Club home page Giter Club logo

rl_with_resets's Introduction

Deep RL with Resets for Addressing the Primacy Bias

This repository contains a JAX implementation of the resetting mechanism from the paper

The Primacy Bias in Deep Reinforcement Learning

by Evgenii Nikishin*, Max Schwarzer*, Pierluca D'Oro*, Pierre-Luc Bacon, and Aaron Courville.

Summary

The paper identifies a common flaw of deep RL algorithms called the primacy bias, a tendency to overfit initial experiences that damages the rest of the learning process. An agent impacted by the primacy bias tends to be incapable of leveraging subsequent data because of the accumulated effect of the initial overfitting. As a remedy, we propose a resetting mechanism that allows an agent to forget a part of its knowledge by periodically re-initializing the last few layers of an agent's network while preserving the replay buffer. Applying the resets to the SAC, DrQ, and SPR algorithms on DM Control tasks and Atari 100k benchmark alleviates the effects of the primacy bias and consistently improves the performance of the agents.

Please cite our work if you find it useful in your research:

@inproceedings{nikishin2022primacy,
  title={The Primacy Bias in Deep Reinforcement Learning},
  author={Nikishin, Evgenii and Schwarzer, Max and D'Oro, Pierluca and Bacon, Pierre-Luc and Courville, Aaron},
  booktitle={International Conference on Machine Learning},
  year={2022},
  organization={PMLR}
}

Instructions

Discrete and continuous control experiments use two different codebases.

DeepMind Control Suite

Install the necessary dependencies for SAC and DrQ algorithms using continuous_control_requirements.txt. To train a continuous control agent with resets on a DMC task, use one of following example commands:

python train_dense.py --env_name quadruped-run --max_steps 2_000_000 --config.updates_per_step 3 --resets --reset_interval 200_000
MUJOCO_GL=egl python train_pixels.py --env_name quadruped-run --max_steps 2_000_000 --resets --reset_interval 100_000  # due to action repeats, the interval will be higher

Note that lines 107-111 in train_dense.py and 141-175 in train_pixels.py are the only modifications to code needed to equip the SAC and DrQ agents with resets.

Atari 100k

To set up discrete control experiments, first create a Python 3.9 environment and run the following command to install the dependencies:

# Install from jax releases
 pip install --no-cache-dir -f https://storage.googleapis.com/jax-releases/jax_releases.html -r ./discrete_control_requirements.txt

To train an SPR agent without resets, run:

python -m discrete_control.train --base_dir ./test_dir/\
 --gin_files discrete_control/configs/SPR.gin \
 --gin_bindings='atari_lib.create_atari_environment.game_name = "Pong"' \
 --run_number 1

To train an SPR agent with default reset hyperparameters, run:

python -m discrete_control.train --base_dir ./test_dir/\
 --gin_files discrete_control/configs/SPR_with_resets.gin \
 --gin_bindings='atari_lib.create_atari_environment.game_name = "Pong"' \
 --run_number 1

To train an SPR agent with fully customized reset hyperparameters, the following template may be used:

python -m discrete_control.train --run_number ${seed} --base_dir ${BASE_DIR}_${seed} "$@" \
     --gin_files discrete_control/configs/SPR_with_resets.gin \
     --gin_bindings='atari_lib.create_atari_environment.game_name = '"\"${map[${f}]}\"" \
     --tag "SPR_resets_${reset_every}_${reset_updates}_${reset_offset}_${resets}_n${nstep}_rr${replay_ratio}" \
     --gin_bindings='JaxSPRAgent.reset_every = '"\"${reset_every}\"" \
     --gin_bindings='JaxSPRAgent.updates_on_reset = '"\"${reset_updates}\"" \
     --gin_bindings='JaxSPRAgent.total_resets = '"\"${resets}\"" \
     --gin_bindings='JaxSPRAgent.reset_offset = '"\"${reset_offset}\"" \
     --gin_bindings='JaxSPRAgent.reset_projection = '"\"${reset_proj}\"" \
     --gin_bindings='JaxSPRAgent.reset_noise = '"\"${reset_noise}\"" \
     --gin_bindings='JaxSPRAgent.reset_encoder = '"\"${reset_encoder}\"" \
     --gin_bindings='JaxDQNAgent.update_horizon = '"${nstep}" \
     --gin_bindings='JaxSPRAgent.replay_ratio = '"${replay_ratio}"  

Results

Method IQM Median Mean
SAC + resets 656 (549, 753) 617 (538, 681) 607 (547, 667)
SAC 501 (389, 609) 475 (407, 563) 484 (420, 548)
DrQ + resets 762 (704, 815) 680 (625, 731) 677 (632, 720)
DrQ 569 (475, 662) 521 (470, 600) 535 (481, 589)
SPR + resets 0.48 (0.46, 0.51) 0.51 (0.42, 0.57) 0.91 (0.84, 1.00)
SPR 0.38 (0.36, 0.39) 0.43 (0.38, 0.48) 0.57 (0.56, 0.60)

Training curves for all agents and environments with and without resets are available in the curves folder.

Acknowledgements

  • Our code for continuous control experiments is based on the JAXRL implementation of SAC and DrQ
  • The implementation of the SPR algorithm uses Dopamine
  • We aggregate scores across tasks using the rliable recommendations for evaluating RL algorithms

rl_with_resets's People

Contributors

evgenii-nikishin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.