Giter Club home page Giter Club logo

selectit's Introduction

SelectIT: Selective Instruction Tuning for Large Language Models via Uncertainty-Aware Self-Reflection

Overview

We introduces a novel data selection strategy, SelectIT, for LLM instruction tuning, which uses LLM uncertainty to efficiently identify high-quality IT data without requiring additional resources. SelectIT includes three types of self-reflection: token, sentence, and model, which can individually and jointly improve the performance of IT data selection. By applying SelectIT to the Alpaca-GPT4 dataset, we introduce a compact and strong IT dataset, called Selective Alpaca. Different models and domain tasks demonstrate the effectiveness of SelectIT. Our analysis reveals that SelectIT effectively excludes abnormal data and tends to select longer and calculational data.

LLMs-MT

Figure 1: The overall framework of SelectIT.

Environment

Framework Versions:

  • PyTorch >= 2.0.1
  • Python >= 3.10.0
  • Transformers == 4.35.0
git clone [email protected]:Blue-Raincoat/SelectIT.git

pip3 install -r requirements.txt

Data

We introduce a novel IT dataset, Selective Alpaca, by selecting the high-quality IT data from the Alpaca-GPT4 dataset.

Self-Reflection

SelectIT employs different grain uncertainty of LLMs: token, sentence, and model, which can effectually improve the accuracy of IT data selection.

Token-level Self-Reflection

We use the foundation model itself to rate the IT data based on the uncertainty of various tokens.

python3 self_reflection/token_level.py \
    -model-name-or-path models--meta-llama--Llama-2-7b-hf \ # path of LLMs
    -rp ./data/rating_prompt.txt \ # path of rating_prompt
    -i ./data/alpaca_gpt4.json \  # path of instruction dataset
    -o ./data/test.json \ # path of output dataset
    -k 5 \ # hyper-parameters
    -proportion 0.2 \ # the number of instruction data
    -alpha 0.2 \ # hyper-parameters
    

Sentence-level Self-Reflection

We use sentence-level uncertainty to improve the rating process by exploiting the effect of different prompts on LLMs.

python3 self_reflection/sentence_level.py \
    -model-name-or-path models--meta-llama--Llama-2-7b-hf \
    -rp ./data/rating_prompt.txt \
    -i ./data/alpaca_gpt4.json \
    -o ./data/test.json \
    -k 5
    -proportion 0.2 \
    -alpha 0.2 \

Model-level Self-Reflection

We utilize the uncertainty between different LLMs, enabling a collaborative decision-making process for IT data selection.

python3 self_reflection/model_level.py \
    -model-name-or-path models--meta-llama--Llama-2-7b-hf,models--meta-llama--Llama-2-13b-hf,models--meta-llama--Llama-2-70b-hf \
    -rp ./data/rating_prompt.txt \
    -i ./data/alpaca_gpt4.json \
    -o ./data/test.json \
    -k 5
    -proportion 0.2 \
    -alpha 0.2 \

Train

We provide two training configuration for different LLMs, which can effectively verify the robustness of SelectIT.

LLaMA-2-7B

We can utilize the following training configuration to fine tune the LLaMA-2 LLMs. Example usages on 4 A800 by 1 node:

export NCCL_SOCKET_IFNAME=eno1
export MASTER_PORT=9909
start_time="$(date "+%Y-%m-%d-%H-%M-%S")"

deepspeed --master_addr "localhost" --master_port $MASTER_PORT \
    ./train/train.py \
    --deepspeed ./train/deepspeed_zero2.conf \
    --model_name_or_path "./models--meta-llama--Llama-2-7b-hf" \
    --model_max_length 4096 \
    --data_path ./data/Selective_Alpaca.json \
    --output_dir ./output_ckpt/llama2_7b_Selective_Alpaca \
    --bf16 True \
    --num_train_epochs 3 \
    --per_device_train_batch_size 32 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --save_strategy "steps" \
    --save_steps 500 \
    --evaluation_strategy "no" \
    --save_total_limit 999 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --gradient_checkpointing True
    

Mistral-7B

We can utilize the following configuration parameters to fine tune the Mistral LLMs. Example usages on 4 A800 by 1 node:

export NCCL_SOCKET_IFNAME=eno1
export MASTER_PORT=9909
start_time="$(date "+%Y-%m-%d-%H-%M-%S")"

deepspeed --master_addr "localhost" --master_port $MASTER_PORT \
    ./train/train.py \
    --deepspeed ./train/deepspeed_zero2.conf \
    --model_name_or_path "./models--Mistral-7B-hf" \
    --model_max_length 4096 \
    --data_path ./data/Selective_Alpaca.json \
    --output_dir ./output_ckpt/mistral_7b_Selective_Alpaca \
    --bf16 True \
    --num_train_epochs 1 \
    --per_device_train_batch_size 32 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --save_strategy "steps" \
    --save_steps 500 \
    --evaluation_strategy "no" \
    --save_total_limit 999 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --gradient_checkpointing True
    

Evaluation

We provide the scripts for running the evaluation of large language models on the open-instruction which has a list of standard benchmarks targeting the core capabilities. These benchmarks include:

  • MMLU
# Evaluating llama 7B model using 5 shot directly
python -m eval.mmlu.run_eval \
    --ntrain 5 \
    --data_dir data/eval/mmlu \
    --save_dir results/mmlu/llama-7B-5shot \
    --model_name_or_path ../hf_llama_models/7B \
    --tokenizer_name_or_path ../hf_llama_models/7B \
    --eval_batch_size 4 \
  • GSM
# Evaluating llama 7B model using chain-of-thought
python -m eval.gsm.run_eval \
    --data_dir data/eval/gsm/ \
    --max_num_examples 200 \
    --save_dir results/gsm/llama-7B-cot-8shot \
    --model ../hf_llama_models/7B \
    --tokenizer ../hf_llama_models/7B \
    --n_shot 8 \
  • BBH
# Evaluating llama 7B model using chain-of-thought
python -m eval.bbh.run_eval \
    --data_dir data/eval/bbh \
    --save_dir results/bbh/llama-7B-cot/ \
    --model ../hf_llama_models/7B \
    --tokenizer ../hf_llama_models/7B \
    --max_num_examples_per_task 40 \
  • TydiQA
# Evaluating llama 7B model, with gold passage provided
python -m eval.tydiqa.run_eval \
    --data_dir data/eval/tydiqa/ \
    --n_shot 1 \
    --max_num_examples_per_lang 100 \
    --max_context_length 512 \
    --save_dir results/tydiqa/llama-7B-goldp \
    --model ../hf_llama_model/7B \
    --tokenizer ../hf_llama_model/7B \
    --eval_batch_size 20 \
  • AlpacaEval
# Use V1 of alpaca farm evaluation.
export IS_ALPACA_EVAL_2=False

python -m eval.alpaca_farm.run_eval \
    --model_name_or_path ../checkpoints \
    --save_dir results/alpaca_farm/checkpoints/ \
    --eval_batch_size 20 \
    --use_chat_format \
    --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format

Acknowledgement

This project cannot be developed without the following resources:

Citation

If you find our work is useful to you, please cite our work:

@article{liu2024selectit,
   title={SelectIT: Selective Instruction Tuning for Large Language Models via Uncertainty-Aware Self-Reflection}, 
   author={Liangxin Liu and Xuebo Liu and Derek F. Wong and Dongfang Li and Ziyi Wang and Baotian Hu and Min Zhang},
   year={2024},
   journal={arXiv preprint arXiv:2402.16705},
}

selectit's People

Contributors

xinxinxing avatar blue-raincoat avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.