Giter Club home page Giter Club logo

rfsfreitas / rwkv-lm-lora Goto Github PK

View Code? Open in Web Editor NEW

This project forked from picocreator/rwkv-lm-lora

0.0 0.0 0.0 11.19 MB

RWKV is a RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.

License: Apache License 2.0

C++ 1.59% Python 91.53% Cuda 6.88%

rwkv-lm-lora's Introduction

LoRA fork of RWKV-LM

A RWKV-LM fork, added with LoRA finetuning support. Currently only RWKV-v4neo is supported. The LoRA module is self-implemented to work with the TorchScript JIT. Existing RWKV-v4neo models/checkpoints should work out of the box. Now only LoRA-finetuned weights are checkpointed during training: it provides much smaller checkpoints, but you now need to specify the base model to use it. See args.MODEL_LOAD and args.MODEL_LORA in RWKV-v4neo/chat.py.

To finetune an existing model with LoRA, just work like full finetuning but with the LoRA options, in the directory RWKV-v4neo:

python3 train.py \
  --load_model <pretrained base model> \
  --proj_dir <place to save checkpoints> \
  --data_file <data for finetune> \
  --data_type <data type for finetune> \
  --vocab_size 50277 --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 --micro_bsz 2 --n_layer 24 --n_embd 1024 --pre_ffn 0 --head_qk 0 --lr_init 1e-4 --lr_final 1e-4 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0 \ # all your familiar options
  --lora --lora_r 8 --lora_alpha 16 --lora_dropout 0.01 \
  --lora_load <lora checkpoint to continue training> \ # optional
  --lora_parts=att,ffn,time,ln # configure which parts to finetune

The r, alpha and dropout options are up to your choice. The att, ffn, time and ln refers to the TimeMix, ChannelMix, time decay/first/mix parameters, and layernorm parameters; DON'T FORGET to add the set of parameters to be finetuned here. I'm still experimenting with different configurations; your experience is also welcomed!

To use the finetuned model, use chat.py as usual with the checkpoints in your specified proj_dir, but remember to align the LoRA-corresponded options with what you have specified during training!

args.MODEL_LORA = 'your_lora_checkpoint.pth'
args.lora_r = 8
args.lora_alpha = 32

TODOs

  • Seperate model merging to allow LoRA pretrained models to be used with other RWKV inference implementation (especially ChatRWKV)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.