Giter Club home page Giter Club logo

Comments (5)

pevnak avatar pevnak commented on July 20, 2024 1

I do not understand this yet, but I finally make it through the attention, such that it now gives the same results as the one in pytorch. I will create the PR, where bits of what I have done are present. But I think it would require adding the right version of rotary embedding.

I do not know, my code is worth a PR. It is a mess.

I have it here
https://github.com/pevnak/Transformers.jl#tp/phi

The important part is in test/debug, where I execute the first layer of phi in python and in julia to see, how to code the self-attention. I have prepared a bit of implementation/phi, but it lacks the construction with the correct self-attention and wiring.

from transformers.jl.

chengchingwen avatar chengchingwen commented on July 20, 2024

Sure, I'll be glad to help. Can you elaborate more on the problem? I'm not sure what the question is.

from transformers.jl.

pevnak avatar pevnak commented on July 20, 2024

Hi Peter,

Thanks a lot for answer. I made a progress, but I got stuck that I cannot reproduce the effect of self-attention with the rotary embedding. I have written a code in Julia and Python to compare results side by side.

The python code is as follows

import torch
import transformers
import math
torch.manual_seed(0)

def rotate_half(x):
   """Rotates half the hidden dims of the input."""
   x1 = x[..., : x.shape[-1] // 2]
   x2 = x[..., x.shape[-1] // 2 :]
   return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
   cos = cos[position_ids].unsqueeze(unsqueeze_dim)
   sin = sin[position_ids].unsqueeze(unsqueeze_dim)
   q_embed = (q * cos) + (rotate_half(q) * sin)
   k_embed = (k * cos) + (rotate_half(k) * sin)
   return q_embed, k_embed



def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
   cos = cos[position_ids].unsqueeze(unsqueeze_dim)
   sin = sin[position_ids].unsqueeze(unsqueeze_dim)
   q_embed = (q * cos)
   k_embed = (k * cos)
   return q_embed, k_embed


# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi

def repeat_kv(hidden_states, n_rep):
   """
   This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
   num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
   """
   batch, num_key_value_heads, slen, head_dim = hidden_states.shape
   if n_rep == 1:
      return hidden_states
   hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
   return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


model =transformers.AutoModelForCausalLM.from_pretrained('microsoft/phi-1', torch_dtype=torch.float32, trust_remote_code=True)
tokenizer = transformers.AutoTokenizer.from_pretrained('microsoft/phi-1', trust_remote_code=True)

s = 'Tell me something about Julia?'
inputs = tokenizer(s, return_tensors='pt', return_attention_mask=False)
torch.save(inputs, '/tmp/inputs.torch')
e = model.model.embed_tokens(inputs.input_ids)
l = model.model.layers[0]
torch.save(e, '/tmp/embedding.torch')
hidden_states = l.input_layernorm(e)
torch.save(hidden_states, '/tmp/hidden_states.torch')



attn_outputs, self_attn_weights, present_key_value = l.self_attn(hidden_states)
torch.save(attn_outputs, '/tmp/attn_outputs.torch')

sa = l.self_attn
query_states = sa.q_proj(hidden_states)
torch.save(query_states, '/tmp/query_states.torch')
key_states = sa.k_proj(hidden_states)
torch.save(key_states, '/tmp/key_states.torch')
value_states = sa.v_proj(hidden_states)
torch.save(value_states, '/tmp/value_states.torch')


bsz, q_len, _ = hidden_states.size()
query_states = query_states.view(bsz, q_len, sa.num_heads, sa.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, sa.num_key_value_heads, sa.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, sa.num_key_value_heads, sa.head_dim).transpose(1, 2)

cos, sin = sa.rotary_emb(value_states, seq_len=kv_seq_len)
torch.save(cos, '/tmp/cos.torch')
torch.save(sin, '/tmp/sin.torch')

# Partial rotary embedding
query_rot, query_pass = (
   query_states[..., : sa.rotary_emb.dim],
   query_states[..., sa.rotary_emb.dim :],
)

torch.save(query_rot, '/tmp/query_rot.torch')
torch.save(query_pass, '/tmp/query_pass.torch')

key_rot, key_pass = (
   key_states[..., : sa.rotary_emb.dim],
   key_states[..., sa.rotary_emb.dim :],
)

torch.save(key_rot, '/tmp/key_rot.torch')
torch.save(key_pass, '/tmp/key_pass.torch')

# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot_pos, key_rot_pos = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, None)
torch.save(query_rot_pos, '/tmp/query_rot_pos.torch')
torch.save(key_rot_pos, '/tmp/key_rot_pos.torch')


# [batch_size, seq_length, num_heads, head_dim]
query_rot_states = torch.cat((query_rot_pos, query_pass), dim=-1)
key_rot_states = torch.cat((key_rot_pos, key_pass), dim=-1)

torch.save(query_rot_states, '/tmp/query_rot_states.torch')
torch.save(key_rot_states, '/tmp/key_rot_states.torch')


key_states = repeat_kv(key_states, sa.num_key_value_groups)
value_states = repeat_kv(value_states, sa.num_key_value_groups)

# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
key_rot_states.size()
query_rot_states.size()
attn_weights = torch.matmul(
   query_rot_states.to(torch.float32), key_rot_states.to(torch.float32).transpose(2, 3)
)/ math.sqrt(sa.head_dim)
attn_weights.size()

torch.save(attn_weights, '/tmp/attn_weights.torch')

# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
torch.save(attn_weights, '/tmp/attn_weights.torch')
attn_weights = torch.nn.functional.dropout(attn_weights, p=sa.attention_dropout, training=sa.training)

attn_output = torch.matmul(attn_weights, value_states)
torch.save(attn_output, '/tmp/attn_output.torch')

Where after every operation I save the results, such that I can load them to Julia and compare. The similar Julia code is as

using Transformers
using Transformers.Flux
using Transformers.HuggingFace
using Transformers.HuggingFace: HGFPhiPreTrainedModel, HGFPhiForCausalLM,  HGFLlamaPreTrainedModel, SelfAttention
using Transformers.HuggingFace: joinname, load_model
using Transformers.Layers: apply_on_namedtuple
using Transformers.HuggingFace: weighted_sum_mixing, gptneox_rope_multihead_qkv_attention, gptneox_rope_attention_score, generic_multihead_qkv_attention, gptneox_reorder
import Transformers.HuggingFace: one_init, zero_init, getweight
using Transformers.Layers: LayerNorm
using NeuralAttentionlib: as_collapsed, _split_and_move_head,  generic_qkv_attention, mixing, attention_score, split_head, naive_qkv_attention, naive_attention_score, scaled_dot_product_score
using TextEncodeBase
using Statistics
using StatsBase
using Pickle
using Flux
using NNlib

function load_torch_matrix(filename)
	x = Pickle.Torch.THload(filename)
	x = Matrix(transpose(x[1,:,:]))
end

"""

 x: [bs, num_attention_heads, seq_len, head_size]
"""
function compare_tensors(x, filename)
	r = Pickle.Torch.THload(filename)
	size(r,1) !=1 && error("the first dimension should be one (one sample)")
	r = r[1,:,:,:]
	size(x,1) == size(r,3) || error("dimension mismatch")
	size(x,3) == size(r,2) || error("dimension mismatch")
	size(x,2) == size(r,1) || error("dimension mismatch")
	maximum(maximum(abs.(r[:,i,:] .- transpose(x[:,:,i]))) for i in 1:size(x,3))
end

# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
function  phi_rotary_embedding(dim, max_position_embeddings=2048, base=10000)
    inv_freq = 1 ./ (base .^ (collect(0:2:(dim-1)) ./ dim))
    inv_freq = vcat(inv_freq,inv_freq)
    t = 0:(max_position_embeddings-1)
    (sin.(inv_freq .* t'), cos.(inv_freq .* t'))
end


"""
	rotate_half(x)

	Rotates half the hidden dims of the input
"""
function rotate_half(x)
	d = size(x,1) รท 2
	x1 = @view x[1:d,:,:]
	x2 = @view x[d+1:end, :, :]
	cat(-x2, x1, dims = 1)
end

function apply_rotary_pos_emb(q, k, _cos::AbstractMatrix, _sin::AbstractMatrix)
	_sin = reshape(_sin, size(_sin,1), 1, size(_sin,2))
	_cos = reshape(_cos, size(_cos,1), 1, size(_cos,2))
	q_embed = q .* _cos .+ rotate_half(q) .* _sin
	k_embed = k .* _cos .+ rotate_half(k) .* _sin
	return (q_embed, k_embed)
end

model_type = model_name = "microsoft/phi-1"
cfg = Transformers.HuggingFace.load_config(model_type)
state_dict = Transformers.HuggingFace.load_state_dict(model_type; config=cfg)
state_dict = Dict(filter((kv) -> contains(kv[1], "model.layers.0."), collect(state_dict)))

# textenc = Transformers.HuggingFace.load_tokenizer(model_name)
# model = Transformers.HuggingFace.load_model(Transformers.HuggingFace.HGFPhiForCausalLM, cfg, state_dict, "")

s = "Tell me something about Julia?"

# input = encode(textenc, s).token 
# input = OneHotArray(OneHot{0x0000c477}.([ 24447, 503, 1224, 547, 22301, 31]))
# input_ref = Pickle.Torch.THload("/tmp/inputs.torch")["input_ids"]
# e = model.model.embed(input) # verify embedding
e_ref = load_torch_matrix("/tmp/embedding.torch")
e = e_ref

lprefix = "model.layers.0"
# residual = e.hidden_state
residual = e
ln = load_model(HGFPhiPreTrainedModel, Layers.LayerNorm, cfg, state_dict, joinname(lprefix, "input_layernorm"))
hidden_state = ln(residual)
hidden_state_ref = load_torch_matrix("/tmp/hidden_states.torch")


# this is where we want to do the self-attention. But it does not work, so 
# we need to learn, how to use it
# sa = load_model(HGFPhiForCausalLM, SelfAttention, cfg, state_dict, joinname(lprefix, "self_attn"))
# attn_outputs = sa((;hidden_state = hidden_state_ref)).hidden_state
# attn_outputs .- load_torch_matrix("/tmp/attn_outputs.torch")

nt = (;hidden_state = hidden_state_ref)
qkv = apply_on_namedtuple(sa.qkv_proj, nt)
maximum(abs.(qkv.hidden_state[1] .- load_torch_matrix("/tmp/query_states.torch")))
maximum(abs.(qkv.hidden_state[2] .- load_torch_matrix("/tmp/key_states.torch")))
maximum(abs.(qkv.hidden_state[3] .- load_torch_matrix("/tmp/value_states.torch")))

# this part is about piercing the computation of attention scode
base, dim, head = 10000.0, 64, 32
hidden_size = 32
len = 6
_sincos = phi_rotary_embedding(32)
_sin = _sincos[1][:,1:len]
_cos = _sincos[2][:,1:len]
maximum(_sin .- Pickle.Torch.THload("/tmp/sin.torch")')
maximum(_cos .- Pickle.Torch.THload("/tmp/cos.torch")')

q,k,v = qkv.hidden_state
query_states = _split_and_move_head(head, q)
key_states = _split_and_move_head(head, k)
hv = _split_and_move_head(head, v)

query_rot, query_pass = (
   query_states[1:32,:, :],	# sa.rotary_emb.dim = 32
   query_states[33:end, :, :],
)

compare_tensors(query_rot, "/tmp/query_rot.torch")
compare_tensors(query_pass, "/tmp/query_pass.torch")

key_rot, key_pass = (
   key_states[1:32,:, :],	# sa.rotary_emb.dim = 32
   key_states[33:end, :, :],
)

compare_tensors(key_rot, "/tmp/key_rot.torch")
compare_tensors(key_pass, "/tmp/key_pass.torch")


query_rot_pos, key_rot_pos = apply_rotary_pos_emb(query_rot, key_rot, _cos, _sin)

compare_tensors(query_rot_pos, "/tmp/query_rot_pos.torch")
compare_tensors(key_rot_pos, "/tmp/key_rot_pos.torch")

query_rot_states = cat(query_rot_pos, query_pass, dims=1)
key_rot_states = cat(key_rot_pos, key_pass, dims=1)

compare_tensors(query_rot_states, "/tmp/query_rot_states.torch")
compare_tensors(key_rot_states, "/tmp/key_rot_states.torch")


# attn_weights = attention_score(naive_attention_score(), as_collapsed(query_rot_states), as_collapsed(key_rot_states))

attn_weights = scaled_dot_product_score(query_rot_states, key_rot_states);
compare_tensors(attn_weights, "/tmp/attn_weights.torch")

It is quite a lot of code, but I had to write to remove the effect of individual layers in NeuralattentionLib to better understand, what is going on. All the intermediate results are similar to the python version, except the attn_weights`, where I got completely different dimensions, as I have discussed in discourse forum. I do not know if it is because NeuralattentionLib does the self-attention very differently, or I need to permute dimensions. I have tried a naive permutation of dimentions and it did not work out.

Any help would be appreciated.

from transformers.jl.

chengchingwen avatar chengchingwen commented on July 20, 2024

This:

q,k,v = qkv.hidden_state
query_states = _split_and_move_head(head, q)
key_states = _split_and_move_head(head, k)
hv = _split_and_move_head(head, v)

is mistaking the length dimension for the batch dimension because the real batch size is omitted and the semantics of each dimension are not specified and thus give the wrong results.

it should either be:

# usually preferable
query_states = _split_and_move_head(head, as_collapsed(q))
key_states = _split_and_move_head(head, as_collapsed(k))
hv = _split_and_move_head(head, as_collapsed(v))

or:

query_states = _split_and_move_head(head, reshape(q, Val(3)))
key_states = _split_and_move_head(head, reshape(k, Val(3)))
hv = _split_and_move_head(head, reshape(v, Val(3)))

It's worth noticing that in most (probably all) Python implementations, the dimension of the tensor is fixed. On the other hand, NeuralAttentionlib makes an abstraction layer (CollapsedDimsArray) above those tensors. The attention "algorithm" requires the tensor to have 3 dimensions: the feature dimension, length dimension, and batch dimension. CollapsedDimsArray groups the dimensions of the tensor into these 3. The whole attention interface in NeuralAttentionlib is built on top of this abstraction layer.

from transformers.jl.

chengchingwen avatar chengchingwen commented on July 20, 2024

I add it in #168. Let me know if you have tested it. Once we don't find any problem, then I'll merge it.

from transformers.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.