colab上基于tensorflow2.0的BERT中文多分类
bert模型在tensorflow1.x版本时,也是先发布的命令行版本,随后又发布了bert-tensorflow包,本质上就是把相关bert实现封装起来了。
tensorflow2.0刚刚在2019年10月左右发布,谷歌也在积极地将之前基于tf1.0的bert实现迁移到2.0上,但近期看还没有完全迁移完成,所以目前还没有基于tf2.0的bert安装包面世,因为近期想基于现有发布的模型做一个中文多分类的事情,所以干脆就弄了个基于命令行版本的。过程中有一些坑,随之记录下来。
1. colab:因为想用谷歌免费的GPU(暂时没研究TPU怎么用),所以直接在colab上弄。
2. 中文多分类:
2.1. 训练数据来源:在百度百科上找了大概100多个词条的数据,自己随便标注成大概8个类别吧。把百科的概述、正文、属性等信息进行清洗后连到一起,类似于这种格式:
label info
1 词条1的描述balabala
0 词条2的描述balabala
然后再按照9:1分成两个集合,分别明明为train.tsv 和 dev.tsv,作为训练集和测试集。注意每个训练集的第一行都是标题行,这个是我的数据解析器里面这么定义的。
2.2. 预训练模型:谷歌已经提供了基于tf2.0 keras网络结构的中文预训练模型,页面地址是:https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/1。直接使用就可以了。注意,基于tf1.0的中文预训练模型(bert_chinese_L-12_H-768_A-12)不能在tf2.0里使用,要用脚本转换,但我们既然已经有了最新的模型,就直接用啦。
2.3. bert代码:在github上直接下载:https://github.com/tensorflow/models.git。注意bert已经被谷歌从tensorflow中分离出来,放在models目录下当成第三方独立代码了,所以需要自己下载配置。
2.4. 数据预处理脚本:在2.3步骤下载下来的bert代码里,位置是:models/official/nlp/bert/create_finetuning_data.py,其中已经定义好了支持若干种数据格式的类及实现,因为我需要处理上面自己定义的那种格式的数据,所以自己写了一个处理百度百科的类放到里面了,如果大家有自己的数据格式,修改后覆盖原来的文件就ok了,具体需要改的是:classifier_data_lib.py和create_finetuning_data.py 这两个文件
2. tf-nightly和bert代码下载:目前这个时间段基于tf2.0的bert只能在tf-nightly下面使用(看社区里的留言,应该在tf2.1正式发布的时候就会提供bert正式版了),所以要安装tf-nightly并且在这下面运行后面的代码。
3. 数据预处理脚本的执行:这个就按照命令行的模式在colab里调用脚本create_finetuning_data.py就可以了,没什么难的,有个坑是目前tf2.0的中文预训练模型没提供基于gs的存储位置,而预处理脚本中需要vocab.txt来分词,所以要先离线把模型下载下来,解压缩后,把里面的vocab.txt拿出来并上传到colab上,然后在预训练脚本里制定文件位置就ok(我把vocab.txt放到我的github上了,可以直接调用获取,但如果想获取最新的vocab.txt,最好自己下载然后加压获取。后续谷歌应该会提供在线模型地址,就不用这么麻烦了)
4. finetune:直接调用脚本models/official/nlp/bert/run_classifier.py,这里有个坑是脚本参数里需要bert_config.json,但上面的中文预处理模型没提供这个模型配置文件,所以干脆从其他tf1.0的模型里copy了一个过来(我用的是uncased_L-12_H-768_A-12的bert_config.json)
代码我都放到github上了,大家自己取用即可,欢迎拍砖、吐槽、交流!
https://github.com/liloi/bert-tf2/blob/master/bert-tf2-zh-demo.ipynb
colab上基于tensorflow2.0的BERT中文多分类的更多相关文章
- 基于tensorflow2.0 使用tf.keras实现Fashion MNIST
本次使用的是2.0测试版,正式版估计会很快就上线了 tf2好像更新了蛮多东西 虽然教程不多 还是找了个试试 的确简单不少,但是还是比较喜欢现在这种写法 老样子先导入库 import tensorflo ...
- 基于tensorflow2.0和cifar100的VGG13网络训练
VGG是2014年ILSVRC图像分类竞赛的第二名,相比当年的冠军GoogleNet在可扩展性方面更胜一筹,此外,它也是从图像中提取特征的CNN首选算法,VGG的各种网络模型结构如下: 今天代码的原型 ...
- 【tensorflow2.0】处理图片数据-cifar2分类
1.准备数据 cifar2数据集为cifar10数据集的子集,只包括前两种类别airplane和automobile. 训练集有airplane和automobile图片各5000张,测试集有airp ...
- 推荐模型DeepCrossing: 原理介绍与TensorFlow2.0实现
DeepCrossing是在AutoRec之后,微软完整的将深度学习应用在推荐系统的模型.其应用场景是搜索推荐广告中,解决了特征工程,稀疏向量稠密化,多层神经网路的优化拟合等问题.所使用的特征在论文中 ...
- 编译可在Nexus5上运行的CyanogenMod13.0 ROM(基于Android6.0)
编译可在Nexus5上运行的CyanogenMod13.0 ROM (基于Android6.0) 作者:寻禹@阿里聚安全 前言 下文中无特殊说明时CM代表CyanogenMod的缩写. 下文中说的“设 ...
- Servlet3.0学习总结——基于Servlet3.0的文件上传
Servlet3.0学习总结(三)——基于Servlet3.0的文件上传 在Servlet2.5中,我们要实现文件上传功能时,一般都需要借助第三方开源组件,例如Apache的commons-fileu ...
- 一文上手Tensorflow2.0之tf.keras(三)
系列文章目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU ...
- 基于AFNetworking3.0网络封装
概述 对于开发人员来说,学习网络层知识是必备的,任何一款App的开发,都需要到网络请求接口.很多朋友都还在使用原生的NSURLConnection一行一行地写,代码到处是,这样维护起来更困难了. 对于 ...
- iOS_SN_基于AFNetworking3.0网络封装
转发文章,原地址:http://www.henishuo.com/base-on-afnetworking3-0-wrapper/?utm_source=tuicool&utm_medium= ...
随机推荐
- tf.reduce_sum()
#axis 表示在哪个维度进行sum操作,不写代表所有维 #keep_dims 是否保留原始数据维度 reduce_sum( input_tensor, axis=None, keep_dims=Fa ...
- linux 自动检测 IRQ 号
驱动在初始化时最有挑战性的问题中的一个是如何决定设备要使用哪个 IRQ 线. 驱动需 要信息来正确安装处理. 尽管程序员可用请求用户在加载时指定中断号, 这是个坏做法, 因为大部分时间用户不知道这个号 ...
- 【50.40%】【BZOJ 4553】[Tjoi2016&Heoi2016]序列
Time Limit: 20 Sec Memory Limit: 128 MB Submit: 371 Solved: 187 [Submit][Status][Discuss] Descript ...
- 2018-2-13-WPF-异常-NativeWPFDLLLoader.LoadNativeWPFDLL
title author date CreateTime categories WPF 异常 NativeWPFDLLLoader.LoadNativeWPFDLL lindexi 2018-2-13 ...
- apache WEB服务器安装(包括虚拟主机)
一.apache下载编译安装 yum install apr apr-devel apr-util apr-util-devel gcc-c++ wget tar -y cd /usr/src wge ...
- 29(30).socket网络基础
转载:https://www.cnblogs.com/linhaifeng/articles/6129246.html 一 客户端/服务器架构 1.硬件C/S架构(打印机) 2.软件C/S架构 互联网 ...
- 22.XML
转载:https://www.cnblogs.com/yuanchenqi/article/5732581.html xml是实现不同语言或程序之间进行数据交换的协议,跟json差不多,但json使用 ...
- ajax异步发送时遇到的问题
问题原因是:controller中方法名与url中的名字不一样造成的 解决办法:找到错误的方法名,将其与url中的方法名统一:
- 选择合适的最短路--hdu3499
[题目链接](http://acm.hdu.edu.cn/showproblem.php?pid=3499) 刚看见题目,哦?不就是个最短路么,来,跑一下dijkstra记录最长路除个二就完事了 ,但 ...
- DRF 06
目录 视图家族 views视图类 mixin视图工具类 generics工具视图类 viewsets视图集 路由配置 视图家族 views视图类 APIView """ ...