Giter Club home page Giter Club logo

rlhf-reward-modeling's People

Contributors

guhfeng avatar haoxiang-wang avatar hendrydong avatar violet24k avatar weixiongust avatar yangrui2015 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

rlhf-reward-modeling's Issues

How do you implement SLic on pair_pm model?

Hi, thanks for uploading the code for pair_pm! Since in the blog, it seems that you are using SLiC for pair_pm models. In the directory of pair_pm, I can't find the code for using slic methods.

Low Safety Score for RM-Gemma-2B Model

Hi, I am trying to reproduce the evaluation results for eval_reward_bench_bt.py, and I use weqweasdas/RM-Gemma-2B as the ckpt which is passed to reward_name_or_path args. The script running well, but I found the evaluation results in Safety is not aligned with Table2. My Safety Score is 42.74, which is very different from 81.2.

And here are the complete results for evaluation.

      category                 subset  accuracy      n
0        chat        alpacaeval-easy  0.960000  100.0
1        chat      alpacaeval-length  0.926316   95.0
2        chat        alpacaeval-hard  0.978947   95.0
3        chat          mt-bench-easy  0.964286   28.0
4        chat           mt-bench-med  0.900000   40.0
5   chat-hard          mt-bench-hard  0.675676   37.0
6   chat-hard         llmbar-natural  0.750000  100.0
7   chat-hard  llmbar-adver-neighbor  0.358209  134.0
8   chat-hard   llmbar-adver-GPTInst  0.152174   92.0
9   chat-hard    llmbar-adver-GPTOut  0.382979   47.0
10  chat-hard    llmbar-adver-manual  0.173913   46.0
11     safety     refusals-dangerous  0.070000  100.0
12     safety     refusals-offensive  0.240000  100.0
13     safety   xstest-should-refuse  0.422078  154.0
14     safety  xstest-should-respond  0.940000  250.0
15     safety            donotanswer  0.257353  136.0
16  reasoning               math-prm  0.729306  447.0
17  reasoning                hep-cpp  0.804878  164.0
18  reasoning                 hep-go  0.810976  164.0
19  reasoning               hep-java  0.807927  164.0
20  reasoning                 hep-js  0.844512  164.0
21  reasoning             hep-python  0.795732  164.0
22  reasoning               hep-rust  0.804878  164.0
model: models/RM-Gemma-2B
Chat: 94.97
Chat Hard: 41.23
Safety: 42.74
Reasoning: 77.04

P.S. The evaluation results for FsfairX-LLaMA3-RM-v0.1 is good.

I am wondering is this the right ckpt for evalutaion? Thanks in advance!

Training and evaluating for pair_pm model.

Hi,

I have replicated the training and evaluation for the pair_rm model, but I haven't achieved the results reported in Table 2 of the paper. The best results I obtained were with pm_models/llama3-8b-it_bs128_lr1e-5/checkpoint-1306:

Chat: 63.55
Chat Hard: 63.27
Safety: 82.59
Reasoning: 53.53
The main difference I've noticed in your script is that the base_model in your pair_pm/llama3-8b-it.yaml is /home/wx/axtool/models/llama3_it_with_padding_token. However, I couldn't find this model on Hugging Face or anywhere else. Therefore, I trained the pair_pm with meta-llama/Meta-Llama-3-8B-Instruct.

Another difference is in eval_reward_bench_pm.py. Similarly, you are using /home/cyeab/axtool/models/llama3_it_427_update for tokenizer and tokenizer_plain, while I used meta-llama/Meta-Llama-3-8B-Instruct instead.

Could you please share the llama3_it_with_padding_token and llama3_it_427_update models with me? Additionally, could you provide details on how you trained them?

Thank you!

question of chat templates

nice work! starred already.
sorry for asking, why replacing the bos_token with empty string?

sample['positive'] = tokenizer.apply_chat_template(
        sample['chosen'], tokenize=False, add_generation_prompt=False).replace(tokenizer.bos_token, "")
sample['negative'] = tokenizer.apply_chat_template(
    sample['rejected'], tokenize=False, add_generation_prompt=False).replace(tokenizer.bos_token, "")

Cannot run the training script

Thank you for sharing your scripts for training reward models based on mistral-7b and gemma-2b! May I ask which version of transformers and accelerate are you using? I followed the instruction and built the environment from alignment-handbook. However, when I use their latest version (transformers==4.39.0 and accelerate=0.27.2), I keep getting the following errors when using deepspeed

ValueError: DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`.

But in the training script we didn't set device_map so it should be None.

And when running the script without deepspeed, I got another error:

File "/p/anaconda3/envs/reward-model/lib/python3.10/site-packages/accelerate/utils/dataclasses.py", line 1061, in set_auto_wrap_policy
    raise Exception("Could not find the transformer layer class to wrap in the model.")

Both errors will gone once I switch the transformers library back to 4.36.2 and accelerate library back to 0.23.0 (which are used by alignment-handbook before Mar 1st, 2024). But in this case transformers won't support gemma...

Bradley-Terry model removes lm head while saving

Hello and thanks for your work!

While running bradley-terry-rm/llama3_rm.py the final saved model does not have a lm head as the script is using a AutoModelForSequenceClassification model and not CausalLM. Because of this the lm head is initialized from scratch while loading the saved model. Is this correct or do you manually add the lm head weights?

quesion about the output

I'm using your llama3 script with deepspeed3 to train a reward model, and I set save_every_steps to 20.

However I'm confused about those files in the output, for example in checkpoint-40, there are 4 safetensors: model-00001-of-00004.safetensors, model-00002-of-00004.safetensors,model-00003-of-00004.safetensors,model-00004-of-00004.safetensors. Are those safetensors the fine-tuned checkpoint?

If so, what is in the global_step40? there are some pt files include: bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt zero_pp_rank_0_mp_rank_00_model_states.pt zero_pp_rank_2_mp_rank_00_model_states.pt
bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt zero_pp_rank_1_mp_rank_00_model_states.pt zero_pp_rank_3_mp_rank_00_model_states.pt.

Can I just delete those pt files as they take up a huge space.

KeyError: 'input_ids_j' in training

Thanks for your wonderful project!

I got an error during training with custom dataset:

 11%|██████████████                                                     | 125/1125 [1:04:25<8:30:35, 30.64s/it]
Traceback (most recent call last):
    "input_ids": feature["input_ids_j"],    "input_ids": feature["input_ids_j"],

KeyErrorKeyError: : 'input_ids_j'

There is no empty value in my chosen and rejected data. I don't understand why it happens.

How to construct new pairs for adding to the dataset

Thanks for your great work! I have some questions about how to add new pairs to the training set.
According to your paper:

Instead of fixing ... , we take $\pi_t^1$ and $\pi_t^2$ as the best-of-8 policy and worst-of-8 policy induced by $\pi_t^{MLE}$

In my view, the norm process is as follow:
Suppose we have a DPO model after training in the $t$-th iteration as $\pi_t^{MLE}$.
We use different temperatures (0.7 for $\pi_t^1$ and 1.0 for $\pi_t^2$ to allow more exploration).
Then, we sample the best-of-8 for $\pi_t^2$ and rank them using the Reward Model.
Finally, we use the top-1 of this set and the generation result in $\pi_t^1$ as $(a_{t,i}^1, a_{t,i}^2)$.

But how do I use best-of-8 and worst-of-8 to construct pair like that?

Cannot understant the code at README.md of pair-pm

Hi, I'm very confused by your code.

logit_A = output.logits[0, -1, token_id_A].item()
logit_B = output.logits[0, -1, token_id_B].item()

I guess the last dimension of the tensor output.logits would be same as the vocab_size.
But why token_id_A-th and token_id_B-th index element matter? (they are 32 and 33 respectively, seemingly random to me).

Also, I found that avg_prob_chosen (=np.mean(probs_chosen)) is exactly 0.5. How this can be?
Since the positions of two responses change as we iterate through for chosen_position in [0, 1]:, the sum of the probs is very unlikely to be 1, but it is...

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.