Giter Club home page Giter Club logo

distributional-rl's Introduction

PRs Welcome

Distributional-RL

This repository is a comprehensive implementation of the state-of-the-art Deep Distributional Reinforcement Learning Algorithms that have been introduced as a series of improvements upon each other and are mainly focused on the Atari benchmark.

Included algorithms via a single API:

  • C51 (Categorical DQN)
  • IQN (Implicit Quantile Networks)
  • QRDQN (Quantile Regression DQN)
  • FQF (Fully parameterized Quantile Function)

Demo

IQN Agent:

Results

Environment: AssaultNoFrameskip-v4

  • As it's been highlighted in the part 4 of the FQF paper, FQF algorithm is roughly 20% slower (in terms of runtime speed) than other methods thus, to have a comparable baseline of each method's performance against others, I have plotted the performance metrics with respect to the processing time rather than steps taken by the agent in the environment.
  • The runtime was bounded to 6 hours available for free-gpu-equipped machines in the paperspace.com as the host of running the current code.

Dependencies

  • gym == 0.19.0
  • numpy == 1.22.0
  • opencv_python == 4.5.5.62
  • psutil == 5.8.0
  • torch == 1.13
  • wandb == 0.12.9
  • protobuf == 3.20
  • gym[atari]

Usage

python main.py --agent_name="C51" --online_wandb --interval=100 --env_name="BreakoutNoFrameskip-v4"
usage: Choose your desired parameters [-h] [--agent_name AGENT_NAME]
                                      [--env_name ENV_NAME]
                                      [--mem_size MEM_SIZE] [--seed SEED]
                                      [--interval INTERVAL] [--do_test]
                                      [--online_wandb]

optional arguments:
  -h, --help            show this help message and exit
  --agent_name AGENT_NAME
                        Distributional method name
  --env_name ENV_NAME   Name of the environment.
  --mem_size MEM_SIZE   The memory size.
  --seed SEED           The random seed.
  --interval INTERVAL   The interval specifies how often different parameters
                        should be saved and printed, counted by the number of
                        episodes.
  --do_test             The flag determines whether to train the agent or play
                        with it.
  --online_wandb        Run wandb in online mode.

Additional Details

  • Training takes significant time. Preferred to be run on machines with access to GPUs.
  • Access to wandb allows your run info (as well as graphs) to be saved online.
  • Use nohup (or something similar) to execute it in the background.

Considerations

  • Accepted values for agent_name: {"C51", "IQN", "FQF" and "QRDQN"}.
  • At the time of testing, the code by default uses the weights of the latest run available in weights folder so, please bear in mind to put your desired weights in the appropriate folder inside the weights directory! ๐Ÿ‘‡

common/logger.py:

def load_weights(self):
    model_dir = glob.glob("weights/*")
    model_dir.sort()
    # model_dir[-1] -> means latest run!
    checkpoint = torch.load(model_dir[-1] + "/params.pth")
    self.log_dir = model_dir[-1].split(os.sep)[-1]
    return checkpoint

References

  1. A Distributional Perspective on Reinforcement Learning, Bellemare, et al., 2017
  2. Implicit Quantile Networks for Distributional Reinforcement Learning, Dabney et al., 2018
  3. Distributional Reinforcement Learning with Quantile Regression, Dabney et al., 2017
  4. Fully Parameterized Quantile Function for Distributional Reinforcement Learning, Yang et al., 2019
  5. Distributional Reinforcement Learning (draft), Bellemare et al., 2022

Acknowledgement

Following repositories were great guides to implement distributional rl ideas. Big thanks to them for their works:

  1. DeepRL_PyTorch by @Kchu
  2. FQF by @microsoft
  3. fqf-iqn-qrdqn.pytorch by @ku2482

distributional-rl's People

Contributors

alirezakazemipour avatar adi-vc avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.