Giter Club home page Giter Club logo

how-to-use-transformers's Issues

question about attention mask

关于第九章这一部分:
为了简化数据处理,这里我们并没有将 [CLS]、[SEP]、[PAD] 等特殊 token 对应的标签设为 -100,而是维持原始的 0 值,然后在计算损失时借助 Attention Mask 来排除填充位置。

attention mask对于cls的位置是1。“active_loss = attention_mask.view(-1) == 1”会包括cls。是否需要mask掉?

多卡训练

你好,在摘要提取中,进行多卡训练的时候,出现这样的问题AttributeError: 'DataParallel' object has no attribute 'prepare_decoder_input_ids_from_labels'是模型本身不能进行多卡训练吗?

摘要评估用的Rouge,是不是有点问题?每次只取一个样本计算Rouge的值。

`import numpy as np
from rouge import Rouge

rouge = Rouge()

def test_loop(dataloader, model):
preds, labels = [], []
model.eval()
for batch_data in tqdm(dataloader):
batch_data = batch_data.to(device)
with torch.no_grad():
generated_tokens = model.generate(
batch_data["input_ids"],
attention_mask=batch_data["attention_mask"],
max_length=max_target_length,
num_beams=4,
no_repeat_ngram_size=2,
).cpu().numpy()
if isinstance(generated_tokens, tuple):
generated_tokens = generated_tokens[0]
label_tokens = batch_data["labels"].cpu().numpy()

    decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)

    preds += [' '.join(pred.strip()) for pred in decoded_preds]
    labels += [' '.join(label.strip()) for label in decoded_labels]
scores = rouge.get_scores(hyps=preds, refs=labels)[0]
result = {key: value['f'] * 100 for key, value in scores.items()}
result['avg'] = np.mean(list(result.values()))
print(f"Rouge1: {result['rouge-1']:>0.2f} Rouge2: {result['rouge-2']:>0.2f} RougeL: {result['rouge-l']:>0.2f}\n")
return result`

2,3章疑问

第2章
image
输入这里应该是input embeddings吧?

第3章
image

公式5的求和上标是不是应该为n?

关于第三章: 注意力机制 实现的问题

在下面的代码中, 我觉得应该表明为什么 Q, K, V 向量序列是等于 inputs_embeds 的, 我理解的是注意力机制中的 QKV 是 embedding 与 W_Q 和 W_K , W_V 这三个矩阵相乘得到的, 这三个矩阵也是超参数, 而下面的代码是好像默认 这三个矩阵是单位矩阵.
`import torch
from math import sqrt

Q = K = V = inputs_embeds
dim_k = K.size(-1)
scores = torch.bmm(Q, K.transpose(1,2)) / sqrt(dim_k)
print(scores.size())`

此外 dim_k = K.size(-1) 和下面封装的函数中不一致, 上面的 dim_k = K.size(-1), 而下面的 dim_k = query.size(-1)

`import torch
import torch.nn.functional as F
from math import sqrt

def scaled_dot_product_attention(query, key, value, query_mask=None, key_mask=None, mask=None):
dim_k = query.size(-1)
scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
if query_mask is not None and key_mask is not None:
mask = torch.bmm(query_mask.unsqueeze(-1), key_mask.unsqueeze(1))
if mask is not None:
scores = scores.masked_fill(mask == 0, -float("inf"))
weights = F.softmax(scores, dim=-1)
return torch.bmm(weights, value)`

如何基于transformers库自定义模型?

image
我试了一下您这个写法,确实可以从from_pretrained方法中加载一些已经预训练过的一些模型,但是前提是这个地方的参数名称一定叫做self.bert
如果我把这一行的名称改写为self.bert_model等等其他的名称,就统统无法用from_pretrained方法成功加载权重了。这样是不是太死板了,有什么别的方法解决吗?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.