Giter Club home page Giter Club logo

lora-vit's Introduction

MeLo: Low-rank Adaptation is Better than Finetuning for Medical Image

Intro

Useful links

[Homepage]      [arXiv]     

Feature

  • Supported DeepLab segmentation for lukemelas/PyTorch-Pretrained-ViT. 2023-03-15
  • Supported timm. 2023-03-16
  • Supported multi-lora. 2023-11-15
  • Repo clean up.

Installation

Gii clone. My torch.__version__==1.13.0, other version newer than torch.__version__==1.10.0 should also work, I guess. You may also need a safetensors from huggingface to load and save weight.

Examples

You may find examples in examples.ipynb

Usage

You may use Vision Transformer from timm:

import timm
import torch
from lora import LoRA_ViT_timm
img = torch.randn(2, 3, 224, 224)
model = timm.create_model('vit_base_patch16_224', pretrained=True)
lora_vit = LoRA_ViT_timm(vit_model=model, r=4, alpha=4, num_classes=10)
pred = lora_vit(img)
print(pred.shape)

If timm is too complicated, you can use a simpler implementation of ViT from lukemelas/PyTorch-Pretrained-ViT. Wrap you ViT using LoRA-ViT, this a simple example of classifer

from base_vit import ViT
import torch
from lora import LoRA_ViT

model = ViT('B_16_imagenet1k')
model.load_state_dict(torch.load('B_16_imagenet1k.pth'))
preds = model(img) # preds.shape = torch.Size([1, 1000])

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"trainable parameters: {num_params}") #trainable parameters: 86859496


lora_model = LoRA_ViT(model, r=4, alpha=4, num_classes=10)
num_params = sum(p.numel() for p in lora_model.parameters() if p.requires_grad)
print(f"trainable parameters: {num_params}") # trainable parameters: 147456

this an example for segmentation tasks, using deeplabv3

model = ViT('B_16_imagenet1k')
model.load_state_dict(torch.load('B_16_imagenet1k.pth'))
lora_model = LoRA_ViT(model, r=4, alpha=4)
seg_lora_model = SegWrapForViT(vit_model=lora_model, image_size=384,
                            patches=16, dim=768, n_classes=10)

num_params = sum(p.numel() for p in seg_lora_model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params/2**20:.3f}") # trainable parameters: 6.459

Save and load LoRA:

lora_model.save_lora_parameters('mytask.lora.safetensors') # save
lora_model.load_lora_parameters('mytask.lora.safetensors') # load

Performance

In M1 Pro, LoRA is about 1.8x~1.9x faster. python performance_profile.py should do the time profiler now. More test will come soon.

Citation

Use this bibtex to cite this repository:

@misc{zhu2023melo,
      title={MeLo: Low-rank Adaptation is Better than Fine-tuning for Medical Image Diagnosis}, 
      author={Yitao Zhu and Zhenrong Shen and Zihao Zhao and Sheng Wang and Xin Wang and Xiangyu Zhao and Dinggang Shen and Qian Wang},
      year={2023},
      eprint={2311.08236},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Credit

ViT code and imagenet pretrained weight come from lukemelas/PyTorch-Pretrained-ViT

lora-vit's People

Contributors

absterzhu avatar jamesqfreeman avatar zhaozh10 avatar zixuzhuang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

lora-vit's Issues

Why only tune query and value in every attention block?

Hello @JamesQFreeman,

Thanks for your great code for easy re-implementation of LoRA in ViT. It's very useful for us to adapt to our own task.

However, I have observed that only the query and value weights in the attention block are being tuned. I'm curious about the rationale behind this. In the original LoRA paper, the authors recommend tuning all four weights in the attention block (query, key, value, and projection linear layer) for better performance.

Is there a trade-off between computational efficiency and final performance that led to this decision? What's more, have you seen an improvement in results by including the key and projection linear layer weights for tuning in your own task?

Thank you, and I look forward to hearing from you.

about the difference of evaluating after training in normal method

Excuse me, I know lora a little, I use your method and I find it powerful in training. In my method, I save model as ".pth" after training and I use ".pth" to evaluating model but I find it worse than full params finetuning. So should I also "save_lora_parameters" after training and evaluate it by loading "save_lora_parameters" and ".pth"? Besides, does the parameters of "r" and "alpha" have an effect?

How can I set up alpha?

Thank you for your work and making it available!
I am not into the technical details of LoRA, but I know it accepts the rank and alpha as inputs. I see from the example you are providing that you specify r as input. but I don't see alpha is passed anywhere. so my question is what its current value? is it the same as rank? and how can I customize it for different values?

Cuda OOM even though < 1 million trainable parameters and no OOM with forward pass through ViT

Hi, I am able to successfully do forward passes through my ViT with no issues, but trying to train LoRA instantly crashes with

CUDA out of memory. Tried to allocate 1.50 GiB (GPU 0; 15.78 GiB total capacity; 13.19 GiB already allocated; 1.41 GiB free; 13.22 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Is it expected that training LoRA adapter weights would use significantly more resources than independently making predictions with the model and training 900,000 parameters? Is it because running a model in with torch.no_grad() saves more memory than shutting off requires_grad on each module?

Thank you for any advice

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.