Giter Club home page Giter Club logo

d3q's Introduction

d3q :: Q-Learning Showcase

This project contains a demonstration framework of Q-learning (reinforcement learning) written in Python. It takes advantage of TensorFlow to address one of the simplest OpenAI Gym problems: the Cart Pole.

โš  Warning Foundations of this project were created for a hackaton to demonstrate a ML system utilizing all 8 hardware accelerators and all CPU cores on a modern, ML-dedicated server. If you are looking for a Reinforced Learning framework - this is not the right project: you are probably looking for TensorFlow Agents or Ray RLlib.

To start, just clone the sources and run:

# one-time setup
python --version  # should be 3.7-3.10
pip install virtualenv
python -m virtualenv venv
source .\venv\bin\activate  # on Windows: .\venv\Scripts\activate
pip install -r .\requirements.txt
pip install -r .\requirements-test.txt
python -m pytest -vs .

# run the training
python -m d3q.apps.train

# run concurrently to see how the training advances
python -m d3q.apps.preview

Design

At the core of the solution lay the separation of the Q-learning process into three asynchronous entities:

  • Q-Trainer - the main training process perfoming actual gradient-descent steps using TensorFlow.
  • Simulators - multiple independent environments exploring the state space, gathering SARS' experiences, and evaluating new versions of the model.
  • Replay Memory - a prioritized replay buffer of SARS' samples, which memorizes new entries, samples them randomly, then updates samples' priorities.

dataflow_diagram

All of them are operated by distinct Python processes and use inter-proces queues to communicate.

The main process spawns the replay memory and simulators, all as services. Simulators start playing games once they receive the initial model parameters, where the first playthrough is purely for evaluation purposes (actions are not affected by the random policy). Sims gather SARS' samples and send them as experience batches to the replay memory, which stores them in the prioritized replay memory. Once the replay memory has sufficient number of records in its store, it starts sampling them randomly, forming the batch of memories provided to the Q-trainer. The Q-trainer receives the SARS' memories and perform a Q-learning training step. After specific number of steps, the model is saved and its parameters send to simulators. The training stops either after the evaluation score reaches the pre-set goal, or the training fails, reaching arbitrary time or sample-based limit.

During the training you can see the learning curve using TensorBoard:

python -m tensorboard.main --logdir tblog

dataflow_diagram

The axis X is expressed in terms of number of samples (SARS' memories) processed by the Q-trainer. The model itself and the hyperparameters are optimized to minimize this particular value, i.e. number of samples processed to reach the high average evaluation score.

In ReplayMemory, SARS' records are indexed using sophisticated priority tree data structure, which allows for storing larger amount of data and manipulating its corresponding sampling priorities.

d3q's People

Contributors

isameru avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.