Giter Club home page Giter Club logo

lamo-2023's Introduction

Unleashing the Power of Pre-trained Language Models for Offline Reinforcement Learning

pytorch arXiv Twitter License

This repo is the official code release for the ICLR 2024 conference paper:

 

Unleashing the Power of Pre-trained Language Models for Offline Reinforcement Learning
Ruizhe Shi*1, Yuyao Liu*1, Yanjie Ze2, Simon Shaolei Du3, Huazhe Xu124
The International Conference on Learning Representations (ICLR) 2024
1Tsinghua Universtiy, IIIS   2Shanghai Qi Zhi Institute   3University of Washington   4Shanghai AI Lab
*Equal contribution. Order is decided by coin flip.

 

🧾 Introduction

We propose LaMo, an offline RL framework that leverages the pre-trained Language Models (LMs) for low-level Motion control. On sparse-reward tasks, LaMo achieves strong results and surpasses recent strong algorithms CQL, IQL, TD3+BC, and DT; On dense-reward tasks, LaMo significantly improves Decision Transformer and closes the gap between value-based methods and DT-based methods. Notably, in low-data scenarios, our method demonstrates powerful few-shot learning ability, which can be attributed to the inductive bias from pre-trained LMs.

We look into the relationship between the performance of various algorithms and the scale of data. As depicted in the Figure, LaMo is capable of achieving excellent performance even with relatively small datasets. For example, in Hopper, LaMo surpasses the performance of CQL and DT when the sample ratio of data is 0.5% and maintains this advantage consistently as the sample ratio increases.

Below, we visualize 8 tasks across 3 domains that we consider.

  • D4RL
    • MuJoCo: Hopper, Walker2d, HalfCheetah, Reacher2d
    • Kitchen
  • Atari: Breakout, Qbert, Pong

💻 Installation

D4RL

Environment

Install MuJoCo

First, you need to download the file from this link and tar -xvf the_file_name in the ~/.mujoco folder. Then, run the following commands.

cd experiment-d4rl
conda env create -f env.yml

After that, add the following lines to your ~/.bashrc file:

export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/YOUR_PATH_TO_THIS/.mujoco/mujoco210/bin
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia

Remember to source ~/.bashrc to make the changes take effect.

Install D4RL

Install D4RL by following the guidance in D4RL.

Dataset

To download original D4RL data,

cd data
python download_d4rl_datasets.py

To get downsampled data, you need to modify line 10 of 'data/mujoco/ratio_dataset.py' and line 10 of 'data/kitchen/ratio_dataset.py' as

suffix = [your data version name]

and then run

cd data
cd mujoco
python ratio_dataset.py
cd ..
cd kitchen
python ratio_dataset.py
cd ..

Besides, you can directly get our pre-processed data in this link.

You can also try generating the data using a PPO agent trained by yourself, as provided in ‘data/data_generation_PPO’.

Atari

Environment

First make sure you have the dependencies to install Atari.

sudo apt install cmake
sudo apt install zlib1g-dev

Then run the following commands.

cd experiment-atari
conda env create -f env.yml

Dataset

The dataset will be downloaded automatically and cached locally by the package d4rl-atari once you launch an experiment. To reproduce our results of downsampled datasets, you can set the seed to be identical to ours (3 seeds, 0, 1, and 2), and our implementation of experiment-atari/buffer.py will make sure that the downsampled dataset will also be identical to ours.

🛠️ Usage

D4RL

After installing the packages and data, to reproduce our results on D4RL, you only need to run

cd experiment-d4rl
bash run.sh [env_name] [dataset_name] [sample_ratio] [description] [seed] [gpu]

An example is:

bash run.sh hopper medium 0.1 reproduce 0 0

If you want to view results on Weights & Biases, you need to modify line 435, 436 of '/code/experiment.py' as:

entity=[your-group-name],
project=[your-project-name],

Trying more configurations is encouraged! Important arguments are explained as below:

-w # enable wandb
--sample_ratio your_sample_ratio # determine the size of the data you are training on, like 0.1
--data_suffix your_data_version_name # you could downsample the data by yourself, default is "d1"
--mlp_embedding # use MLP as embeddings and projections
--adapt_mode # otherwise fully fine-tuning
--adapt_embed # fine-tune embeddings and projections when adapt_mode is ON
--lora # fine-tune low rank matrices of Transformer when adapt_mode is ON
--pretrained_lm language_model_name # you could try 'gpt2' and 'gpt2-medium'
--co_training # use language loss as auxiliary objective
--co_lambda # the weight of language loss, like 0.1

We provided all scripts in this link.

Atari

To reproduce our results on Breakout with one click, run the following commands

cd experiment-atari
bash run.sh 

Since we use Hydra to manage the configuration of the experiments on Atari, you can overwrite hypermeters conveniently. If you want to run experiments on more environments, add the configuration for the corresponding environment under experiments-atari/cfgs/env. Refer to the documentation of Hydra for more details. Here are a few important hyperparameters:

env # environment name (breakout, qbert, pong, or any atari environment you want to explore)
pretrained_lm # gpt2, gpt2-medium or none
seed # 0, 1, 2
sample_ratio # the ratio of dataset you train on
model.random_initialize # randomly initialize the weight of the model (overwrite the pretrained weight) or not 
model.adapt_cfg.use_adapt # use adapt mode or not (relative to fully finetune)
model.adapt_cfg.adapt_embed # unfreeze embedding or not 
model.lora_cfg.use_lora # use lora or not
model.lora_cfg.lora_attn_dim # the dimension of lora
model.context_len # the context length of the transformer model
train.lr # learning rate
train.weight_decay # weight decay
train.batch_size # batch size
nlp_train.co_training # use language joint training or not
nlp_train.co_lambda # the weight of language joint training loss

🙏 Acknowledgement

LaMo is based on many open-source projects, including Decision Transformer, Can Wikipedia Help Offline Reinforcement Learning, LoRA, DeFog, d4rl-atari. We thank all these authors for their nicely open sourced code and their great contributions to the community.

🏷️ License

LaMo is licensed under the MIT license. See the LICENSE file for details.

📝 Citation

If you find our work useful, please consider citing:

@article{Shi2024LaMo,
  title={Unleashing the Power of Pre-trained Language Models for Offline Reinforcement Learning},
  author={Ruizhe Shi and Yuyao Liu and Yanjie Ze and Simon S. Du and Huazhe Xu},
  journal={International Conference on Learning Representations}, 
  year={2024}
}

lamo-2023's People

Contributors

srzer avatar nature21 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.