Giter Club home page Giter Club logo

Comments (9)

lvhan028 avatar lvhan028 commented on August 16, 2024 1

"不仅与原模型输出不一致" The lmdeploy kernel is different from transformers' kernel. It is normal the output logits is different.
"logits的前半部分解码出来也不是prompt的input_ids"
We'll investigate it.

from lmdeploy.

lvhan028 avatar lvhan028 commented on August 16, 2024

端午结束之后,我们会来处理这个问题

from lmdeploy.

irexyc avatar irexyc commented on August 16, 2024

logits的前半部分解码出来也不是prompt的input_ids。

不是很理解你的问题,看你的代码的意思是说自己构造input_ids,然后看某些id的概率。是结果不符合预期么?


另外,我比较了一下 turobmind 和 transformers 的logits,没有感觉差别很大。以下是比较的代码:

from transformers import AutoModel, AutoTokenizer
import torch

tok = AutoTokenizer.from_pretrained('/mnt/140/InternLM/internlm2-chat-1_8b', trust_remote_code=True)
ids = tok.encode('请根据给定的文本信息,判断是否存在的违规行为。回复:')
# [1, 60836, 68420, 60562, 68942, 69684, 68347, 60353, 69844, 68553, 72279, 77302, 68891, 60355, 70396, 334]


# transformers
m = AutoModel.from_pretrained('/mnt/140/InternLM/internlm2-chat-1_8b', trust_remote_code=True).half().eval().cuda()
input_ids = torch.LongTensor(ids).reshape(1, -1)
with torch.inference_mode():
  out = m(input_ids.cuda())
  logits1 = out.logits.cpu()


# turbomind
from lmdeploy import pipeline
pipe = pipeline('/mnt/140/InternLM/internlm2-chat-1_8b', log_level='INFO')
g = pipe.engine.create_instance()
tok = pipe.tokenizer
inputs = '请根据给定的文本信息,判断是否存在的违规行为。回复:'
input_ids = tok.encode(inputs) 
logits2 = g.decode(input_ids).cpu()

# diff
torch.max(torch.abs(logits1 - logits2))
# tensor(0.0301)

from lmdeploy.

dafu-wu avatar dafu-wu commented on August 16, 2024

@irexyc Hi, I use your code to compare llamav2 found this err:
AttributeError: 'BaseModelOutputWithPast' object has no attribute 'logits'

from lmdeploy.

lvhan028 avatar lvhan028 commented on August 16, 2024

Are you using pytorch engine? May provide the reproducible demo

from lmdeploy.

dafu-wu avatar dafu-wu commented on August 16, 2024

@lvhan028 Sure

lmdeploy==0.5.1
transformers ==4.42.4

from transformers import AutoModel, AutoTokenizer
import torch
import os
from transformers import LlamaForCausalLM, LlamaTokenizer
model_path = "NousResearch/Llama-2-7b-hf"

# 加载分词器
tokenizer = LlamaTokenizer.from_pretrained(model_path)

# 加载模型
model = LlamaForCausalLM.from_pretrained(model_path)

# 输入文本
input_text = "请根据给定的文本信息,判断是否存在的违规行为。回复:"
input_ids = tokenizer(input_text, return_tensors='pt').input_ids

# 获取logits
with torch.no_grad():
    outputs = model(input_ids)
logits1 = outputs.logits

del model
torch.cuda.empty_cache()


# turbomind
from lmdeploy import pipeline
pipe = pipeline(model_path, log_level='INFO')
g = pipe.engine.create_instance()
tok = pipe.tokenizer
inputs = '请根据给定的文本信息,判断是否存在的违规行为。回复:'
input_ids = tok.encode(inputs) 
logits2 = g.decode(input_ids).cpu()
print(logits2.shape)
# diff
rc = torch.max(torch.abs(logits1 - logits2))
print(rc)
#tensor(0.4661)

from lmdeploy.

dafu-wu avatar dafu-wu commented on August 16, 2024

@lvhan028 Any update?

from lmdeploy.

lvhan028 avatar lvhan028 commented on August 16, 2024

@irexyc Hi, I use your code to compare llamav2 found this err: AttributeError: 'BaseModelOutputWithPast' object has no attribute 'logits'

@irexyc This error comes from the following code. I am using transfromers 4.41.1

m = AutoModel.from_pretrained('/workspace/models-140/llama2/huggingface/llama-2-7b-chat', trust_remote_code=True).half().eval().cuda()
input_ids = torch.LongTensor(ids).reshape(1, -1)
with torch.inference_mode():
  out = m(input_ids.cuda())
  logits1 = out.logits.cpu()

from lmdeploy.

dafu-wu avatar dafu-wu commented on August 16, 2024

@lvhan028 I replace it by LlamaForCausalLM.from_pretrained(model_path), it's normal now, but the logits diff #tensor(0.4661) is very big

from lmdeploy.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.