Giter Club home page Giter Club logo

Comments (10)

wanzhenchn avatar wanzhenchn commented on July 17, 2024 1

I have verified that the FUSED MODEL caused this problem!

image

Further investigation revealed that the issue is triggered by fuse_attention() in awq/models/llama.py

    def fuse_attention(self):
        for name, module in self.attention_modules:
            qkv_layer: WQLinear = self._fuse_qkv(module)
            attn = QuantLlamaAttention(
                module.hidden_size,
                module.num_heads,
                qkv_layer,
                module.o_proj,
                qkv_layer.qweight.device,
                self.model.config.max_new_tokens
            )
            set_module_name(self.model, name, attn)

from autoawq.

wanzhenchn avatar wanzhenchn commented on July 17, 2024

The problem also occurred in my case.

model: llama-13b, awq_model_w4_g128.pt

ids = torch.randint(0, tokenizer.vocab_size, (2, n_context)).cuda
out = model(ids)
out.logits

I have verified that the original AWQ has NOT the problem, what's the difference after your refactor?

from autoawq.

casper-hansen avatar casper-hansen commented on July 17, 2024

The problem also occurred in my case.

model: llama-13b, awq_model_w4_g128.pt

ids = torch.randint(0, tokenizer.vocab_size, (2, n_context)).cuda
out = model(ids)
out.logits

I have verified that the original AWQ has NOT the problem, what's the difference after your refactor?

There should practically be a minimal difference between the actual quantization code. I will have to further investigate this bug as it prevents quite a few things in terms of progression.

@wanzhenchn I would love any and all help in terms of debugging/reporting when this happens.

from autoawq.

casper-hansen avatar casper-hansen commented on July 17, 2024

@wanzhenchn I have now tested the original llm-awq and it seems the bug is also present there.

import time
import torch
from awq.utils.utils import simple_dispatch_model
from awq.quantize.quantizer import real_quantize_model_weight
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_in_model

def fuse(model, device="cuda:0"):
    from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
    make_quant_attn(model, device)
    make_quant_norm(model)
    make_fused_mlp(model)

def load(model_path="vicuna-7b-v1.5-awq", checkpoint="awq_model_w4_g128.pt"):
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)

    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.float16, trust_remote_code=True)
    
    real_quantize_model_weight(
        model, w_bit=4, q_config={"zero_point": True, "q_group_size": 128}, init_only=True
    )
    
    model.tie_weights()
    
    # Infer device map
    kwargs = {}
    device_map = infer_auto_device_map(
        model,
        no_split_module_classes=[
            "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
        **kwargs
    )
    # Load checkpoint in the model
    load_checkpoint_in_model(
        model,
        checkpoint=f"{model_path}/{checkpoint}",
        device_map=device_map,
        offload_state_dict=True,
    )
    
    # Fuse model
    fuse(model)

    # Dispatch model
    model = simple_dispatch_model(model, device_map=device_map)

    model.eval()

    return model, tokenizer

@torch.inference_mode()
def run_speed(device="cuda:0", n_generate=128, max_new_tokens=256):
    def _timer(func):
        start = time.time()
        out = func()
        return out, time.time() - start
    
    def _generate(model, model_out, n_generate):
        past_key_values = model_out.past_key_values

        for i in range(n_generate):
            logits = model_out.logits[0, -1, :]
            probs = torch.softmax(logits, dim=-1)
            token = torch.multinomial(probs, num_samples=1)
            token = torch.as_tensor([token], device=device).unsqueeze(0)

            model_out = model(token, use_cache=True, past_key_values=past_key_values)
    
    def _warmup(device:str):
        warm_up = torch.randn((4096,4096)).to(device)
        torch.mm(warm_up,warm_up)
    
    # Load model
    model, tokenizer = load()
    _warmup(device)

    # Generate random inputs
    n_context = max_new_tokens - n_generate
    ids = torch.randint(0, tokenizer.vocab_size, (2, n_context)).cuda()

    # Context stage
    model_out, context_time = _timer(lambda: model(ids, use_cache=True))

    # Generation stage
    _, generation_time = _timer(lambda: _generate(model, model_out, n_generate))
    
    # Prints
    memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
    context_tokens_per_second = n_context / context_time
    context_ms_per_token = (context_time*1000) / n_context
    inference_tokens_per_second = n_generate / generation_time
    inference_ms_per_token = (generation_time*1000) / n_generate

    print(f"[======] Model summary [======]")
    print(f"[*] Context speed: {context_tokens_per_second:.2f} tokens/second ({context_ms_per_token:.2f} ms/token)")
    print(f"[*] Generation speed: {inference_tokens_per_second:.2f} tokens/second ({inference_ms_per_token:.2f} ms/token)")
    print(f"[*] VRAM: {memory_used:.2f} MB")

if __name__ == '__main__':
    run_speed()

from autoawq.

wanzhenchn avatar wanzhenchn commented on July 17, 2024

@wanzhenchn I have now tested the original llm-awq and it seems the bug is also present there.

import time
import torch
from awq.utils.utils import simple_dispatch_model
from awq.quantize.quantizer import real_quantize_model_weight
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_in_model

def fuse(model, device="cuda:0"):
    from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
    make_quant_attn(model, device)
    make_quant_norm(model)
    make_fused_mlp(model)

def load(model_path="vicuna-7b-v1.5-awq", checkpoint="awq_model_w4_g128.pt"):
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)

    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.float16, trust_remote_code=True)
    
    real_quantize_model_weight(
        model, w_bit=4, q_config={"zero_point": True, "q_group_size": 128}, init_only=True
    )
    
    model.tie_weights()
    
    # Infer device map
    kwargs = {}
    device_map = infer_auto_device_map(
        model,
        no_split_module_classes=[
            "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
        **kwargs
    )
    # Load checkpoint in the model
    load_checkpoint_in_model(
        model,
        checkpoint=f"{model_path}/{checkpoint}",
        device_map=device_map,
        offload_state_dict=True,
    )
    
    # Fuse model
    fuse(model)

    # Dispatch model
    model = simple_dispatch_model(model, device_map=device_map)

    model.eval()

    return model, tokenizer

@torch.inference_mode()
def run_speed(device="cuda:0", n_generate=128, max_new_tokens=256):
    def _timer(func):
        start = time.time()
        out = func()
        return out, time.time() - start
    
    def _generate(model, model_out, n_generate):
        past_key_values = model_out.past_key_values

        for i in range(n_generate):
            logits = model_out.logits[0, -1, :]
            probs = torch.softmax(logits, dim=-1)
            token = torch.multinomial(probs, num_samples=1)
            token = torch.as_tensor([token], device=device).unsqueeze(0)

            model_out = model(token, use_cache=True, past_key_values=past_key_values)
    
    def _warmup(device:str):
        warm_up = torch.randn((4096,4096)).to(device)
        torch.mm(warm_up,warm_up)
    
    # Load model
    model, tokenizer = load()
    _warmup(device)

    # Generate random inputs
    n_context = max_new_tokens - n_generate
    ids = torch.randint(0, tokenizer.vocab_size, (2, n_context)).cuda()

    # Context stage
    model_out, context_time = _timer(lambda: model(ids, use_cache=True))

    # Generation stage
    _, generation_time = _timer(lambda: _generate(model, model_out, n_generate))
    
    # Prints
    memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
    context_tokens_per_second = n_context / context_time
    context_ms_per_token = (context_time*1000) / n_context
    inference_tokens_per_second = n_generate / generation_time
    inference_ms_per_token = (generation_time*1000) / n_generate

    print(f"[======] Model summary [======]")
    print(f"[*] Context speed: {context_tokens_per_second:.2f} tokens/second ({context_ms_per_token:.2f} ms/token)")
    print(f"[*] Generation speed: {inference_tokens_per_second:.2f} tokens/second ({inference_ms_per_token:.2f} ms/token)")
    print(f"[*] VRAM: {memory_used:.2f} MB")

if __name__ == '__main__':
    run_speed()

The API I tested is llm-awq/awq, the StreamGenerator() in tinychat/stream_generators/stream_gen.py is named benchmark() and merged into awq/entry.py, the whole code is here: mit-han-lab/llm-awq#79 (comment).

The difference lies in whether the model has been fused, so I suspect that the fused model might be causing this issue.

from autoawq.

casper-hansen avatar casper-hansen commented on July 17, 2024

Further investigation revealed that the issue is triggered by fuse_attention() in awq/models/llama.py

Thank you for investigating. There must be an issue somewhere when fusing the linear layers in LlamaAttention. I will have to investigate further, maybe rebuild it from scratch.

from autoawq.

wanzhenchn avatar wanzhenchn commented on July 17, 2024

Further investigation revealed that the issue is triggered by fuse_attention() in awq/models/llama.py

Thank you for investigating. There must be an issue somewhere when fusing the linear layers in LlamaAttention. I will have to investigate further, maybe rebuild it from scratch.

Yeah, the problem seems to have been fixed in mit-han-lab/llm-awq#8, However, this could be another issue.

from autoawq.

casper-hansen avatar casper-hansen commented on July 17, 2024

Yeah, the problem seems to have been fixed in mit-han-lab/llm-awq#8, However, this could be another issue.

This seems to be another issue, and this commit is already included in AutoAWQ. I have tested more and implemented #26 as a start but still need more testing to figure out and solve the issue.

from autoawq.

wanzhenchn avatar wanzhenchn commented on July 17, 2024

Yeah, the problem seems to have been fixed in mit-han-lab/llm-awq#8, However, this could be another issue.

This seems to be another issue, and this commit is already included in AutoAWQ. I have tested more and implemented #26 as a start but still need more testing to figure out and solve the issue.

I have tested and found that if the quantized model fusion is disable, its throughput performance is nearly the same as FP16.

from autoawq.

casper-hansen avatar casper-hansen commented on July 17, 2024

I have tested and found that if the quantized model fusion is disable, its throughput performance is nearly the same as FP16.

Yes, default LLaMa models are slow in general because of the many extra linear layers, unlike other architectures like Falcon or MPT. We fuse the QKV operation, which works fine, and we use rotary embeddings from vLLM (which is causing this illegal memory access). I am working on fixing this issue to enable multi-GPU and batch size >1 in #28

from autoawq.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.