Giter Club home page Giter Club logo

reproduction-of-sherpherd-paper's Introduction

Shepherd

shepherd Logo

Generated by DALL·E 3

介绍

image

PRM 数据自动标注,通过在每个步骤继续往下生成多个结果来判断该步骤是否正确,例如在一个步骤基础上继续生成3个结果,如果其中生成的结果中有正确的,则认为这个步骤是正确的,否则将这个步骤标注为错误。 目前实现的是Gsm8k和Math数据集。

运行SFT

bash ./Shepherd/SFT/scripts/run_sft.sh
  • 将Gsm8K和Math数据集中的/n和/n/n换成<step>方便Vllm在此处暂停

运行PRM数据生成

bash ./Shepherd/scripts/vllm_shepherd.sh

运行ORM

bash ./Shepherd/scripts/run_orm_train.sh
bash ./Shepherd/scripts/run_orm_infer.sh
  • 在Gsm8k或者Math上训练一个epoch,之后在Gsm8k和Math上每个prompt生成15个response,并且通过规则匹配的方式对所有的respons打上0、1标签。得到上述数据集后再进行RM的训练。

运行PRM

bash ./Shepherd/scripts/run_prm_train.sh
bash ./Shepherd/scripts/run_prm_infer.sh
  • 同样是在Gsm8K或者Math上训练一个epoch。推理的时候每个prompt生成到<step>标志时会继续往下生成3个结果,通过规则匹配判断这三个结果中是否有正确的结果,如有正确的结果,则将改<step>标注为1,反之标注为0.

实验结果汇总

MATH和Gsm8上rerank实验结果

image

  • 在Gsm8k和Math上进行rerank的实验结果

MetaMath上rerank实验结果

image

  • 在MetaMath上进行rerank的实验结果

当前的实验结果还没有超过ORM,但是使用论文中提供的数据集进行训练小幅超过了ORM,可以改进的地方有:

  1. 增加rollout的数量,目前是3,提高数量的话对这个步骤是否能得到正确的答案估计会更加准确
  2. 增加prompt对应response的数量,这个方法可能增益不会太明显,因为当前已经使用了足够多的数据量。
  3. 将最后一个正确步骤前面的所有步骤标注为1,这样比较的话与人工标注方法会更加一致。
  4. 优化生成方法,将生成的方法建成树,通过遍历和回溯来给所有step标注,这样标注的效率会高很多
  5. 参考Core的方法进行蒙特卡罗树搜索,训练token level和solution level的rm对所有的推理路径进行打分,最后通过卡阈值对所有的step进行标注。
  6. 添加api实现

个人认为这种方法应该是比不过GPT-4进行步骤标注的acc,但是目前复现还不完善,暂时没有进行对比。

ORM和PRM都是在Gsm8k和Math数据集上得到的(不是MetaMath上生成的结果)

learning curve

image

Gsm8k-orm Learning curve

image

Math-orm Learning curve

image

Gsm8k-prm Learning curve

参考文献

reproduction-of-sherpherd-paper's People

Contributors

yhc582825016 avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.