Giter Club home page Giter Club logo

repb's Introduction

RL for SED

This repository holds the implementation code for the paper Optimizing Sequential Experimental Design with Deep Reinforcement Learning, published in ICML 2022.

The code herein is based on the reinforcement learning framework Garage and the pyro implementation of stochastic gradient BOED by Adam Foster et al.

Installation

This code has been tested in Python 3.7.11. We offer no guarantee that it will be supported in other versions. Due to dependence on the rpy2 package, you will need a local copy of R installed (see the rpy2 installation instructions). If you do not want to install rpy2 and are not interested inr unning the SMC experiment, simply remove this package from the requirements.txt file.

Installation is done using venv:

python -m venv boed
source boed/bin/activate
pip install -r requirements.txt

Running Experiments

To run the experiments in the paper we must use different utilities for the different algorithms, since they come from different sources. The probabilistic models themselves are implemented in boed/pyro/models/adaptive_experiment_model.py and are transformed into MDPs by the AdaptiveDesignEnv class implemented in boed/pyro/envs/adaptive_design_env.py.

RL Experiments

The RL experiments can be run by executing the python files in the Experiments folder. These expect certain arguments, although some have default values. To reproduce the exact settings in the paper:

For Source location

python -m Experiments.Adaptive_Source_SAC --n-contr-samples=100000 --n-rl-itr=20001 --log-dir=<log_dir>/boed_results/source  --bound-type=lower --id=1 --budget=30 --discount=0.9 --buffer-capacity=1000000 --tau=0.001 --pi-lr=0.0001 --qf-lr=0.0003 --ens-size=2

For CES

python -m Experiments.Adaptive_CES_SAC --n-contr-samples=100000 --n-rl-itr=20001 --log-dir=<log_dir>/boed_results/ces  --bound-type=lower --id=1 --budget=10 --discount=0.9 --buffer-capacity=1000000 

For Prey population

python -m Experiments.Adaptive_Prey_SAC --n-contr-samples=10000 --n-rl-itr=40001 --log-dir=<log_dir>/boed_results/prey  --bound-type=lower --id=1 --budget=10 --tau=0.01 --pi-lr=0.0001 --qf-lr=0.001 --discount=0.95 --buffer-capacity=1000000 --temp=1 --target-entropy=2.85189124 --ens-size=10

where <log_dir> needs to be replaced with a path to a logging directory of your choice.

PCE Experiments

The PCE experiments are run by executing files in the root folder of the repository that are based on the ones provided by Foster et al.

For Source location

python source.py --num-steps=30 --num-parallel=100  --name=pce --typs=pce --num-gradient-steps=2500 

For CES

python ces.py --num-steps=10 --num-parallel=100  --name=pce --typs=pce --num-acquisition=1 --num-gradient-steps=2500 

For CES with PCE-BO

python ces.py --num-steps=10 --num-parallel=100  --name=bo --typs=bo --num-gradient-steps=2500 

For Prey population

python prey.py --num-steps=10 --num-parallel=100  --name=pce --typs=pce

Random Experiments

Random experiments can be executed just like the PCE experiments, but replace the --types=pce flag to --typs=rand and change the --name flag appropriately.

SMC Experiments

The SMC experiment can be run by executing the file SMC_prey.py in the root folder.

DAD Experiments

DAD experiments are not included in this repository, and must be run using code from the original paper's repository.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.