Giter Club home page Giter Club logo

bert-chinese-text-classification-pytorch's Introduction

Bert-Chinese-Text-Classification-Pytorch

LICENSE

中文文本分类,Bert,ERNIE,基于pytorch,开箱即用。

介绍

模型介绍、数据流动过程:还没写完,写好之后再贴博客地址。 工作忙,懒得写了,类似文章有很多。

机器:一块2080Ti , 训练时间:30分钟。

环境

python 3.7
pytorch 1.1
tqdm
sklearn
tensorboardX
pytorch_pretrained_bert(预训练代码也上传了, 不需要这个库了)

中文数据集

我从THUCNews中抽取了20万条新闻标题,已上传至github,文本长度在20到30之间。一共10个类别,每类2万条。数据以字为单位输入模型。

类别:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐。

数据集划分:

数据集 数据量
训练集 18万
验证集 1万
测试集 1万

更换自己的数据集

  • 按照我数据集的格式来格式化你的中文数据集。

效果

模型 acc 备注
bert 94.83% 单纯的bert
ERNIE 94.61% 说好的中文碾压bert呢
bert_CNN 94.44% bert + CNN
bert_RNN 94.57% bert + RNN
bert_RCNN 94.51% bert + RCNN
bert_DPCNN 94.47% bert + DPCNN

原始的bert效果就很好了,把bert当作embedding层送入其它模型,效果反而降了,之后会尝试长文本的效果对比。

CNN、RNN、DPCNN、RCNN、RNN+Attention、FastText等模型效果,请见我另外一个仓库

预训练语言模型

bert模型放在 bert_pretain目录下,ERNIE模型放在ERNIE_pretrain目录下,每个目录下都是三个文件:

  • pytorch_model.bin
  • bert_config.json
  • vocab.txt

预训练模型下载地址:
bert_Chinese: 模型 https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz
词表 https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt
来自这里
备用:模型的网盘地址:https://pan.baidu.com/s/1qSAD5gwClq7xlgzl_4W3Pw

ERNIE_Chinese: http://image.nghuyong.top/ERNIE.zip
来自这里
备用:网盘地址:https://pan.baidu.com/s/1lEPdDN1-YQJmKEd_g9rLgw

解压后,按照上面说的放在对应目录下,文件名称确认无误即可。

使用说明

下载好预训练模型就可以跑了。

# 训练并测试:
# bert
python run.py --model bert

# bert + 其它
python run.py --model bert_CNN

# ERNIE
python run.py --model ERNIE

参数

模型都在models目录下,超参定义和模型定义在同一文件中。

未完待续

  • 封装预测功能

对应论文

[1] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
[2] ERNIE: Enhanced Representation through Knowledge Integration

bert-chinese-text-classification-pytorch's People

Contributors

649453932 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

bert-chinese-text-classification-pytorch's Issues

模型为什么保存的ckpt

您好 pytorch 训练保存的模型为什么是.ckpt 不应该 .bin 文件吗, ckpt 不是tensorflow 的模型保存形式吗

预训练代码

请问之后会提高bert chinese pytorch版本的预训练(init_checkpoint=bert.bin)文件么?

ERNIE.py与bert.py

您好,请问ERNIE.py体现ERNIE模型的地方在哪里呢?为什么我看不出来和bert.py有什么大的区别?

RuntimeError: CUDA out of memory.

问一下,为什么batch_size 设置过大或者pad_size设置过大就会出现这个问题呢
比如batch_size设置为128或pad_size设置为200,就会报CUDA内存不够的错误,设置较小就能正常运行...
RuntimeError: CUDA out of memory. Tried to allocate 46.00 MiB (GPU 0; 7.43 GiB total capacity; 6.87 GiB already allocated; 14.94 MiB free; 83.05 MiB cached)

该如何解决???

支持并行训练嘛

支持多GPU并行嘛,我有3张卡,希望把句子长度和batchsize都放大点

单机多GPU运行

我想在单机上使用多个GPU来训练,在run.py文件的最开始加上了
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"
并且注释掉了
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True
但是还是只使用的GPU:4这一个GPU

这条语句在tensorflow中可以正常运行的,但是在pytroch中为什么会失效了呢?
在线求解~

求大佬帮助一下,项目架构问题

大佬,作为一个新人,对于代码的理解力还不是么强!所以相结合博客来学习一下!比如:项目中的BERT+CNN进行训练,那么此处的CNN是不是TextCNN那一套?BERT最终训练出来的是字向量还是句子向量?而且模型的整体输入流输出流是什么!还希望能帮忙解答一下谢谢

关于pre-trained model中vocabulary的问题

想问下是不是vocabulary中的内容对训练是没有影响的。感觉训练的时候vocabulary是由build_vocab训练出来的。想求证下我的想法是不是正确的~

如何训练英文数据

作者你好,我现在有一个英文数据集想做多分类,数据格式一样,是否只需将bert预训练模型、词表换成英文的就行呢

您好,这个训练好的模型就是那个model下的.ckpt文件吧,这个模型就是生成词向量的模型吗,我如果想训练自己的词向量模型是不是就用自己的预料就行了?还有,这个模型怎么去用呢?希望不要嫌弃我,我刚开始学这个。。。。期待您的信息回复

您好,这个训练好的模型就是那个model下的.ckpt文件吧,这个模型就是生成词向量的模型吗,我如果想训练自己的词向量模型是不是就用自己的预料就行了?还有,这个模型怎么去用呢?希望不要嫌弃我,我刚开始学这个。。。。期待您的信息回复

> > 为什么我每次跑完都要一两天的时间呢?求大佬们帮忙

为什么我每次跑完都要一两天的时间呢?求大佬们帮忙
因为evaluation是在cpu上做的,试一下把evaluation放在gpu上以tensor的形式来做

谢谢大佬的回答。具体改法是不是把train_eval.py中.cpu()的地方改为.gpu()就可以了(gpu环境以及有了)?原文件里只说了训练时间:30分钟,也不知道怎么训练这么快的,哼😕(吐槽一下嘻嘻)

实际上这时候这个张量本身就是在GPU里面了,只要把cpu()去掉即可,并且要把原本允许ndarry作为参数的sklearn.metrics换为能够张量计算的方式。其他的都不用改。下面是我这边修改的这块代码,可以参考一下。
if total_batch % 100 == 0:
# 每多少轮输出在训练集和验证集上的效果
true = (labels).data
predic = torch.max(outputs.data, 1)[1]
total = true.size(0)
correct = (predic == true).sum().item()
train_acc = correct / total
dev_acc, dev_loss = evaluate(config, model, dev_iter)
if dev_loss < dev_best_loss:
dev_best_loss = dev_loss
torch.save(model.state_dict(), config.save_path)
improve = '*'
last_improve = total_batch
else:
improve = ''
time_dif = get_time_dif(start_time)
msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}'
print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
model.train()

No such file or directory: './bert_pretrain\\bert_config.json'

我感觉是我的在文件夹bert_pretrain中放的东西不对,README中写了要放:
pytorch_model.bin
bert_config.json
vocab.txt
我下载了链接中的bert-base-chiness并直接解压到了bert_pretrain文件夹中,但是没有看到要求中的文件。
(好像下错了,欸嘿

我写了一个predict

image

Connected to pydev debugger (build 181.4668.75)
打印texts:
(tensor([[ 101, 872, 1962, 1435, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0],
[ 101, 872, 1962, 1004, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0]]), tensor([4, 4]), tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0]]))
Traceback (most recent call last):
File "D:\Software\pycharm\PyCharm 2018.1.2\helpers\pydev\pydevd.py", line 1664, in
main()
File "D:\Software\pycharm\PyCharm 2018.1.2\helpers\pydev\pydevd.py", line 1658, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "D:\Software\pycharm\PyCharm 2018.1.2\helpers\pydev\pydevd.py", line 1068, in run
pydev_imports.execfile(file, globals, locals) # execute the script
File "D:\Software\pycharm\PyCharm 2018.1.2\helpers\pydev_pydev_imps_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "D:/Program/Python/bert_classification_pytorch/predict_bak.py", line 48, in
label = bert_model.predict(test_demo)
File "D:/Program/Python/bert_classification_pytorch/predict_bak.py", line 37, in predict
outputs = self.model(texts)
File "D:\Program\Python\bert_classification_pytorch\venv\lib\site-packages\torch\nn\modules\module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "D:\Program\Python\bert_classification_pytorch\models\bert.py", line 42, in forward
mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0]
IndexError: tuple index out of range

我这是哪里出错了呢?模型没有动过……
请各位大佬帮忙看下,非常感谢~

关于分类任务的效果

使用ERNIE预训练模型
我发现用thucnews数据集,效果和您的项目的readme是一样的
但是换一个数据集,比如chnsenticorp数据集,无法复现到百度ERNIE报告的效果。百度报告的结果超过94%,但是这里只能得到90。请问您有试验过thucnews之外的数据集么?

utils.py文件有个小bug

DatasetIterater中的__next__(self)函数
将elif self.index > self.n_batches 修改为elif self.index >= self.n_batches
如果self.residue=False且self.index = self.n_batches时应该就迭代完成了,但是按该代码还会再取一次

关于中文的vocab.txt

我从bert_pretrain文件夹中给出的链接中下载的文件解压出的文件里面并不包含vocab.txt文件,我是否需要另外的下载呢?能否给出下载链接?或者是需要自己构建?

bug: DatasetIterater的residue属性错误

`

class DatasetIterater(object):

def __init__(self, batches, batch_size, device):

    self.batch_size = batch_size

    self.batches = batches

    self.n_batches = len(batches) // batch_size

    self.residue = False  # 记录batch数量是否为整数

    if len(batches) % self.n_batches != 0:

        self.residue = True

    self.index = 0

    self.device = device

`
self.residue = True if len(batches) % self.batch_size != 0 else False

len(batches)=58, batch_size = 32, n_batches = 1,此时self.residue应该为True,但是 len(batches) % self.n_batches == 0

ModuleNotFoundError: No module named 'pytorch_pretrained_bert'

python3 run.py --model bert

Traceback (most recent call last):
File "run.py", line 5, in
from train_eval import train, init_network
File "/models/Bert-Chinese-Text-Classification-Pytorch/train_eval.py", line 9, in
from pytorch_pretrained_bert.optimization import BertAdam
ModuleNotFoundError: No module named 'pytorch_pretrained_bert'

train_eval.py依然需要pytorch_pretrained_bert

求求大佬帮助!关于预测的文件

请问预测应该如何写呢😂,想要将保存的最好的参数预测自己的语料,文本情感分析用,毕设急求,非常希望大佬解答!!!

关于长文本

我看到您的代码是直接把长文本截断,但是这样是否会降低performance呢?有没有更好的方案呢?

torch 1.5.0 报错:add被弃用

..\torch\csrc\utils\python_arg_parser.cpp:756: UserWarning: This overload of add_ is deprecated:
        add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
        add_(Tensor other, *, Number alpha)

请问如何解决?

训练好的模型如何进行要存

你好,请问训练好的模型应该是bert.ckpt文件吧,这个文件请问要怎么用来进行预测呢,这个是pytorch保存的模型文件吗

CUDA out of memory

您好,请问下您,跑ERNIE模型时,报CUDA out of memory的错误。我的显卡是1050 Ti

您好!问下下面代码对不同的层设置不同的weight_decay作用是什么呢?

no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

no_decay三种层和其他的层相比有什么区别嘛?谢谢!

> > 我也是bert跑出来准确率50左右

我也是bert跑出来准确率50左右
为什么我按照源代码 原数据集训练出来的准确率 和你的差很多呢 而且时间也很长
我检查了下是词表出现了乱码,你可以看看bert里vocab.txt这个文件打开是不是正常的汉字,erine是正常的
我看了一下 我bert里vocab.txt这个文件中的文字跟作者放的链接里面的是一样的 但是我觉得这个中文有点儿奇怪

https://blog.csdn.net/m0_38133212/article/details/88614153

谢谢您 十分感谢

Originally posted by @ZMM6128 in #46 (comment)

No such file or directory: 'THUCNews/saved_dict/bert.ckpt'

按照流程clone项目,下载bert预训练集与vocab装入对应目录,运行run.py,抛出如下错误,请问如何修改?

python3 run.py --model bert   
Loading data...
180000it [00:28, 6288.85it/s]
10000it [00:01, 6258.25it/s]
10000it [00:01, 6435.75it/s]
Time usage: 0:00:32
Epoch [1/3]
Traceback (most recent call last):
  File "run.py", line 37, in <module>
    train(config, model, train_iter, dev_iter, test_iter)
  File "/Users/temco/Downloads/Bert-Chinese-Text-Classification-Pytorch-master/train_eval.py", line 65, in train
    torch.save(model.state_dict(), config.save_path)
  File "/Users/temco/Library/Python/3.7/lib/python/site-packages/torch/serialization.py", line 260, in save
    return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
  File "/Users/temco/Library/Python/3.7/lib/python/site-packages/torch/serialization.py", line 183, in _with_file_like
    f = open(f, mode)
FileNotFoundError: [Errno 2] No such file or directory: 'THUCNews/saved_dict/bert.ckpt'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.