Giter Club home page Giter Club logo

bear's Introduction

BEAR (Bootstrapping Error Accumulation Reduction)

Update (07/17): We have released a cleaner implementation of BEAR on top of rlkit at: https://github.com/rail-berkeley/d4rl_evaluations, which goes with the latest version of the D4RL paper. We would encourage all users to use this new implementation as compared to this repo. We made hyperparameters consistent across environments, and retuned the algorithm for the new D4RL datasets.

Update (05/04): Added support for D4RL environments, https://github.com/rail-berkeley/d4rl.

This is the code for NeurIPS 2019 paper Stabilizing Off-Policy Q-Learning via Bootstrapping Error Reduction. Please refer to the project page: https://sites.google.com/view/bear-off-policyrl for details and slides explaining the algorithm.

Our code is built off of the BCQ[https://github.com/sfujim/BCQ] repository and uses many similar components. To run BEAR, please use a command like this:

python main.py --buffer_name=buffer_walker_300_curr_action.pkl --eval_freq=1000 --algo_name=BEAR
--env_name=Walker2d-v2 --log_dir=data_walker_BEAR/ --lagrange_thresh=10.0 
--distance_type=MMD --mode=auto --num_samples_match=5 --lamda=0.0 --version=0 
--mmd_sigma=20.0 --kernel_type=gaussian --use_ensemble_variance="False"
python main.py --buffer_name=buffer_hopper_300_curr_action.pkl --eval_freq=1000 --algo_name=BEAR
--env_name=Hopper-v2 --log_dir=data_hopper_BEAR/ --lagrange_thresh=10.0 --distance_type=MMD
--mode=auto --num_samples_match=5 --lamda=0.0 --version=0 --mmd_sigma=10.0 --kernel_type=laplacian --use_ensemble_variance="False"

Installation Instructions: Please download rlkit[https://github.com/vitchyr/rlkit] and follow the instructions on the installation of the rlkit environment as supported by your machine. Please make sure to use mujoco_py==1.50.1.56 and mjpro150 for the MuJoCo installation. Then run the above command. Any version of PyTorch >= 1.1.0 is supported (Note: Default rlkit pytorch version is 0.4.1, but this codebase needs pytorch >= 1.1.0; Also you might need to update numpy in your system to the latest numpy version). For easy visualization, we recommmend installing viskit[https://github.com/vitchyr/viskit] and using viskit for visualization. This repository is configured to writing log-files that are compatible with viskit.

Algorithms Supported:

  1. BCQ (algo_name=BCQ) [Fujimoto et.al. ICML 2019]
  2. TD3 (algo_name=TD3) [Fujimoto et.al. ICML 2018]
  3. Behavior Cloning (algo_name=BC)
  4. KL Control (algo_name=KLControl) [Jacques et.al. arxiv 2019]
  5. Deep Q-learning from Demonstrations (algo_name=DQfD) [Hester et.al. 2017]

Hyperparameter definitions:

  1. mmd_sigma: Standard deviation of the kernel used for MMD computation
  2. kernel_type: (gaussian|laplacian) Kernel type used for computation of MMD
  3. num_samples_match: Number of samples used for computing sampled MMD
  4. version: (0|1|2): Whether to use min(0), max(1) or mean(2) of Q-values from the ensemble for policy improvement
  5. buffer_name: Path to the buffer (prefered .pkl files, other options available in utils.py
  6. use_ensemble_variance: Whether to use ensemble variance for the policy improvement step (Set to False, else can result in NaNs)
  7. lagrange_thresh: The threshold for log of the Lagrange multiplier
  8. cloning: Set this flag to run behaviour cloning

Hyperparameters that generally work well (for BEAR, across environments):

  1. mmd_sigma=10.0, kernel_type=laplacian, num_samples_match=5, version=0 or 2, lagrange_thresh=10.0, mode=auto
  2. mmd_sigma=20.0, kernel_type=gaussian, num_samples_match=5, version=0 or 2, lagrange_thresh=10.0, mode=auto

We have removed ensembles from this version, and we just use a minimum/average over 2 Q-functions, without an ensemble-based conservative estimate based on sample variance. This is because we didn't find ensemble variance to in general provide benefits, although it doesn't hurt either. However, the code for ensembles is present in EnsembleCritic in the file algos.py. Also, please set use_ensemble_variance=True to use ensembles in the BEAR algorithm.

If you use this code in your research, please cite our paper:

@article{kumar19bear,
  author       = {Aviral Kumar and Justin Fu and George Tucker and Sergey Levine},
  title        = {Stabilizing Off-Policy Q-Learning via Bootstrapping Error Reduction},
  conference   = {NeurIPS 2019},
  url          = {http://arxiv.org/abs/1906.00949},
}

For any questions/issues please contact Aviral Kumar at [email protected].

bear's People

Contributors

aviralkumar2907 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

bear's Issues

KeyError: 'data_policy_mean' when run BEAR_IS

Traceback (most recent call last):
File "/home/hq/code/remotepycharmfolder/BEAR-master/main.py", line 209, in
pol_vals = policy.train(replay_buffer, iterations=int(args.eval_freq))
File "/home/hq/code/remotepycharmfolder/BEAR-master/algos.py", line 655, in train
state_np, next_state_np, action, reward, done, mask, data_mean, data_cov = replay_buffer.sample(batch_size, with_data_policy=True)
File "/home/hq/code/remotepycharmfolder/BEAR-master/utils.py", line 40, in sample
data_mean = self.storage['data_policy_mean'][ind]
KeyError: 'data_policy_mean'

When running BEAR_IS algorithm, that raises this error.
Note that your program don't save data_policy_mean and data_policy_logvar in the buffer : )

In place operations in algos.py

I keep getting this error due to some in place changes to the variable a in sample_multiple:

[W python_anomaly_mode.cpp:60] Warning: Error detected in AddmmBackward. Traceback of forward call that caused the error: File "/home/7331215/wrappers/run_optimizer.py", line 211, in <module> main(sys.argv[1:]) File "/home/7331215/wrappers/run_optimizer.py", line 152, in main RewPred.generate_knobs() File "/home/7331215//wrappers/../../rewardpredictor/rewardpredictor_base.py", line 431, in generate_knobs self.generate_knobs_BEAR() File "/home/7331215//wrappers/../../rewardpredictor/rewardpredictor_base.py", line 532, in generate_knobs_BEAR pol_vals = policy.train(replay_buffer, iterations = int(5e3)) File "/home/7331215/wrappers/../../rl/Algos/BEAR/algos.py", line 440, in train actor_actions, raw_actor_actions = self.actor.sample_multiple(state, num_samples)# num) File "/home/7331215/../../rl/Algos/BEAR/algos.py", line 76, in sample_multiple log_std_a = self.log_std(a.clone()) File "/home/7331215/virtenvs/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "/home/7331215/virtenvs/lib64/python3.6/site-packages/torch/nn/modules/linear.py", line 91, in forward return F.linear(input, self.weight, self.bias) File "/home/7331215/virtenvs/lib64/python3.6/site-packages/torch/nn/functional.py", line 1674, in linear ret = torch.addmm(bias, input, weight.t()) (function print_stack) Traceback (most recent call last): File "/home/7331215/wrappers/run_optimizer.py", line 211, in <module> main(sys.argv[1:]) File "/home/7331215/wrappers/run_optimizer.py", line 152, in main RewPred.generate_knobs() File "/home/7331215/wrappers/../../rewardpredictor/rewardpredictor_base.py", line 431, in generate_knobs self.generate_knobs_BEAR() File "/home/7331215/wrappers/../../rewardpredictor/rewardpredictor_base.py", line 532, in generate_knobs_BEAR pol_vals = policy.train(replay_buffer, iterations = int(5e3)) File "/home/7331215//wrappers/../../rl/Algos/BEAR/algos.py", line 508, in train (-lagrange_loss).backward() File "/home/7331215/virtenvs/lib64/python3.6/site-packages/torch/tensor.py", line 185, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/home/7331215/virtenvs/lib64/python3.6/site-packages/torch/autograd/__init__.py", line 127, in backward allow_unreachable=True) # allow_unreachable flag RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [300, 32]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Any guidance for how to fix? I have edited main.py to adapt to my specific problem task but haven't edited algos.py except to try to debug this error.

Couldn't reproduce the result on Mujoco suite.

Couldn't reproduce the result on the Mujoco suite.
Setting: We run the BEAR with the recommend settings: ** mmd_sigma = 20.0 , kernel_type = gaussian , num_samples_match = 5 , version = 0 or 2 , lagrange_thresh = 10.0 , `mode = auto**
The batch dataset is produced by training a DDPG agent for 1 million time steps. For reproducing, we use the DDPG code in BCQ repository.

We utilize the final buffer setting in BCQ paper.
Here are the whole results.
Note that the "behavioral" means the evaluation of the DDPG agent when training.

Sweetice 0720 update: Upload the png file for easily reading.
Ant-v2
HalfCheetah-v2
Hopper-v2
InvertedDoublePendulum-v2
InvertedPendulum-v2
Reacher-v2
Swimmer-v2
Walker2d-v2

For more clear reading.
Ant-v2.pdf
HalfCheetah-v2.pdf
Hopper-v2.pdf
InvertedDoublePendulum-v2.pdf
InvertedPendulum-v2.pdf
Reacher-v2.pdf
Swimmer-v2.pdf
Walker2d-v2.pdf

Algorithm question

I am confused about the VAE network, could you please explain it? Just to pre-train and obtain the distribution of behavior policy?

BEAR/algos.py

Lines 395 to 398 in f2e31c1

recon, mean, std = self.vae(state, action)
recon_loss = F.mse_loss(recon, action)
KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
vae_loss = recon_loss + 0.5 * KL_loss

What is delta_conf ?

Thank you for your greta work.

BEAR/algos.py

Lines 468 to 470 in f2e31c1

actor_loss = (-critic_qs +\
self._lambda * (np.sqrt((1 - self.delta_conf)/self.delta_conf)) * std_q +\
self.log_lagrange2.exp() * mmd_loss).mean()

A default value for delta_conf is 0.1 in the code.
But I cant' find what this means in the paper.
Can you explain this?

Couldn't reproduce the results reported in D4RL

Hi @aviralkumar2907

Thank you for the code!
I try to reproduce the results reported in D4RL "walker2d-medium-v0" environment.
I run the code with the following command:
"python main.py --eval_freq=1000 --algo_name=BEAR --env_name=walker2d-medium-v0 --log_dir=data_walker_BEAR/ --lagrange_thresh=10.0 --distance_type=MMD --mode=auto --num_samples_match=5 --lamda=0.0 --version=0 --mmd_sigma=20.0 --kernel_type=gaussian --use_ensemble_variance="False""

The results are averaged over four random seeds, which are shown here:
data_walker_BEAR

However, the score looks much lower than the reported score 1526.7 in walker2d-medium-v0 D4RL paper.

Do you know how to solve this?

Thank you!

Best,
Rui

Question about using Importance Sampling in BEAR

Hello, I have some problem with the BEAR_IS in your algos.py file.

As is known to us, DDPG is actually one-step Q-learning in continuous tasks and BEAR also takes such architechture. Now that it makes no sense to use importance sampling in BEAR because the difference between current policy and behavioral policy doesn't result in the inaccuracy of the estimation of Q-value.

So Can you explain why you wrote a importance sampling version of BEAR in your project?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.