Giter Club home page Giter Club logo

Comments (9)

1moye avatar 1moye commented on August 24, 2024

补充下,是8卡A100训练

from badam.

Ledzy avatar Ledzy commented on August 24, 2024

方便提供一下模型、loss曲线和LoRA rank的信息吗?

from badam.

1moye avatar 1moye commented on August 24, 2024

方便提供一下模型、loss曲线和LoRA rank的信息吗?

没跑完 根据时间看的 badam 在llamafactory中选择的是全参数训练
lora我是选择了固定线性层
badam

method

stage: pt
do_train: true
finetuning_type: full
use_badam: true
badam_switch_mode: ascending
badam_switch_interval: 50
badam_verbose: 2
deepspeed: examples/deepspeed/ds_z3_config_badam.json
badam_mode: layer

lora

method

stage: pt
do_train: true
finetuning_type: lora
lora_target: q_proj,k_proj,v_proj,o_proj,up_proj,gate_proj,down_proj

补充一下
即便我同样的数据集 全参数full 二次预训练 也小于当前badam的计时时间

这是badam 当前的时间
| 1013/6057 [21:06:32<78:38:17, 56.13s/it]
lora
大概是34小时
全参数也就是不到4天

模型是qwen1.5 -7B

from badam.

Ledzy avatar Ledzy commented on August 24, 2024

感谢反馈,1.2.2之前的版本在backward时的确有一些无用的计算,请安装最新版本BAdam:
pip install badam==1.2.2

以下是在2张RTX-3090上ZeRO-3的测试时间(跑了10个step):

  • BAdam ascending mode: 10/93 [03:31<29:01, 20.98s/it]
  • BAdam descending mode: 10/93 [01:58<16:19, 11.81s/it]
  • LoRA rank 100: 10/93 [03:34<29:09, 21.07s/it]

BAdam active block越靠近输入层,时间越长。因此初期ascending(从前往后更新)会比较慢,但随着active block靠近输出层,backward cost减少,单次迭代会变快。平均时间约为更新第一层和最后一层所需时间的均值,即16.5s左右,要比LoRA更快。

注意到在分布式训练时BAdam节省的时间比例要比论文中低,这是因为显卡通信花费了许多时间,因而BAdam节省的反向传播的计算时间占比变低了。

脚本如下:

export CUDA_VISIBLE_DEVICES=0,1
export USE_MODELSCOPE_HUB=1

# BAdam
llamafactory-cli train \
    --stage sft \
    --do_train True \
    --model_name_or_path LLM-Research/Meta-Llama-3-8B \
    --preprocessing_num_workers 16 \
    --finetuning_type full \
    --quantization_method bitsandbytes \
    --template default \
    --flash_attn auto \
    --dataset_dir data \
    --dataset alpaca_en_demo \
    --cutoff_len 1024 \
    --learning_rate 5e-05 \
    --num_train_epochs 3.0 \
    --max_samples 100000 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --lr_scheduler_type cosine \
    --max_grad_norm 1.0 \
    --logging_steps 5 \
    --save_steps 100 \
    --warmup_steps 0 \
    --optim adamw_torch \
    --packing False \
    --report_to none \
    --use_badam True \
    --output_dir saves/LLaMA3-8B/full/temp \
    --bf16 True \
    --plot_loss True \
    --ddp_timeout 180000000 \
    --include_num_input_tokens_seen True \
    --badam_mode layer \
    --badam_switch_mode descending \
    --badam_switch_interval 50 \
    --badam_update_ratio 0.05 \
    --deepspeed cache/ds_z3_config.json 

# LoRA
llamafactory-cli train \
    --stage sft \
    --do_train True \
    --model_name_or_path LLM-Research/Meta-Llama-3-8B \
    --preprocessing_num_workers 16 \
    --finetuning_type lora \
    --quantization_method bitsandbytes \
    --template default \
    --flash_attn auto \
    --dataset_dir data \
    --dataset alpaca_en_demo \
    --cutoff_len 1024 \
    --learning_rate 5e-05 \
    --num_train_epochs 3.0 \
    --max_samples 100000 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --lr_scheduler_type cosine \
    --max_grad_norm 1.0 \
    --logging_steps 5 \
    --save_steps 100 \
    --warmup_steps 0 \
    --optim adamw_torch \
    --packing False \
    --report_to none \
    --output_dir saves/LLaMA3-8B/lora/temp \
    --bf16 True \
    --plot_loss True \
    --ddp_timeout 180000000 \
    --include_num_input_tokens_seen True \
    --lora_rank 100 \
    --lora_alpha 200 \
    --lora_dropout 0 \
    --lora_target all \
    --deepspeed cache/ds_z3_config.json

from badam.

1moye avatar 1moye commented on August 24, 2024

我在做的是PT阶段 不是SFT
现在已经更改了
badam是1.2.2 版本
这是现在的结果
lora all参数的
0%|▏ | 2/2019 [01:57<32:50:42, 58.62s/it]
badam full参数
1%|▉ | 49/8076 [1:17:24<211:39:01, 94.92s/i

然后这是full 预训练
0%|▏ | 3/2019 [04:01<45:43:10, 81.64s/it]
所以badam 为什么时间这么长
模型是qwen1.5-7B

from badam.

Ledzy avatar Ledzy commented on August 24, 2024

麻烦提供一下运行命令。

from badam.

1moye avatar 1moye commented on August 24, 2024

运行命令
CUDA_VISIBLE_DEVICES =0,1,2,3,4,5 llamafactory-cli train examples/extras/badam/qwen_badam_pt.yaml

model_name_or_path:qwen1.5-7B路径

阶段: PT
do_train: 真
finetuning_type: 全
use_badam: 真
badam_switch_mode: 升badam_switch_interval
: 50
badam_verbose: 2
DeepSpeed: Examples/DeepSpeed/ds_z3_config_badam.json
badam_mode: Layer

dataset: 同样的数据集
template: empty
cutoff_len: 4096

overwrite_cache: true
preprocessing_num_workers: 16

output_dir: saves/qwen1.5-7b/badam/pt
logging_steps: 10
save_steps: 1700
plot_loss: true
overwrite_output_dir: true

per_device_train_batch_size:4
gradient_accumulation_steps:2
learning_rate:1.0e-6
num_train_epochs:3.0
lr_scheduler_type:余弦
warmup_ratio:0.1

val_size:0.1
per_device_eval_batch_size:1
eval_strategy:步骤
eval_steps:1700

from badam.

Ledzy avatar Ledzy commented on August 24, 2024
  1. 和LoRA,full对比时间,请务必确保使用相同的6张卡,或显卡之间的通信拓扑一致。通信拓扑的差异会导致通信时间不同。参考nvidia-smi topo -m.
  2. 前面已提到,badam_switch_mode为ascending时,初期BP cost较高因此迭代时间长。设置为descending可快速看到单次迭代时间的削减。一般建议该值设置为random。

from badam.

1moye avatar 1moye commented on August 24, 2024

1.使用的是相同的6张卡
2 针对第二种情况 我会在做一下测试

from badam.

Related Issues (11)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.