Giter Club home page Giter Club logo

sppo's Introduction

SPPO: Self-Play Preference Optimization for Language Model Alignment

Mistral-7B-Instruct Llama-3-8B-Instruct AlpacaEval 2.0 Open LLM MT-Bench

This repository contains the official code and released models for the paper Self-Play Preference Optimization for Language Model Alignment.

Authors: Yue Wu*, Zhiqing Sun*, Huizhuo Yuan*, Kaixuan Ji, Yiming Yang, Quanquan Gu

[Webpage] [Huggingface] [Paper]

๐Ÿ”” News

Table of Content

About SPPO

We propose a new self-play framework dubbed SPPO for language model alignment and a new learning objective (called SPPO loss) derived from the self-play framework to fine-tune large language models efficiently.


AlpacaEval 2.0 leaderboard results of normal and length-controlled (LC) win rates in percentage (\%). Mistral-7B-SPPO can outperform larger models and Mistral-7B-SPPO (best-of-16) can outperform proprietary models such as GPT-4(6/13). Llama-3-8B-SPPO exhibits even better performance.

SPPO can significantly enhance the performance of an LLM without strong external signals such as responses or preferences from GPT-4. It can outperform the model trained with iterative direct preference optimization (DPO), among other methods. SPPO is theoretically grounded, ensuring that the LLM can converge to the von Neumann winner (i.e., Nash equilibrium) under general, potentially intransitive preference, and empirically validated through extensive evaluations on multiple datasets.

For more details, you can check our paper here.

Base Models and Released Models

Model AlpacaEval2.0 LC Win Rate AlpacaEval2.0 Win Rate
๐Ÿค—Mistral-7B-Instruct-v0.2 17.11 14.72
๐Ÿค—Mistral-7B-SPPO Iter1 24.79 23.51
๐Ÿค—Mistral-7B-SPPO Iter2 26.89 27.62
๐Ÿค—Mistral-7B-SPPO Iter3 28.53 31.02
๐Ÿค—Llama-3-8B-Instruct 22.92 22.57
๐Ÿค—Llama-3-8B-SPPO Iter1 31.73 31.74
๐Ÿค—Llama-3-8B-SPPO Iter2 35.15 35.98
๐Ÿค—Llama-3-8B-SPPO Iter3 38.77 39.85
๐Ÿค—Gemma-2-9B-It 45.08 35.62
๐Ÿค—Gemma-2-9B-SPPO Iter1 48.70 40.76
๐Ÿค—Gemma-2-9B-SPPO Iter2 50.93 44.64
๐Ÿค—Gemma-2-9B-SPPO Iter3 53.27 47.74

Environment Setup

Our training code is based on the alignment-handbook codebase. We utilize vllm for generation and pairRM for ranking. Follow the steps below to set up your environment:

  1. Create a Virtual Environment:

    conda create -n sppo python=3.10
    conda activate sppo
  2. Install vllm for Generation:

    pip install vllm
  3. Install PairRM:

    git clone https://github.com/yuchenlin/LLM-Blender.git
    cd LLM-Blender
    pip install -e .
  4. Download and Install Training Dependencies:

    git clone https://github.com/uclaml/SPPO.git
    cd SPPO
    pip install -e .

Training Scripts

Execute the training scripts based on the base model you choose:

  • For Mistral-7B-Instruct-v0.2:

    bash run_sppo_mistral.sh
  • For Llama-3-8B-Instruct:

    bash run_sppo_llama-3.sh

These scripts manage the training iterations, generation, and PairRM ranking processes. Note that some scripts may attempt to push datasets to the Hugging Face Hub under the UCLA-AGI organization. Ensure you have write access, or modify the organization name accordingly, or comment out any push_to_hub commands if necessary. Detailed scripts for each component are listed as follows:

Breakdown of Scripts:

  1. Generation:
    python scripts/generate.py --model $MODEL --maxlen 2048 --output_dir $OUTPUT_DIR --prompts $PROMPTS

Main parameters:

  • model: Specifies the model used for generation. In the first iteration, the model should be either mistralai/Mistral-7B-Instruct-v0.2 or meta-llama/Meta-Llama-3-8B-Instruct.
  • maxlen: Sets the token length for generation, defining the maximum number of tokens generated.
  • pairs: Determines the number of generated samples per prompt, with a default setting of 5. Please note that changing this number is not supported by the overall pipeline.
  • output_dir: Specifies the directory paths for saving intermediate results.
  • prompts: Defines the set of prompts used for generation.
  • frac_len: Enables the operation of vllm on multiple GPUs by dividing prompts into different fractions. frac_len defines the number of prompts in each fraction. For usage examples, see generate.sh.
  • data_frac: Used in conjunction with frac_len for multi-GPU setups, data_frac indicates which fraction of the data the current GPU is processing. Refer to generate.sh for more details.
  1. Ranking:
    python scripts/rank.py --output_dir $OUTPUT_DIR --prompts $PROMPTS

Main Parameters:

  • output_dir: Specifies the directory paths where intermediate results are saved. Note that the default script attempts to push datasets to Hugging Face under the UCLA-AGI organization. You may need to adjust this to your organization, obtain write access for UCLA-AGI, or disable the push_to_hub command if necessary.
  • pairs: Sets the number of generated samples per prompt, with a default of 5. Please note that other numbers are not supported by the overall pipeline.
  • frac_len: This parameter is used to enable the use of PairRM on multiple GPUs by dividing prompts into different fractions. frac_len determines the number of prompts in each fraction. For usage examples, refer to generate.sh.
  • data_frac: Similar to frac_len, this option is used for running PairRM on multiple GPUs. It specifies which fraction of the data the current GPU is processing. See generate.sh for examples.
  • prompts: Defines the set of prompts used for generation.
  • gpu: Indicates the GPU index used for ranking; it should match the data_frac parameter.
  1. Training:
    bash scripts/pipeline.sh --model $MODEL --iter $ITER --dataset $DATASET --output_dir $OUTPUT_DIR --num 1

Main Parameters:

  • model: The base model for training.
  • dataset: The dataset used for training.
  • output_dir: The name of the output model.
  • num: The number of training epochs.

Evaluation

We adhere to the established guidelines for evaluation and utilize the following repositories:

We provide the model configurations used during AlpacaEval 2 in the models_configs directory. Please note that after the initial release of our model, we retrained it using a slightly modified prompt. The win rates observed post-retraining are comparable to the original results.

Troubleshoot

For questions related to the paper, please contact the authors via email. If you encounter any issues with the code or wish to report a bug, feel free to open an issue on our GitHub repository.

Citation

@article{wu2024self,
  title={Self-play preference optimization for language model alignment},
  author={Wu, Yue and Sun, Zhiqing and Yuan, Huizhuo and Ji, Kaixuan and Yang, Yiming and Gu, Quanquan},
  year={2024}
}

Acknowledgements

We thank the authors of The Alignment Handbook for their foundational contributions to the training code. We also acknowledge the use of PairRM for ranking and vllm for generation.

sppo's People

Contributors

angelahzyuan avatar eltociear avatar sanowl avatar xiaohangt avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

sppo's Issues

Some packages' version are too old

When use accelerate==0.23.0 in setup.py, it got the following error:
Accelerator.__init__() got an unexpected keyword argument use_seedable_sampler
When upgrade accelerate to 0.31.0, this error fixed.

Is it normal the pipeline start with a huge loss ?

step 10:
{'loss': 119743.8516, 'grad_norm': 938286.7284407256, 'learning_rate': 2.0161290322580643e-09, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/rejected': -128.30323791503906, 'logps/chosen': -178.66146850585938, 'logits/rejected': -0.7681801915168762, 'logits/chosen': -0.792536735534668, 'epoch': 0.0}

step 20:
{'loss': 119688.3056, 'grad_norm': 1090985.982531398, 'learning_rate': 2.0161290322580644e-08, 'rewards/chosen': -8.749030530452728e-05, 'rewards/rejected': 0.00024323315301444381, 'rewards/accuracies': 0.2222222238779068, 'rewards/margins': -0.00033072344376705587, 'logps/rejected': -102.9691390991211, 'logps/chosen': -104.48147583007812, 'logits/rejected': -0.30933287739753723, 'logits/chosen': -0.3230978548526764, 'epoch': 0.0}

step 30:
{'loss': 122734.3, 'grad_norm': 677227.7630694123, 'learning_rate': 4.032258064516129e-08, 'rewards/chosen': -0.00015188578981906176, 'rewards/rejected': 3.675480911624618e-05, 'rewards/accuracies': 0.20000000298023224, 'rewards/margins': -0.00018864059529732913, 'logps/rejected': -132.24008178710938, 'logps/chosen': -116.12632751464844, 'logits/rejected': -0.4473434388637543, 'logits/chosen': -0.4207238554954529, 'epoch': 0.01}

I am surprised at such a huge loss, is this normal ?

Is it possible to run llama 3-70B and/or mixtral 8x22b through this process?

I'm running the Llama-3-Instruct-8B-SPPO-Iter3 model locally and am very impressed by the improved quality from the original model. I can't help but wonder what the results would be if this finetuning process were run on larger models.

Is it possible to run the code on these larger models, or are the smaller versions too different form their larger counterparts; requiring a rework of the training scripts?

Thank you for what you have contributed, this is great stuff!

Questions about the training code

Thank you for sharing your code and making it open source.

But I noticed that the training code you used still uses the DPO trainer, which seems inconsistent with the SPPO in the paper. Can you explain why you use DPO instead of the suggested SPPO in your code?

Which version of vllm should be installed

Hi, when I follow the default steps to set up environment:
pip install vllm
it will automaticly install vllm 0.5.0.post1, and transformers>=4.40.0 is required.

When installing SPPO ( transformers==4.36.2 are required), I got the following errors:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
vllm 0.5.0.post1 requires tokenizers>=0.19.1, but you have tokenizers 0.15.2 which is incompatible.
vllm 0.5.0.post1 requires transformers>=4.40.0, but you have transformers 4.36.2 which is incompatible.

Should I degrade the vllm version or ignore this error, how could I fix this error?

Good work

Hi! ๐Ÿ˜Š Your SPPO project caught my eye. Amazing work on this python repository! โœจ Could you send me more details on Telegram? Also, please review my work and follow me on GitHub @nectariferous. Thanks!

ShareGPT appending

How is the ShareGPT format handled with this workflow? I'm currently developing a dataset that could be greatly benefited from this technique. However, I hate training on "User" and "Assistant" tokens. It goes against my intentions when working with language models. With Axolotl, there's a way to change the header IDs for sharegpt datasets. I was wondering if there was something similar I could do here, or perhaps I could just do some data processing to change the format...

ConnectionError: Couldn't reach 'synthetic_data_llama-3-8b-instruct-sppo-iter3_score' on the Hub (ConnectionError)

Great work!
I commented all the push_to_hub in the code. Is synthetic_data_llama-3-8b-instruct-sppo-iter3_score dataset generated by PairRM?

[rank4]: Traceback (most recent call last):
[rank4]: File "/training-data/huangxing/software/SPPO/sppo/run_dpo.py", line 249, in
[rank4]: main()
[rank4]: File "/training-data/huangxing/software/SPPO/sppo/run_dpo.py", line 43, in main
[rank4]: main_inner(model_args, data_args, training_args)
[rank4]: File "/training-data/huangxing/software/SPPO/sppo/run_dpo.py", line 78, in main_inner
[rank4]: raw_datasets = get_datasets(data_args, splits=["train"])
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/training-data/huangxing/software/SPPO/sppo/alignment/data.py", line 164, in get_datasets
[rank4]: raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/training-data/huangxing/software/SPPO/sppo/alignment/data.py", line 189, in mix_datasets
[rank4]: dataset = load_dataset(ds, split=split)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/training-data/software/miniconda3/envs/mcts/lib/python3.11/site-packages/datasets/load.py", line 2129, in load_dataset
[rank4]: builder_instance = load_dataset_builder(
[rank4]: ^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/training-data/software/miniconda3/envs/mcts/lib/python3.11/site-packages/datasets/load.py", line 1815, in load_dataset_builder
[rank4]: dataset_module = dataset_module_factory(
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/training-data/software/miniconda3/envs/mcts/lib/python3.11/site-packages/datasets/load.py", line 1512, in dataset_module_factory
[rank4]: raise e1 from None
[rank4]: File "/training-data/software/miniconda3/envs/mcts/lib/python3.11/site-packages/datasets/load.py", line 1468, in dataset_module_factory
[rank4]: raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({type(e).name})")
[rank4]: ConnectionError: Couldn't reach 'synthetic_data_llama-3-8b-instruct-sppo-iter3_score' on the Hub (ConnectionError)

Question about SPPO

Hi! ๐Ÿ˜Š Your SPPO project caught my eye. Amazing work on this python repository! โœจ Could you send me more details on Telegram? Also, please review my work and follow me on GitHub @nectariferous. Thanks!

Adaptation for 4-bit Quantization Training/Responses Generation (with 2 Home GPUs)

Hello everyone!

For those who want to run this algorithm in your own home lab, I am working on adapting this code to run on my setup:

  • AMD Ryzen 9 7950X
  • 2x 32GB DDR5 @ 6000 MHz
  • 2x RTX 4090 24GB (48GB total)

If you're interested in testing and contributing, you can find my progress on my fork of this repository:
https://github.com/kaykyr/SPPO

Once I accomplish this objective, I intend to fully refactor this repository to allow both full-precision training on multi-GPUs and half-precision training on home GPUs with an easy-to-use script.

I appreciate all the contributions to the papers and the original code and will be even more grateful for any contributions from the community.

Thank you!

Ranking speed & training hyperparameters

I'm trying to replicate the Llama-3-8B setup with one of our custom finetunes and stumbled across some questions:

The ranking model gets called with a batch size of 1 and increasing that didn't seem to make the ranking any faster. With my current setup the ranking takes longer than the actual training. Is there a way to speed up the ranking part of the pipeline?

You mention in the paper that you train for 18 epochs per iteration, usually instruction tuning is done with a single epoch since the models can overfit on the data pretty quickly. Did you really end up training each iteration for 18 epochs and didn't that lead to massive overfitting?

Could you provide some loss numbers for iter 1/2/3 just so people have a number to compare their runs to? The loss seems very high but i'm not sure how the numbers are supposed to look like since SPPO uses a custom loss function.

Overall a pretty nice pipeline that you built with the iterative generation->ranking->training setup

Scores and probability calcuations

prb[i][j] = 1 / (1 + np.exp(score[j] - score[i]))

From my understanding of the code, the score list here is the output from the blender.rank(*, return_scores=True) which should output the average relative score of the response in the index being better than other responses. Please correct me if wrong.

For example, given three responses, {y1, y2, y3}, the first element of the scores output by the blender model (s1, s2, 3) is, s1 = P(y1 > y2) + P(y1 > y3), disregarding the constant coefficient and P is general preference score function, not probability. [references from blender code and their paper]

Thus, subtracting two scores, i.e., s1 - s2, is also dependent on the third response y3 as well, which seems a bit different from what is described in the paper.

In summary, I feel it is more appropriate to use the score output from the blender with just two responses (although, I don't think this would make a significant difference in the performance), e.g.,

score = blender.rank([x], [[yj, yi]], return_scores=True)[0, 0]
prb[i][j] = 1 / (1 + np.exp(score))

(sorry for the badly coded example)

Any chance it work on my homelab?

Hello everyone,

I would like to know if these scripts are capable of running on a home lab setup with the following specifications:

  • 64GB RAM
  • 2x RTX 4090 GPUs

Thank you for your assistance.


I noticed that the scripts are configured to use 8 GPUs.

Dataset used and results in Gemma-2-9B results

Thanks for the great product.
I am so impressed with your research that I have tried it many times.
However, the results with Gemma-2-9B are very different from your results.

The score was even Iter-3 lower than the original Gemma2-9B-it.

My question is, what did you use,
UCLA-AGI/data-mistral-7b-instruct-sppo-iter[x]?

I am aware that these or others were based on UltraFeedBack, and the Github code was the same.

Sincelery, Kazuya

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.