Giter Club home page Giter Club logo

alpaca-rlhf's Introduction

alpaca-rlhf

Finetuning LLaMA with RLHF (Reinforcement Learning with Human Feedback).

Online Demo

Modifications on DeepSpeed Chat

Step 1

  • alpaca_rlhf/deepspeed_chat/training/step1_supervised_finetuning/main.py#main()
    • Set special tokens
  • alpaca_rlhf/deepspeed_chat/training/utils/data/data_utils.py#create_dataset_split()
    • Train only on responses and add eos
    • Remove end_of_conversation_token
  • alpaca_rlhf/deepspeed_chat/training/utils/data/data_utils.py#PromptDataset#getitem
    • Labels differs from input
  • alpaca_rlhf/deepspeed_chat/training/utils/data/raw_datasets.py#MultiTurnAlpacaDataset
    • add MultiTurnAlpacaDataset
  • alpaca_rlhf/deepspeed_chat/training/utils/module/lora.py#convert_linear_layer_to_lora
    • Support multiple module names for lora

Step 2

  • alpaca_rlhf/deepspeed_chat/training/step2_reward_model_finetuning/main.py#main()
    • Set special tokens
  • alpaca_rlhf/deepspeed_chat/training/utils/model/reward_model.py#RewardModel#forward()
    • Fixing the numerical instability
  • alpaca_rlhf/deepspeed_chat/training/utils/data/data_utils.py#create_dataset_split()
    • Remove end_of_conversation_token

Step 3

  • alpaca_rlhf/deepspeed_chat/training/step3_rlhf_finetuning/main.py#main()
    • Set special tokens
  • alpaca_rlhf/deepspeed_chat/training/utils/data/data_utils.py#create_dataset_split()
    • Fix max length bug
  • alpaca_rlhf/deepspeed_chat/training/utils/data/data_utils.py#DataCollatorRLHF#call
    • Fix padding side bug
  • alpaca_rlhf/deepspeed_chat/training/step3_rlhf_finetuning/ppo_trainer.py#DeepSpeedPPOTrainer#generate_experience
    • Normalize reward
  • alpaca_rlhf/deepspeed_chat/training/step3_rlhf_finetuning/ppo_trainer.py#DeepSpeedPPOTrainer#_generate_sequence
    • Mask the tokens after the eos

Stey by Step

  • Running all three steps on 2 x A100 80G
  • Datasets
  • Enter ./alpaca_rlhf directory first, then run the following commands:
    • step1: sh run.sh --num_gpus 2 /tmp/pycharm_project_227/alpaca_rlhf/deepspeed_chat/training/step1_supervised_finetuning/main.py --sft_only_data_path MultiTurnAlpaca --data_output_path /root/autodl-tmp/rlhf/tmp/ --model_name_or_path decapoda-research/llama-7b-hf --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --max_seq_len 512 --learning_rate 3e-4 --num_train_epochs 1 --gradient_accumulation_steps 8 --num_warmup_steps 100 --output_dir /root/autodl-tmp/rlhf/actor --lora_dim 8 --lora_module_name q_proj,k_proj --only_optimize_lora --deepspeed --zero_stage 2
      • when --sft_only_data_path MultiTurnAlpaca is added, please unzip data/data.zip first.
    • step2: sh run.sh --num_gpus 2 /tmp/pycharm_project_227/alpaca_rlhf/deepspeed_chat/training/step2_reward_model_finetuning/main.py --data_output_path /root/autodl-tmp/rlhf/tmp/ --model_name_or_path decapoda-research/llama-7b-hf --num_padding_at_beginning 0 --per_device_train_batch_size 4 --per_device_eval_batch_size 64 --learning_rate 5e-4 --num_train_epochs 1 --gradient_accumulation_steps 1 --num_warmup_steps 0 --zero_stage 2 --deepspeed --output_dir /root/autodl-tmp/rlhf/critic --lora_dim 8 --lora_module_name q_proj,k_proj --only_optimize_lora
      • the training process of step 2
      • The mean and standard deviation of the reward of the chosen responses are collected and used to normalize the reward in step 3. In one experiment, they are -0.8677118420600891 and 0.2210693359375 respectively and are used in the alpaca_rlhf/deepspeed_chat/training/step3_rlhf_finetuning/ppo_trainer.py#DeepSpeedPPOTrainer#generate_experience methods: 'rewards': (reward_score - (-0.8677118420600891)) / 0.2210693359375.
    • step3: sh run.sh --num_gpus 2 /tmp/pycharm_project_227/alpaca_rlhf/deepspeed_chat/training/step3_rlhf_finetuning/main.py --data_output_path /root/autodl-tmp/rlhf/tmp/ --actor_model_name_or_path /root/autodl-tmp/rlhf/actor/ --tokenizer_name_or_path decapoda-research/llama-7b-hf --critic_model_name_or_path /root/autodl-tmp/rlhf/critic --actor_zero_stage 2 --critic_zero_stage 2 --num_padding_at_beginning 0 --per_device_train_batch_size 4 --per_device_mini_train_batch_size 4 --ppo_epochs 2 --actor_learning_rate 9.65e-6 --critic_learning_rate 5e-6 --gradient_accumulation_steps 1 --deepspeed --actor_lora_dim 8 --actor_lora_module_name q_proj --critic_lora_dim 8 --critic_lora_module_name q_proj,k_proj --only_optimize_lora --output_dir /root/autodl-tmp/rlhf/final
      • the training process of step 3
  • Inference
    • nohup sh run_inference.sh 0 alpaca_rlhf/inference/llama_chatbot_gradio.py --path /root/autodl-tmp/rlhf/final/actor > rlhf_inference.log 2>&1 &
    • nohup sh run_inference.sh 0 alpaca_rlhf/inference/llama_chatbot_gradio.py --path /root/autodl-tmp/rlhf/actor > sft_inference.log 2>&1 &

Comparison between SFT and RLHF

  • Chat
    • SFT
    • RLHF
  • Write stories
    • SFT
    • RLHF

References

Articles

Sources

Tools

Datasets

Related Repositories

alpaca-rlhf's People

Contributors

abhi1092 avatar l294265421 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.