Giter Club home page Giter Club logo

flap's Introduction


FLAP

[AAAI 2024] Fluctuation-based Adaptive Structured Pruning for Large Language Models

image

Introduction

Fluctuation-based Adaptive Structured Pruning for Large Language Models [arXiv]
Yongqi An, Xu Zhao, Tao yu, Ming Tang, Jinqiao Wang
Institute of Automation, Chinese Academy of Sciences

Why FLAP:

  • No training required: Our method can obtain a better compressed LLM without any retraining.
  • Adaptive compression structure: Each module and layer has adaptive pruning ratio.
  • Efficient compression: 3 to 5 minutes on a single GPU, with no additional time required.
  • Better performance: Better performance on a variety of language benchmarks, with additional gains in specific task datasets.

Supported LLMs:

Table of Contents

Quick Start

Installation

Installation instructions can be found in INSTALL.md.

Minimal Example

bash script/llama_7b.sh $GPU_ID

This script would compress the LLaMA-7B model with ~20% parameters pruned by FLAP. All the pre-trained models and the dataset would be automatically downloaded, so you do not need to manually download the resource. When running this script for the first time, it will require some time to download the model and the dataset.

Configuration Instruction

Pruning

LLaMA-7B pruning with ~20% parameters pruned:

python main.py \
    --model decapoda-research/llama-7b-hf \
    --prune_method flap \
    --pruning_ratio 0.2 \
    --remove_heads -1 \
    --metrics WIFV \
    --structure AL-AM \
    --nsamples 1024 \
    --save_model "llm_weights/flap_p0.2_WIFV_ALAM_llama_7b/" \
    --eval \

Arguments:

  • --model: The identifier for the LLaMA model on the Hugging Face model hub. The model name is used for AutoModelForCausalLM.from_pretrained to load the pre-trained LLM. For example, if you want to use the LLaMA with 7 billion parameters, than pass decapoda-research/llama-7b-hf to --model.
  • --cache_dir: Directory for loading or storing LLM weights. The default is llm_weights.
  • --prune_method: We have implemented three pruning methods, as referenced in the paper, namely [flap, wanda_sp, mag_sp]. The default is flap.
  • --pruning_ratio: Denotes the percentage of weights to be pruned.
  • --remove_heads: How many heads should be removed, only used in UL-MM and AL-MM to manual the ratio of Self-attn and MLP.
  • --metrics: The pruning metric to choose, as referenced in the paper, namely [IFV, WIFV, WIFN, N/A]. The default is WIFV.
  • --structure: The global compressed model structure to choose, as referenced in the paper, namely [UL-UM, UL-MM, AL-MM, AL-AM]. The default is AL-AM.
  • --unstr: Whether to true prune the model or only mask the weight, default is False.
  • --eval: Whether to eval the model on Wikitext2 to calculate the perplexity, default is False.
  • --save_model: Specifies the directory where the pruned model will be stored.

After pruning and post-training, we follow lm-evaluation-harness for evaluation.

Language Modeling Evaluation

A brief quantitative language modeling performance for LLaMA-family:


Zero-shot Evaluation

A brief quantitative zero-shot performance results for LLaMA-7B:


More results can be found in the paper.

Acknowledgement

Citation

If you find this project useful, please cite

@misc{an2023fluctuationbased,
      title={Fluctuation-based Adaptive Structured Pruning for Large Language Models}, 
      author={Yongqi An and Xu Zhao and Tao Yu and Ming Tang and Jinqiao Wang},
      year={2023},
      eprint={2312.11983},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

flap's People

Contributors

an-yongqi avatar ivaopenlab avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

flap's Issues

Question on the calculation of W_metric for 'self_attn.o_proj' in prune_flap()

Thanks for your inspiring work!
I have a little question on the W_metric for self_attn.o_proj in prune_flap(), there is a square operation, while for W_metric for mlp.down_proj is different.

        for name in subset:
            if name == 'self_attn.o_proj':
                W_metric = metrics[args.metrics](wrapped_layers, subset, name) ** 2    # sqaure is needed
                if args.structure == "UL-UM":
                    W_metric = W_metric.reshape(-1, 128).sum(dim=1)
                    thresh = torch.sort(W_metric.cuda())[0][int(args.pruning_ratio*layer.self_attn.num_heads)].cpu()
                    W_mask = (W_metric>=thresh)
                    attn_mask.append(W_mask)
                elif args.structure == "UL-MM":
                    W_metric = W_metric.reshape(-1, 128).sum(dim=1)
                    thresh = torch.sort(W_metric.cuda())[0][args.remove_heads // len(layers)].cpu()
                    W_mask = (W_metric>=thresh)
                    attn_mask.append(W_mask)
                else:
                    attn_metric_list.append(W_metric.cpu())
                attn_baseline_inp_list.append(wrapped_layers[name].baseline_inp.type(torch.half))
            else:
                W_metric = metrics[args.metrics](wrapped_layers, subset, name)    # no square
                if args.structure == "UL-UM":
                    thresh = torch.sort(W_metric.cuda())[0][int(W_metric.numel()*args.pruning_ratio)].cpu()
                    W_mask = (W_metric>=thresh)
                    mlp_mask.append(W_mask)
                elif args.structure == "UL-MM":
                    thresh = torch.sort(W_metric.cuda())[0][cal_remove_neuron(args, model)].cpu()
                    W_mask = (W_metric>=thresh)
                    mlp_mask.append(W_mask)
                else:
                    mlp_metric_list.append(W_metric.cpu())
                mlp_baseline_inp_list.append(wrapped_layers[name].baseline_inp.type(torch.half))
            wrapped_layers[name].free()

Im really confused. Could you help me out?

Question about the number of parameters

Hi, thank you for sharing your impressive work.

I think there might be a need to modify the below line [link].
print(f"model parameter {sum(p.numel() for p in model.parameters()) / 1024 ** 3:.2f}B")

Instead of dividing the number of parameters by 1024 ** 3 to calculate the parameters in billions, it might be more accurate to use 1000 ** 3. Specifically, when I checked the number of parameters of the following models, the results are:

  • LLaMA-7B: 6738415616 = 6.74B
  • LLM-Pruner’s code (20% pruning) [link]: 5422977024 = 5.42B
  • FLAP’s code (20% pruning) [link]: 5442514944 = 5.44B (not 5.07B)

I would appreciate it if you could share your opinion. Thank you for your time and consideration.

how much data was used in the pruning process?

I would like to know how much data was used in the pruning process? Is it just like the example code where nsample=1024, indicating that only 1024 data were used to determine the pruning results?

Question about the wikitext2 data loader

Thanks for your nice work. When running the sample script, I'm getting the following error message regarding the wikitext2 loader. Would you kindly check it?

File "/ssd2/bkkim/FLAP/lib/data.py", line 81, in get_wikitext2
    traindata = load_dataset('text', data_files='datasets/wikitext/wiki.train.raw', split="train")
FileNotFoundError: Unable to find '/ssd2/bkkim/FLAP/datasets/wikitext/wiki.train.raw'

By the way, I found a workaround for the issue by using the comment lines in the script. Is this an appropriate solution?

# traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
# testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

Thank you for your time.

request for the inference speed test script

image
I saw the evaluation results of the inference speed before and after pruning in the paper, I'm wondering if I can ask for a copy of the test script, I wanna reproduce the result. thank you!

pruning 之后使用 无法读取模型

我尝试使用
model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True)
但是会报错
“Traceback (most recent call last):
File "/home/ubuntu/test_scripts/benchmark_r.py", line 154, in
main()
File "/home/ubuntu/test_scripts/benchmark_r.py", line 63, in main
model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True,
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 556, in from_pretrained
return model_class.from_pretrained(
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3502, in from_pretrained
) = cls._load_pretrained_model(
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3926, in _load_pretrained_model
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/modeling_utils.py", line 805, in _load_state_dict_into_meta_model
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 348, in set_module_tensor_to_device
raise ValueError(
ValueError: Trying to set a tensor of shape torch.Size([2048, 2785]) in "weight" (which has shape torch.Size([2048, 5504])), this look incorrect.

我也尝试了在加载的时候添加参数 ignore_mismatched_sizes=True
model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True, ignore_mismatched_sizes=True)

同样也会报错:
Some weights of QWenLMHeadModel were not initialized from the model checkpoint at /data/xxxx and are newly initialized because the shapes did not match:

  • transformer.h.10.mlp.c_proj.weight: found shape torch.Size([2048, 2785]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.10.mlp.w1.weight: found shape torch.Size([2785, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.10.mlp.w2.weight: found shape torch.Size([2785, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.11.mlp.c_proj.weight: found shape torch.Size([2048, 2518]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.11.mlp.w1.weight: found shape torch.Size([2518, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.11.mlp.w2.weight: found shape torch.Size([2518, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.12.attn.c_attn.weight: found shape torch.Size([3840, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.12.attn.c_proj.weight: found shape torch.Size([2048, 1280]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.12.mlp.c_proj.weight: found shape torch.Size([2048, 2393]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.12.mlp.w1.weight: found shape torch.Size([2393, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.12.mlp.w2.weight: found shape torch.Size([2393, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.13.attn.c_attn.weight: found shape torch.Size([3072, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.13.attn.c_proj.weight: found shape torch.Size([2048, 1024]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.13.mlp.c_proj.weight: found shape torch.Size([2048, 3776]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.13.mlp.w1.weight: found shape torch.Size([3776, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.13.mlp.w2.weight: found shape torch.Size([3776, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.14.attn.c_attn.weight: found shape torch.Size([2688, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.14.attn.c_proj.weight: found shape torch.Size([2048, 896]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.14.mlp.c_proj.weight: found shape torch.Size([2048, 3594]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.14.mlp.w1.weight: found shape torch.Size([3594, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.14.mlp.w2.weight: found shape torch.Size([3594, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.15.attn.c_attn.weight: found shape torch.Size([3072, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.15.attn.c_proj.weight: found shape torch.Size([2048, 1024]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.15.mlp.c_proj.weight: found shape torch.Size([2048, 4113]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.15.mlp.w1.weight: found shape torch.Size([4113, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.15.mlp.w2.weight: found shape torch.Size([4113, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.16.attn.c_attn.weight: found shape torch.Size([3072, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.16.attn.c_proj.weight: found shape torch.Size([2048, 1024]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.17.attn.c_attn.weight: found shape torch.Size([2688, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.17.attn.c_proj.weight: found shape torch.Size([2048, 896]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.17.mlp.c_proj.weight: found shape torch.Size([2048, 3263]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.17.mlp.w1.weight: found shape torch.Size([3263, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.17.mlp.w2.weight: found shape torch.Size([3263, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.18.mlp.c_proj.weight: found shape torch.Size([2048, 3861]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.18.mlp.w2.weight: found shape torch.Size([3861, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.18.attn.c_attn.weight: found shape torch.Size([1536, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.18.attn.c_proj.weight: found shape torch.Size([2048, 512]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.18.mlp.w1.weight: found shape torch.Size([3861, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.19.attn.c_attn.weight: found shape torch.Size([2688, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.19.attn.c_proj.weight: found shape torch.Size([2048, 896]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.20.attn.c_attn.weight: found shape torch.Size([2688, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.20.attn.c_proj.weight: found shape torch.Size([2048, 896]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.20.mlp.c_proj.weight: found shape torch.Size([2048, 3291]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.20.mlp.w1.weight: found shape torch.Size([3291, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.20.mlp.w2.weight: found shape torch.Size([3291, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.21.attn.c_attn.weight: found shape torch.Size([1536, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.21.attn.c_proj.weight: found shape torch.Size([2048, 512]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.22.attn.c_attn.weight: found shape torch.Size([3072, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.22.attn.c_proj.weight: found shape torch.Size([2048, 1024]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.9.mlp.c_proj.weight: found shape torch.Size([2048, 2630]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.9.mlp.w1.weight: found shape torch.Size([2630, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.9.mlp.w2.weight: found shape torch.Size([2630, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
    Traceback (most recent call last):
    File "/home/ubuntu/test_scripts/benchmark_r.py", line 152, in
    main()
    File "/home/ubuntu/test_scripts/benchmark_r.py", line 63, in main
    model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True, ignore_mismatched_sizes=True)
    File "/home/ubuntu/miniconda3/envs/qwen/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 556, in from_pretrained
    return model_class.from_pretrained(
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3558, in from_pretrained
    dispatch_model(model, **device_map_kwargs)
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/accelerate/big_modeling.py", line 474, in dispatch_model
    model.to(device)
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2556, in to
    return super().to(*args, **kwargs)
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1152, in to
    return self._apply(convert)
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 802, in _apply
    module._apply(fn)
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 802, in _apply
    module._apply(fn)
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 802, in _apply
    module._apply(fn)
    [Previous line repeated 2 more times]
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 825, in _apply
    param_applied = fn(param)
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1150, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
    NotImplementedError: Cannot copy out of meta tensor; no data!

请问你们在prune模型之后是怎么去加载的呢。
很着急尝试FLAP,期待您的回复,谢谢。

support for other models

Are there any parallel efforts ongoing to add support for other models in FLAP?
models such as phi, Gemma, mistral, mistral, etc.

If so, what is the timeline for the same?

AttributeError: Can't pickle local object 'add_hook_to_module.<locals>.new_forward'

Traceback (most recent call last):
File "/home/shwu/LABS/FLAP/main.py", line 147, in
main()
File "/home/shwu/LABS/FLAP/main.py", line 141, in main
torch.save(model, f'{args.save_model}/pruned_model.pt')
File "/home/shwu/LABS/FLAP/venv/lib/python3.10/site-packages/torch/serialization.py", line 629, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
File "/home/shwu/LABS/FLAP/venv/lib/python3.10/site-packages/torch/serialization.py", line 841, in _save
pickler.dump(obj)
AttributeError: Can't pickle local object 'add_hook_to_module..new_forward'


this occurs after pruning on multiple gpus and saving with torch.save

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.