Giter Club home page Giter Club logo

online-rlhf's Introduction

Online RLHF

TL;DL: this is a repo to align the large language models (LLMs) by online iterative RLHF. Also check out our technical report and Huggingface Repo!

We present the workflow of Online Iterative Reinforcement Learning from Human Feedback (RLHF), which is widely reported to outperform its offline counterpart by a large margin in the recent LLM literature. However, existing open-source RLHF projects are still largely confined to the offline learning setting. In this repo, we aim to fill in this gap and provide a detailed recipe that is easy to be reproduced for online iterative RLHF. In particular, with our recipe, with only open-source data, we can achieve comparable or even better results than LLaMA3-8B-instruct.

image

Model Releases

Installation instructions

It is recommeded to have two separate environments for inference and training, respectively.

Note that the numpy version should be numpy<2.0. Numpy 2.0 will encounter unexpected issues!!!

Inference Environment

conda create -n vllm python=3.10.9
conda activate vllm
pip install datasets
# The following code is tested for CUDA12.0-12.2. You may need to update the torch and flash-attention sources according to your own CUDA version
pip3 install torch==2.1.2 torchvision torchaudio
pip install https://github.com/vllm-project/vllm/releases/download/v0.4.0/vllm-0.4.0-cp310-cp310-manylinux1_x86_64.whl 
pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.7/flash_attn-2.5.7+cu122torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

pip install accelerate==0.27.2
pip install deepspeed

Training Environment

conda create -n rlhflow python=3.10.9
conda activate rlhflow

git clone https://github.com/huggingface/alignment-handbook.git
cd ./alignment-handbook/
git checkout d17fd7cd3b71c6a7bf7af34d8dc73135bb7ea8e9
pip3 install torch==2.1.2 torchvision torchaudio
python -m pip install .
pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.7/flash_attn-2.5.7+cu122torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install accelerate==0.27.2

You also need to install the wandb to record the training and login with your huggingface account so that you have access to the LLaMA3 models.

pip install wandb

wandb login
huggingface-cli login

Get Started

We present a step-by-step guidance in this section.

Step 1 Supervised Fine-tuning

To start with, you should first preprocess your dataset into the standard format. Here is an example of the dataset. You may need to adjust the hyper-parameters (batch size, packing size) according to your computational resources. To run SFT, you can use the following command.

# You can adjust the training parameters in ./sft/sft.py
accelerate launch ./sft/sft.py

# Train with deepspeed stage3 
# You may need to adjust ./configs/zero3.yaml, especially the num_processes (the number of GPUs) according to your environment
accelerate launch --config_file ./configs/zero3.yaml ./sft/sft.py

Step 2 Reward Modeling

We refer the interested readers to this repo for a detailed recipe to train the state-of-the-art open-source reward/preference models. We have trained several RMs and prepared them on the huggingface like sfairXC/FsfairX-LLaMA3-RM-v0.1 and RLHFlow/pair-preference-model-LLaMA3-8B, which are SOTA open-source RMs so far (2024 May).

image

Step 3.1 Data Generation

To accelerate data generation, we use the VLLM. We prepare two ways of using VLLM to inference for a more robust implementation, where you can try them out and choose the one that fits with your environment best. We use LLaMA3-8B as an example. For other models, you need to adjust the eos_ids.

You may create a test_gen.sh file, and copy the following contents into the file and run ``bash test_gen.sh''.

# First approach: initialize 4 VLLM processes and split the prompt set to the 4 agents
# The generated samples will be stored at output_dir + local_index + ".json

my_world_size=4 # how many gpu you use
infer_model=meta-llama/Meta-Llama-3-8B-Instruct
prompt_dir=RLHFlow/test_generation_2k
mkdir data
output_dir=./data/gen_data

conda activate vllm
CUDA_VISIBLE_DEVICES=0 python ./generation/get_hf2.py --model_name_or_path ${infer_model} --dataset_name_or_path ${prompt_dir} --output_dir ${output_dir} --K 4 --temperature 1.0 --local_index 0 --my_world_size ${my_world_size} --eos_ids 128009 &
CUDA_VISIBLE_DEVICES=1 python ./generation/get_hf2.py --model_name_or_path ${infer_model} --dataset_name_or_path ${prompt_dir} --output_dir ${output_dir} --K 4 --temperature 1.0 --local_index 1 --my_world_size ${my_world_size} --eos_ids 128009 &
CUDA_VISIBLE_DEVICES=2 python ./generation/get_hf2.py --model_name_or_path ${infer_model} --dataset_name_or_path ${prompt_dir} --output_dir ${output_dir} --K 4 --temperature 1.0 --local_index 2 --my_world_size ${my_world_size} --eos_ids 128009 &
CUDA_VISIBLE_DEVICES=3 python ./generation/get_hf2.py --model_name_or_path ${infer_model} --dataset_name_or_path ${prompt_dir} --output_dir ${output_dir} --K 4 --temperature 1.0 --local_index 3 --my_world_size ${my_world_size} --eos_ids 128009 &

wait
python ./generation/merge_data.py --base_path ${output_dir} --output_dir ./data/gen_data.json --num_datasets ${my_world_size}

We can also use API server to generate new responses.

# First approach: initialize 4 VLLM processes and split the prompt set to the 4 agents
# The generated samples will be stored at output_dir + local_index + ".json

my_world_size=4
infer_model=meta-llama/Meta-Llama-3-8B-Instruct
prompt_dir=RLHFlow/test_generation_2k
mkdir data
output_dir=./data/gen_data.json
conda activate vllm

# register the api server
bash ./generation/run_8gpu.sh $infer_model
python ./generation/gen_hf.py --ports 8000 8001 8002 8003 8004 8005 8006 8007 --eos_ids 128009 --tokenizer $infer_model --dataset_name_or_path $prompt_dir --output_dir $output_dir --K 4 --temperature 1.0

Step 3.2 Data Annotation

Then, we call the reward/preference model trained in step 2 to rank the generated responses.

accelerate launch ./annotate_data/get_rewards.py --dataset_name_or_path ./data/gen_data.json --output_dir ./data/data_with_rewards.json --K 4

If you encounter error ``TypeError: Got unsupported ScalarType BFloat16'', considering pip install transformers==4.38.2

Remark: following LLaMA2 project, the current implementation assumes that the RM shares the same chat template with the model to be aligned. In many cases, however, the RM may have its own chat template. You can update the change_of_format function in get_rewards.py and enable

# Around line 123
test_texts = [change_of_format(sample['prompt'], tmp_output) for tmp_output in sample['responses']]

Step 3.3 Training

conda activate rlhflow
model_path=meta-llama/Meta-Llama-3-8B-Instruct
initial_model=meta-llama/Meta-Llama-3-8B-Instruct
mkdir models
accelerate launch --config_file ./configs/zero2.yaml ./dpo_iteration/run_dpo.py --run_name rlhflow_iter1 --output_dir ./models/rlhflow_iter1 --model_name_or_path $model_path --ref_model $initial_model --learning_rate 2e-7 --max_steps 1200 --choose_type max_min --train_dir ./data/data_with_rewards.json --eval_dir ./data/data_with_rewards.json --loss_type sigmoid --lr_scheduler_type cosine

If you encounter ``RuntimeError: CUDA error: invalid device ordinal, CUDA kernel errors might be asynchronously reported at some other API call'', you need to adjust num_of_process in the config file according to your GPUs.

Putting Everything Together

We put everything together so that the iterative training can run automatically. Note that we set sleep 1m to wait for registering the API for inference. You may need to adjust this parameter according to your environment.

bash run_loop.sh

Acknowledgement

The authors would like to thank the great open-source communities, including the Huggingface TRL team, the Huggingface H4 team, the Allen Institute AI RewardBench team, the Meta LLaMA team, and Axolotl team for sharing the models, codes, and training sets.

Citation

If you find the content of this repo useful, please consider cite it as follows:

@misc{dong2024rlhf,
      title={RLHF Workflow: From Reward Modeling to Online RLHF}, 
      author={Hanze Dong and Wei Xiong and Bo Pang and Haoxiang Wang and Han Zhao and Yingbo Zhou and Nan Jiang and Doyen Sahoo and Caiming Xiong and Tong Zhang},
      year={2024},
      eprint={2405.07863},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@inproceedings{xiong2023iterative,
  title={Iterative preference learning from human feedback: Bridging theory and practice for RLHF under KL-constraint},
  author={Xiong, Wei and Dong, Hanze and Ye, Chenlu and Wang, Ziqi and Zhong, Han and Ji, Heng and Jiang, Nan and Zhang, Tong},
  booktitle={ICLR 2024 Workshop on Mathematical and Empirical Understanding of Foundation Models}
}

online-rlhf's People

Contributors

erjanmx avatar hendrydong avatar weixiongust avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

online-rlhf's Issues

Questions about training data during iterative DPO

Hi, awesome work and thanks for open source!

In each iteration, it seems that only using the 20k dataset annotated by reward model. However, the pipeline in the paper shows that the historical dataset will grow. Please let me know if I have missed something important. Thanks a lot!

More RLHF algorithms in the implementation

I saw the choice of the loss type indicating that several other loss functions can be used like hinge, ipo, raft ...

I am wondering whether we only need to modify the loss choice and do not need to modify other parts of the codes.

Reference policy ablations

Dear authors,

  1. I noticed that the reference policy is fixed as the initial policy, instead of updating as the last iter's policy. May I know the reason for it and have you tried updating the reference policy?
  2. The initial policy is "RLHFlow/LLaMA3-SFT", not "meta-llama/Meta-Llama-3-8B-Instruct", right?

One question about the loss function given a gold reward model

As illustrated, the gold reward model like sfairXC/FsfairX-LLaMA3-RM-v0.1 is trained following BT model and performs very well, then one question naturally arises that why the DPO loss function is still
$\log \pi(y_1)\pi_{ref}(y_2)/\pi_{ref}(y_1)\pi(y_2)$, instead of the real BT model, which is expected to be $\sigma(r_1-r_2)\log \pi(y_1)\pi_{ref}(y_2)/\pi_{ref}(y_1)\pi(y_2)+\sigma(r_2-r_1)\log \pi(y_2)\pi_{ref}(y_1)/\pi_{ref}(y_2)\pi(y_1)$ (omit hyper-parameter here). Could you provide some intuitions?

question about dpo dataset

Hi, awesome work and thanks for open source!

In reading your article 'RLHF Workflow: From Reward Modeling to Online RLHF', Chapter 3 mentions: “ Hybrid batch learning. We formulate a slightly more general framework to combine an initial offline dataset with online data collected during training.“

Does the mixing of the two types of data here refer to a hybrid method where you start from p0(sft), use p0 as the reference model, and include the 20k data annotated in each iteration? Or does it refer to the corresponding solution mentioned in LLaMA2 to address the 'alignment tax' issue?

Thanks!

Phi3 has a nearly constant DPO loss of 0.69xx

Issue: Implementing Iterative DPO on Phi3-4k-instruct

Hi, thanks for the great work and open source!

I am trying to implement iterative DPO on Phi3-4k-instruct. The following outlines my approach:

  1. Generation Step:

    python generation/gen_hf.py --ports 8000 8001 8002 8003 --tokenizer microsoft/Phi-3-mini-4k-instruct --dataset_name_or_path $jsonl_input --output_dir $json_output --K 8 --temperature 1.0
  2. Reward Annotation:

    accelerate launch annotate_data/get_rewards.py --dataset_name_or_path $json_output --output_dir $model_output

    Note: I have commented line 124 and uncommented line 123 in this file to handle the chat template of Phi3 differently from the Llama3-based reward model. This might be incorrect as I have not modified the change_of_format() function!

  3. DPO Iteration:

    accelerate launch dpo_iteration/run_dpo.py --run_name $iteration --output_dir $iteration --model_name_or_path microsoft/Phi-3-mini-4k-instruct --ref_model microsoft/Phi-3-mini-4k-instruct --learning_rate 5e-7 --max_steps 1200 --choose_type max_min --train_dir $model_output --eval_dir $model_output --loss_type sigmoid --lr_scheduler_type cosine

After performing these steps, the DPO loss is stuck at 0.69xx. I am running at a batch size of 128 and a learning rate of 5e-7.

Any insights to help get a Phi3 variant of iterative DPO would be greatly appreciated.

Thanks!

Distributed training in stage 3.3 keeps hanging

In stage 3.3, when I set distributed_type as NO, the code runs well; while when I try distributed_type as DEEPSPEED or MULTI_GPU, the code gets stuck when loading training_args = TrainingArguments(. For DEEPSPEED, the terminal stucks when showing

[2024-06-11 00:23:36,254] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-06-11 00:23:36,254] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2024-06-11 00:23:36,296] [INFO] [comm.py:637:init_distributed] cdb=None

Do you have some idea? My cuda version is 12.4

questions about dpo

Hi,I have some questions about dpo:

  1. Is there any reason why choosing Nectar dataset to train offline vanilla dpo rather than using the same dataset as iterative dpo, for a possibly more fair comparison?
  2. 3.Have you applied iterative dpo to llama3-70b? When applying iterative dpo to Llama3-70B, what specific details should be paid attention to?

Thanks for your assistance.

Iterative pipeline question

I have some questions about the iterative pipeline. Please correct me if my understanding is wrong, thank you so much!

From the report, \pi_0 should be the SFT policy trained on SFT-OpenHermes-2.5-Standard, (LLaMA3-SFT I guess?), and \pi_1 is the policy further trained with DPO on a historical dataset, is the dataset iterative-prompt-v1-iter1-20K?

After we get \pi_1, we should use it to generate answers on iterative-prompt-v1-iter2-20K, labeled by the reward model, and then use run_dpo to get \pi_2 (with reference model still SFT model but start from \pi_1?). Thanks again!

numpy version and transformers version

I encounter the following package version related issue when training with Gemma. The solution is to update transformers and downgrade NumPy.

/configuration_auto.py", line 795, in __getitem__
    raise KeyError(key)
KeyError: 'gemma'
 UserWarning: Failed to initialize NumPy: _ARRAY_API not found (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
  device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'),

A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.0 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Fail to load weight from pair-preference-model-LLaMA3-8B

Hi, congratulations to the great work and thanks for open source!

I am running step 3.2 with pair-preference-model-LLaMA3-8B. However, I encountered the warning "Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at RLHFlow/pair-preference-model-LLaMA3-8B and are newly initialized: ['score.weight']". Could you please help me with the issue? Thanks a lot!

Cannot Reproduce the DPO Checkpoint

Hi,

I tried to reproduce the training process from sft to dpo. I ran the run_loop.sh script, the only change I made is setting initial_model="RLHFlow/LLaMA3-SFT". After 3 iterations, the final model checkpoint of iteration 3 has a mtbench score of 7.95, which is different than the reported number. The initial sft start point has the same mtbench score as the paper reported.

I did not modify other settings in run_loop.sh script. Please let me know if there is anything additional needed to reproduce.

large max_steps?

Hi, thanks for the great work!

I noticed max_steps is set at 1200 for each DPO step. Given that there's 20K examples and 128 bsz, isn't this almost 8 epochs? Do we need so many steps empirically?

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.