Giter Club home page Giter Club logo

Comments (13)

sunzeyeah avatar sunzeyeah commented on June 10, 2024

目前RLHF代码已更新为基于DeepSpeedChat实现,示例可以参考train_rlhf.sh。将其中SFT_MODELREWARD_MODEL替换为对应模型即可。

但目前在多张 v100 16G上跑ChatGLM-6B-RLHF还是会报OOM,可以在更大显存的GPU上验证

from rlhf.

GUORUIWANG avatar GUORUIWANG commented on June 10, 2024

目前RLHF代码已更新为基于DeepSpeedChat实现,示例可以参考train_rlhf.sh。将其中SFT_MODELREWARD_MODEL替换为对应模型即可。

但目前在多张 v100 16G上跑ChatGLM-6B-RLHF还是会报OOM,可以在更大显存的GPU上验证

谢谢回复,我在尝试train_rlhf.sh 的时候出现 assert not args.enable_hybrid_engine, "DeepSpeed currently does not support Pangu-based or GLM-based model in hybrid engine"

将参数enable_hybrid_engine 调整为store_false后出现
│ /mnt/data/anaconda3/envs/coati2/lib/python3.9/site-packages/deepspeed/ │
│ runtime/config_utils.py:62 in init
│ │
│ 59 │ │ │ │ for k, │
│ 60 │ │ │ │ v in data.items() if (v != "auto" or k == "replace_met │
│ 61 │ │ │ } │
│ ❱ 62 │ │ super().init(**data) │
│ 63 │ │ self._deprecated_fields_check(self) │
│ 64 │ │
│ 65 │ def _process_deprecated_field(self, pydantic_config, field): │
│ │
│ in pydantic.main.BaseModel.init:341 │
╰──────────────────────────────────────────────────────────────────────────────╯
ValidationError: 1 validation error for DeepSpeedZeroConfig
memory_efficient_linear
extra fields not permitted (type=value_error.extra)

我的sh文件为:
#!/bin/bash

MODEL="chatglm-6B"
ROOT="/mnt/data"

DATR_DIR=$ROOT/LLM/RLHF_v1/data
MAIN=$ROOT/LLM/RLHF/src/train_rlhf.py
TOKENIZER_PATH=$ROOT/LLM/model/chatglm
ACTOR_MODEL_PATH=$ROOT/LLM/model/chatglm
CRITIC_MODEL_PATH=$ROOT/LLM/model/chatglm
#CRITIC_MODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$REWARD_MODEL
CRITIC_CHECKPOINT=$ROOT/LLM/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-ft-31-1e-5/checkpoint-10000/pytorch_model.bin
OUTPUT_DIR=$ROOT/LLM/RLHF/output
TRAIN_FILENAME="dev_data_external_v1.jsonl"
PRETRAIN_FILENAME="dev_data_external_v1.jsonl"

#python $MAIN
#accelerate launch --main_process_port 5007 --config_file $ACCELERATE_CONFIG $MAIN
deepspeed --num_gpus 8 $MAIN
--data_dir $DATR_DIR
--output_dir $OUTPUT_DIR
--tokenizer_path $TOKENIZER_PATH
--actor_model_path $ACTOR_MODEL_PATH
--critic_model_path $CRITIC_MODEL_PATH
--critic_checkpoint $CRITIC_CHECKPOINT
--max_length 512
--max_gen_length 256
--logging_steps 10
--do_train
--train_filename $TRAIN_FILENAME
--pretrain_filename $PRETRAIN_FILENAME
--actor_learning_rate 1e-5
--critic_learning_rate 1e-5
--lr_scheduler_type cosine
--train_batch_size 4
--ppo_train_batch_size 4
--gradient_accumulation_steps 16
--num_epochs 1
--ppo_epochs 1
--enable_hybrid_engine
--actor_zero_stage 3
--critic_zero_stage 3
--offload_reference_model
--actor_gradient_checkpointing
--critic_gradient_checkpointing
--release_inference_cache

from rlhf.

sunzeyeah avatar sunzeyeah commented on June 10, 2024

目前在deepspeed==0.9.1上验证过。如果deepspeed版本过低,可以把memory_efficient_linear那行注释掉

from rlhf.

GUORUIWANG avatar GUORUIWANG commented on June 10, 2024

目前在deepspeed==0.9.1上验证过。如果deepspeed版本过低,可以把memory_efficient_linear那行注释掉

我升级了deepspeed,但在跑的时候出现这个问题
│ /mnt/data/anaconda3/envs/coati2/lib/python3.9/site-packages/deepspeed/ │
│ runtime/zero/partition_parameters.py:1164 in partition_param │
│ │
│ 1161 │ │ │ if start < param.ds_numel and end <= param.ds_numel: │
│ 1162 │ │ │ │ src_tensor = one_dim_param.narrow(0, start, partition │
│ 1163 │ │ │ │ │
│ ❱ 1164 │ │ │ │ param.ds_tensor.copy
(src_tensor) │
│ 1165 │ │ │ │ #partitioned_tensor = src_tensor.clone().detach().to( │
│ 1166 │ │ │ │
│ 1167 │ │ │ else: │
╰──────────────────────────────────────────────────────────────────────────────╯
NotImplementedError: Cannot copy out of meta tensor; no data!

from rlhf.

taofennanhai avatar taofennanhai commented on June 10, 2024

我看代码好像时只能跑ChatGLM模型,GLM模型都不行,是这样吗

目前在deepspeed==0.9.1上验证过。如果deepspeed版本过低,可以把memory_efficient_linear那行注释掉

from rlhf.

GUORUIWANG avatar GUORUIWANG commented on June 10, 2024

目前RLHF代码已更新为基于DeepSpeedChat实现,示例可以参考train_rlhf.sh。将其中SFT_MODELREWARD_MODEL替换为对应模型即可。

但目前在多张 v100 16G上跑ChatGLM-6B-RLHF还是会报OOM,可以在更大显存的GPU上验证

您好,我想问下:1)需要什么样的环境可以跑SFT_MODEL和REWARD_MODEL都为chatglm 6b的rlhf流程,2)当SFT_MODEL为chatglm 6b,REWARD_MODEL为glm-350m,在v100 4*32g的stage 3出现了oom 3)能否提供一下您在训的时候,所需要的环境配置

from rlhf.

GUORUIWANG avatar GUORUIWANG commented on June 10, 2024

另外我尝试SFT_MODEL和REWARD_MODEL 为glm-350出现好多奇葩错误,请问大佬跑过什么样的组合(SFT_MODEL、REWARD_MODEL),我看不同组合代码貌似都需要调整(大佬给否给出一个SFT_MODEL为chatglm 6b,REWARD_MODEL为glm-350m的组合code)

我在测试 SFT_MODEL和REWARD_MODEL 为glm-350m出现:

│ /mnt/data/LLM/RLHF/src/train_rlhf.py:276 in main │
│ │
│ 273 │ │ │ │ # prompts = prompts[:, length - args.max_prompt_le │
│ 274 │ │ │ │ # raise ValueError("Prompt length is too long") │
│ 275 │ │ │ │ │
│ ❱ 276 │ │ │ │ out = trainer.generate_experience(batch_prompt) │
│ 277 │ │ │ │ exp_dataset = exp_mini_dataset.add(out) │
│ 278 │ │ │ │ │
│ 279 │ │ │ │ if exp_dataset is not None: │
│ │
│ /mnt/data/LLM/RLHF/src/models/trainer.py:1411 in generate_experience │
│ │
│ 1408 │ │ return outputs │
│ 1409 │ │
│ 1410 │ def generate_experience(self, inputs): │
│ ❱ 1411 │ │ self.eval() │
│ 1412 │ │ outputs = self._generate_sequence(inputs) │
│ 1413 │ │ self.train()

│ /mnt/data/LLM/RLHF/src/models/trainer.py:1568 in eval │
│ │
│ 1565 │ │ self.critic_model.train() │
│ 1566 │ │
│ 1567 │ def eval(self): │
│ ❱ 1568 │ │ self.actor_model.eval() │
│ 1569 │ │ self.critic_model.eval() │
│ 1570 │ │ self.reward_model.eval() │
│ 1571 │ │ self.ref_model.eval()

│ /home/admin/anaconda3/envs/chatglm/lib/python3.9/site-packages/deepspeed/run │
│ time/hybrid_engine.py:379 in eval │
│ │
│ 376 │ │ │ │ │ f'|Generate time={(self._generate_latency):.2f}s │
│ 377 │ │ │ │ │ f'|Training time={(self._training_latency):.2f}s │
│ 378 │ │ │ │ │ f'|Others={others:.2f} ({(others / latency * 100 │
│ ❱ 379 │ │ │ │ │ f'|CurSamplesPerSec={(1 / latency * self.total
│ 380 │ │ │ │ │ f'|AvgSamplesPerSec={(1 / (self._total_latency / │
│ 381 │ │ │ self._t_start = time.time() │
│ 382 │ │ self._training_latency = 0 │
╰──────────────────────────────────────────────────────────────────────────────╯

deepspeed 版本为0.9.2

from rlhf.

GUORUIWANG avatar GUORUIWANG commented on June 10, 2024

我看代码好像时只能跑ChatGLM模型,GLM模型都不行,是这样吗

目前在deepspeed==0.9.1上验证过。如果deepspeed版本过低,可以把memory_efficient_linear那行注释掉
glm 我尝试了 SFT_MODEL和REWARD_MODEL 为glm-350m,出现了上述问题,需要修改code

from rlhf.

MAJIN123 avatar MAJIN123 commented on June 10, 2024

另外我尝试SFT_MODEL和REWARD_MODEL 为glm-350出现好多奇葩错误,请问大佬跑过什么样的组合(SFT_MODEL、REWARD_MODEL),我看不同组合代码貌似都需要调整(大佬给否给出一个SFT_MODEL为chatglm 6b,REWARD_MODEL为glm-350m的组合code)

我在测试 SFT_MODEL和REWARD_MODEL 为glm-350m出现:

│ /mnt/data/LLM/RLHF/src/train_rlhf.py:276 in main │ │ │ │ 273 │ │ │ │ # prompts = prompts[:, length - args.max_prompt_le │ │ 274 │ │ │ │ # raise ValueError("Prompt length is too long") │ │ 275 │ │ │ │ │ │ ❱ 276 │ │ │ │ out = trainer.generate_experience(batch_prompt) │ │ 277 │ │ │ │ exp_dataset = exp_mini_dataset.add(out) │ │ 278 │ │ │ │ │ │ 279 │ │ │ │ if exp_dataset is not None: │ │ │ │ /mnt/data/LLM/RLHF/src/models/trainer.py:1411 in generate_experience │ │ │ │ 1408 │ │ return outputs │ │ 1409 │ │ │ 1410 │ def generate_experience(self, inputs): │ │ ❱ 1411 │ │ self.eval() │ │ 1412 │ │ outputs = self._generate_sequence(inputs) │ │ 1413 │ │ self.train()

│ /mnt/data/LLM/RLHF/src/models/trainer.py:1568 in eval │ │ │ │ 1565 │ │ self.critic_model.train() │ │ 1566 │ │ │ 1567 │ def eval(self): │ │ ❱ 1568 │ │ self.actor_model.eval() │ │ 1569 │ │ self.critic_model.eval() │ │ 1570 │ │ self.reward_model.eval() │ │ 1571 │ │ self.ref_model.eval()

│ /home/admin/anaconda3/envs/chatglm/lib/python3.9/site-packages/deepspeed/run │ │ time/hybrid_engine.py:379 in eval │ │ │ │ 376 │ │ │ │ │ f'|Generate time={(self._generate_latency):.2f}s │ │ 377 │ │ │ │ │ f'|Training time={(self._training_latency):.2f}s │ │ 378 │ │ │ │ │ f'|Others={others:.2f} ({(others / latency * 100 │ │ ❱ 379 │ │ │ │ │ f'|CurSamplesPerSec={(1 / latency * self.total │ │ 380 │ │ │ │ │ f'|AvgSamplesPerSec={(1 / (self._total_latency / │ │ 381 │ │ │ self._t_start = time.time() │ │ 382 │ │ self._training_latency = 0 │ ╰──────────────────────────────────────────────────────────────────────────────╯

deepspeed 版本为0.9.2

我也有这个需求,一直想找使用chatglm实现rlhf的,一直没找到,急急急急

from rlhf.

sunzeyeah avatar sunzeyeah commented on June 10, 2024

我看代码好像时只能跑ChatGLM模型,GLM模型都不行,是这样吗

目前在deepspeed==0.9.1上验证过。如果deepspeed版本过低,可以把memory_efficient_linear那行注释掉

最新的代码已经修复了GLM类的模型运行报错的问题,可以参考issue 14

from rlhf.

sunzeyeah avatar sunzeyeah commented on June 10, 2024

目前RLHF代码已更新为基于DeepSpeedChat实现,示例可以参考train_rlhf.sh。将其中SFT_MODELREWARD_MODEL替换为对应模型即可。
但目前在多张 v100 16G上跑ChatGLM-6B-RLHF还是会报OOM,可以在更大显存的GPU上验证

您好,我想问下:1)需要什么样的环境可以跑SFT_MODEL和REWARD_MODEL都为chatglm 6b的rlhf流程,2)当SFT_MODEL为chatglm 6b,REWARD_MODEL为glm-350m,在v100 4*32g的stage 3出现了oom 3)能否提供一下您在训的时候,所需要的环境配置

目前我能用到的就是8卡 v100 16G,最大能支持到的模型size是SFT和Reward均为350M,每卡的显存占用为14.6G左右。已验证SFT=2.6B+Reward=350M即会报错,SFT=6B+Reward=6B肯定也是OOM。

不过ChatGLM、GLM和Pangu类模型的代码,都用小型模型验证过了可以跑通

from rlhf.

sunzeyeah avatar sunzeyeah commented on June 10, 2024

不过目前RLHF的代码还有优化空间,因为没有像SFT和Reward训练代码一样,使用transformersTrainingArgumentsTrainer来启动deepspeed,而是直接通过deepspeed.initialize()来启动。目前看这样会影响多卡上的模型并行,即每卡显存占用没有随着卡数增加而减小,对比之下通过TrainingArgumentsTrainer是可以实现模型并行的。后续会针对这个问题再做优化

from rlhf.

GUORUIWANG avatar GUORUIWANG commented on June 10, 2024

不过目前RLHF的代码还有优化空间,因为没有像SFT和Reward训练代码一样,使用transformersTrainingArgumentsTrainer来启动deepspeed,而是直接通过deepspeed.initialize()来启动。目前看这样会影响多卡上的模型并行,即每卡显存占用没有随着卡数增加而减小,对比之下通过TrainingArgumentsTrainer是可以实现模型并行的。后续会针对这个问题再做优化

谢谢回复,原来如此,期待后续优化

from rlhf.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.