Giter Club home page Giter Club logo

cpm-lm-tf2's Introduction

TensorFlow 2.x CPM-Generate

本Repo将模型转换为TensorFlow版本,原Repo https://github.com/TsinghuaAI/CPM-Generate

原项目首页:https://cpm.baai.ac.cn/

原项目介绍文章:https://mp.weixin.qq.com/s/oI2Ak-M57MSuycLVpVEiHw

如果你只想大概看一下结果,请直接打开`prediction_v2.ipynb`文件预览

感谢智源研究院的工作!

Q:为什么你要把源代码转换为TensorFlow版本?

A:1. 我装不上Nvidia apex;2. 按照原Repo说明需要两张显卡的主机运行,我也没有; 3. TensorFlow模型很方便。

^_^ 如果你喜欢我的工作,请给智源研究院的原Repo打星,如果能顺便给我的Repo也打一个就更好了。

使用方法

  1. Clone本Repo

  2. 下载模型:

打开下面分享下载CPM目录的cpm-lm-tf2

也可以下载cpm-lm-tf2_v2,是第二个版本,如果下载这个,后面的也要改成主要参考prediction_v2.ipynb,v2版本的区别是增加了top_p和temperature两个参数
链接: https://pan.baidu.com/s/1tjbWty2hkbmtCrvV9Qh_SQ  密码: n0nt
--来自百度网盘超级会员V7的分享
另一个目录`cpm-lm-tf2-fp16`是fp16版本的模型,但是除非你有显卡,并且确认支持float16,并且确认正确安装CUDA,否则`请使用非fp16`的版本,因为在不支持float16的设备上,会非常慢!

下载cpm-lm-tf2到Clone好的Repo目录,结构大概是这样:

- cpm-tf2
  - cpm-lm-tf2  (从网盘下载好的TF2版本模型)
    - assets
    - saved_model.pb
    - variables
  - CPM-Generate
    - bpe_3w_new (词表所在目录)
  - prediction.ipynb  (预测demo,主程序)
  - gpt2_tokenizer.py  (分词文件,这个里面引入了jieba,和huggingface那一系列的不能简单互换)
运行所需的代码其实就大概以上的几个文件和目录就够了,其他的主要是模型转换等代码
  1. 安装依赖
# 依赖:
pip install sentencepiece
pip install jieba
pip install regex
pip install tensorflow
pip install tensorflow-hub
  1. 参考prediction.ipynb中的代码运行

    如果你下载的是_v2版本的模型,请参考prediction_v2.ipynb

其中所需大概代码就这么几行:

import tensorflow_hub as hub
import tensorflow as tf

from gpt2_tokenizer import GPT2Tokenizer

tokenizer = GPT2Tokenizer(
    'CPM-Generate/bpe_3w_new/vocab.json',
    'CPM-Generate/bpe_3w_new/merges.txt',
    model_file='CPM-Generate/bpe_3w_new/chinese_vocab.model')

gpt = hub.load('./cpm-lm-tf2/')

def sample(tokenizer, gpt, sentence, number=1, length=20):
    inputs = tf.constant([tokenizer.encode(sentence)] * number, dtype=tf.int64)
    length = tf.constant(length, dtype=tf.int64)
    ret = gpt.signatures['serving_default'](inp=inputs, length=length)['output_0']
    return [
        tokenizer.decode(s).replace(' ', '')
        for s in ret.numpy()
    ]

ret = sample(tokenizer, gpt, '书写英文:\n狗dog\n猫cat\n鸟', 3, 10)
for x in ret:
    print(x)
    print('-' * 20)

一些额外的闲聊

模型的转换参考

模型的具体转换代码在load_pytorch.ipynb文件中,希望有类似torch to tensorflow的同学可以参考

原模型的一些细节

模型的训练基础是英伟达的https://github.com/NVIDIA/Megatron-LM,这大概算是一个英伟达魔改的PyTorch上的高级API,论文在https://arxiv.org/pdf/1909.08053.pdf

这个应该主要是为了能把一个很大的模型在很多张显卡上更好的并行训练而设计的,原模型分了两个文件,也提到了需要两张显卡,应该是在每张显卡上分别载入这两个文件。

这两个文件中大概各有一半的模型参数,有些层,例如全连接层(Dense)的参数会平均到两个模型中。

比如一个256到256的Dense层,按道理来说有一个256x256的kernel和一个256的bias,平分之后会在每个文件分别有一个128x256的kernel和一个128的bias。

因为有些层无法平分,例如LayerNormalization层,所以是在两个文件中有重复的。

fp32和fp16

fp16在笔者的CPU上几乎和龟速一样(Macbook Pro 2020),比fp32的慢了好多倍。

猜测这应该是由于现代cpu上实际上不具备物理的fp16运算器导致的,也就是每次进行fp16的前后其实是把fp16转换为了32再运行的,所以非常浪费CPU。

fp16的模型相比fp32的其实是有一些区别的,主要是原来的attention mask的问题,因为1e10这个数字在fp32上是合法的,但是在fp16上是inf,所以笔者把这部分mask的1e10的超参改为了1e4,才跑起来fp16的模型。

TensorFlow版本和原版本的区别

道理来讲应该没有什么太大区别,而且也载入了原来的参数,不过毕竟还是有GPU -> CPU,PyTorch -> TensorFlow这样的转换,所以可能和原模型结果有一定出入,不过笔者估计这个出入不会很大,顶多1%左右。

引用

参考原Repo

@article{cpm-v1,
  title={CPM: A Large-scale Generative Chinese Pre-trained Language Model},
  author={Zhang, Zhengyan and Han, Xu, and Zhou, Hao, and Ke, Pei, and Gu, Yuxian and Ye, Deming and Qin, Yujia and Su, Yusheng and Ji, Haozhe and Guan, Jian and Qi, Fanchao and Wang, Xiaozhi and Zheng, Yanan and Cao, Jiannan and Zeng, Guoyang and Cao, Huanqi and Chen, Shengqi and Li, Daixuan and Sun, Zhenbo and Liu, Zhiyuan and Huang, Minlie and Han, Wentao and Tang, Jie and Li, Juanzi and Sun, Maosong},
  year={2020}
}

cpm-lm-tf2's People

Contributors

qhduan avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.