在文章NLP(二十)利用BERT实现文本二分类中,笔者介绍了如何使用BERT来实现文本二分类功能,以判别是否属于出访类事件为例子。但是呢,利用BERT在做模型预测的时候存在预测时间较长的问题。因此,我们考虑用新出来的预训练模型来加快模型预测速度。

  本文将介绍如何利用ALBERT来实现文本二分类。

关于ALBERT

  ALBERT的提出时间大约是在2019年10月,其第一作者为谷歌科学家蓝振忠博士。ALBERT的论文地址为:https://openreview.net/pdf?id=H1eA7AEtvS , Github项目地址为: https://github.com/brightmart/albert_zh

  简单说来,ALBERT是BERT的一个精简版,它在BERT模型的基础上进行改造,减少了大量参数,使得其在模型训练和模型预测的速度上有很大提升,而模型的效果只会有微小幅度的下降,具体的效果和速度方面的说明可以参考Github项目。

  ALBERT相对于BERT的改进如下:

  • 对Embedding因式分解(Factorized embedding parameterization);
  • 跨层的参数共享(Cross-layer parameter sharing);
  • 句间连贯(Inter-sentence coherence loss);
  • 移除dropout 。

  笔者在北京的时候也写过ALBERT在提升序列标注算法的预测速度方面的一篇文章:NLP(十八)利用ALBERT提升模型预测速度的一次尝试 ,该项目的Github地址为:https://github.com/percent4/ALBERT_4_Time_Recognition

项目说明

  本项目的数据和代码主要参考笔者的文章NLP(二十)利用BERT实现文本二分类,该项目是想判别输入的句子是否属于政治上的出访类事件。笔者一共收集了340条数据,其中280条用作训练集,60条用作测试集。

  项目结构如下图:

  在这里我们使用ALBERT已经训练好的文件albert_tiny,借鉴BERT的调用方法,我们在这里给出albert_zh模块,能够让ALBERT提取文本的特征,具体代码不在这里给出,有兴趣的读者可以访问该项目的Github地址:。

  注意,albert_tiny给出的向量维度为312,我们的模型训练代码(model_train.py)如下:

  1. # -*- coding: utf-8 -*-
  2. # author: Jclian91
  3. # place: Pudong Shanghai
  4. # time: 2020-03-04 13:37
  5. import os
  6. import numpy as np
  7. from load_data import train_df, test_df
  8. from keras.utils import to_categorical
  9. from keras.models import Model
  10. from keras.optimizers import Adam
  11. from keras.layers import Input, BatchNormalization, Dense
  12. import matplotlib.pyplot as plt
  13. from albert_zh.extract_feature import BertVector
  14. # 读取文件并进行转换
  15. bert_model = BertVector(pooling_strategy="REDUCE_MEAN", max_seq_len=100)
  16. print('begin encoding')
  17. f = lambda text: bert_model.encode([text])["encodes"][0]
  18. train_df['x'] = train_df['text'].apply(f)
  19. test_df['x'] = test_df['text'].apply(f)
  20. print('end encoding')
  21. x_train = np.array([vec for vec in train_df['x']])
  22. x_test = np.array([vec for vec in test_df['x']])
  23. y_train = np.array([vec for vec in train_df['label']])
  24. y_test = np.array([vec for vec in test_df['label']])
  25. print('x_train: ', x_train.shape)
  26. # Convert class vectors to binary class matrices.
  27. num_classes = 2
  28. y_train = to_categorical(y_train, num_classes)
  29. y_test = to_categorical(y_test, num_classes)
  30. # 创建模型
  31. x_in = Input(shape=(312, ))
  32. x_out = Dense(32, activation="relu")(x_in)
  33. x_out = BatchNormalization()(x_out)
  34. x_out = Dense(num_classes, activation="softmax")(x_out)
  35. model = Model(inputs=x_in, outputs=x_out)
  36. print(model.summary())
  37. model.compile(loss='categorical_crossentropy',
  38. optimizer=Adam(),
  39. metrics=['accuracy'])
  40. # 模型训练以及评估
  41. history = model.fit(x_train, y_train, validation_data=(x_test, y_test), batch_size=8, epochs=20)
  42. model.save('visit_classify.h5')
  43. print(model.evaluate(x_test, y_test))
  44. # 绘制loss和acc图像
  45. plt.subplot(2, 1, 1)
  46. epochs = len(history.history['loss'])
  47. plt.plot(range(epochs), history.history['loss'], label='loss')
  48. plt.plot(range(epochs), history.history['val_loss'], label='val_loss')
  49. plt.legend()
  50. plt.subplot(2, 1, 2)
  51. epochs = len(history.history['acc'])
  52. plt.plot(range(epochs), history.history['acc'], label='acc')
  53. plt.plot(range(epochs), history.history['val_acc'], label='val_acc')
  54. plt.legend()
  55. plt.savefig("loss_acc.png")

  模型训练的效果很不错,在训练集的acc为0.9857,在测试集上的acc为0.9500,具体如下:

与BERT的预测对比

  接下来我们在模型预测上的时间,与BERT的文本二分类模型预测时间做一个对比,这样有助于提升我们对ALBERT的印象。

  BERT的文本二分类模型预测可以参考文章NLP(二十)利用BERT实现文本二分类,本文给出的代码与BERT实现的模型预测代码基本一致,只不过BERT提取特征改成ALBERT提取特征。

  本文的模型预测代码(model_predict.py)如下:

  1. # -*- coding: utf-8 -*-
  2. # author: Jclian91
  3. # place: Pudong Shanghai
  4. # time: 2020-03-04 17:33
  5. import time
  6. import pandas as pd
  7. import numpy as np
  8. from albert_zh.extract_feature import BertVector
  9. from keras.models import load_model
  10. load_model = load_model("visit_classify.h5")
  11. # 预测语句
  12. texts = ['在访问限制中,用户可以选择禁用iPhone的功能,包括Siri、iTunes购买功能、安装/删除应用等,甚至还可以让iPhone变成一台功能手机。以下是访问限制具体可以实现的一些功能',
  13. 'IT之家4月23日消息 近日,谷歌在其官方论坛发布消息表示,他们为Android Auto添加了一项新功能:可以访问完整联系人列表。用户现在可以通过在Auto的电话拨号界面中打开左上角的菜单访问完整的联系人列表。值得注意的是,这一功能仅支持在车辆停止时使用。',
  14. '要通过telnet 访问路由器,需要先通过console 口对路由器进行基本配置,例如:IP地址、密码等。',
  15. 'IT之家3月26日消息 近日反盗版的国际咨询公司MUSO发布了2017年的年度报告,其中的数据显示,去年盗版资源网站访问量达到了3000亿次,比前一年(2016年)提高了1.6%。美国是访问盗版站点次数最多的国家,共有279亿次访问;其后分别是俄罗斯、印度和巴西,中国位列第18。',
  16. '应葡萄牙议会邀请,全国人大常委会副委员长吉炳轩率团于12月14日至16日访问葡萄牙,会见副议长费利佩、社会党副总书记卡内罗。',
  17. '2月26日至3月2日,应香港特区政府“内地贵宾访港计划”邀请,省委常委、常务副省长陈向群赴港考察访问,重点围绕“香港所长、湖南所需”,与特区政府相关部门和机构深入交流,推动湖南与香港交流合作取得新进展。',
  18. '目前A站已经恢复了访问,可以直接登录,网页加载正常,视频已经可以正常播放。',
  19. '难民署特使安吉丽娜·朱莉6月8日结束了对哥伦比亚和委内瑞拉边境地区的难民营地为期两天的访问,她对哥伦比亚人民展现的人道主义和勇气表示赞扬。',
  20. '据《南德意志报》报道,德国总理默克尔计划明年1月就前往安卡拉,和土耳其总统埃尔多安进行会谈。',
  21. '自9月14日至18日,由越共中央政治局委员、中央书记处书记、中央经济部部长阮文平率领工作代表团对希腊进行工作访问。',
  22. 'Win7电脑提示无线适配器或访问点有问题怎么办?很多用户在使用无线网连接上网时,发现无线网显示已连接,但旁边却出现了一个黄色感叹号,无法进行网络操作,通过诊断提示电脑无线适配器或访问点有问题,且处于未修复状态,这该怎么办呢?下面小编就和大家分享下Win7电脑提示无线适配器或访问点有问题的解决方法。',
  23. '2019年10月13日至14日,外交部副部长马朝旭访问智利,会见智利外长里韦拉,同智利总统外事顾问萨拉斯举行会谈,就智利举办亚太经合组织(APEC)第二十七次领导人非正式会议等深入交换意见。',
  24. '未开发所有安全组之前访问,FTP可以链接上,但是打开会很慢,需要1-2分钟才能链接上',
  25. 'win7系统电脑的用户,在连接WIFI网络网上时,有时候会遇到突然上不了网,查看连接的WIFI出现“有限的访问权限”的文字提示。',
  26. '联合国秘书长潘基文8日访问了日本福岛县,与当地灾民交流并访问了一所高中。',
  27. '国务院总理温家宝当地时间23日下午乘专机抵达布宜诺斯艾利斯,开始对阿根廷进行正式访问。',
  28. '正在中国访问的巴巴多斯总理斯图尔特15日在陕西西安参观访问。',
  29. '据外媒报道,当地时间10日,美国白宫发声明称,美国总统特朗普将于2月底访问印度,与印度总理莫迪进行战略对话。',
  30. '2月28日,唐山曹妃甸蓝色海洋科技有限公司董事长赵力军等一行5人到黄海水产研究所交流访问。黄海水产研究所副所长辛福言及相关部门负责人、专家等参加了会议。',
  31. '2018年7月2日,莫斯科孔子文化促进会会长姜彦彬,常务副会长陈国建,在中国著名留俄油画大师牟克教授的陪同下,访问了莫斯科国立苏里科夫美术学院,受到第一副校长伊戈尔·戈尔巴秋克先生接待。'
  32. '据外媒报道,当地时间26日晚,阿尔及利亚总统特本抵达沙特阿拉伯,进行为期三天的访问。两国领导人预计将就国家间合作和地区发展进行磋商。',
  33. '与标准Mozy一样,Stash文件夹为用户提供了对其备份文件的基于云的访问,但是它们还使他们可以随时,跨多个设备(包括所有计算机,智能手机和平板电脑)访问它们。换句话说,使用浏览器的任何人都可以同时查看文件(如果需要)。操作系统和设备品牌无关。',
  34. '研究表明,每个网页的平均预期寿命为44至100天。当用户通过浏览器访问已消失的网页时,就会看到「Page Not Found」的错误信息。对于这种情况,相信大多数人也只能不了了之。不过有责任心的组织——互联网档案馆为了提供更可靠的Web服务,它联手Brave浏览器专门针对此类网页提供了一键加载存档页面的功能。',
  35. '据外媒报道,土耳其总统府于当地时间2日表示,土耳其总统埃尔多安计划于5日对俄罗斯进行为期一天的访问。',
  36. '3日,根据三星电子的消息,李在镕副会长这天访问了位于韩国庆尚北道龟尾市的三星电子工厂。'] * 10
  37. labels = []
  38. bert_model = BertVector(pooling_strategy="REDUCE_MEAN", max_seq_len=100)
  39. init_time = time.time()
  40. # 对上述句子进行预测
  41. for text in texts:
  42. # 将句子转换成向量
  43. vec = bert_model.encode([text])["encodes"][0]
  44. x_train = np.array([vec])
  45. # 模型预测
  46. predicted = load_model.predict(x_train)
  47. y = np.argmax(predicted[0])
  48. label = 'Y' if y else 'N'
  49. labels.append(label)
  50. cost_time = time.time() - init_time
  51. print("Average cost time: %s." % (cost_time/len(texts)))
  52. for text, label in zip(texts, labels):
  53. print('%s\t%s' % (label, text))
  54. df = pd.DataFrame({'句子':texts, "是否属于出访类事件": labels})
  55. df.to_excel('./result.xlsx', index=False)

输出的平均预测时长为:16.98ms,而BERT版的平均预测时间为:257.31ms

  我们将模型预测写成HTTP服务,代码(server.py)如下:

  1. # -*- coding: utf-8 -*-
  2. # author: Jclian91
  3. # place: Pudong Shanghai
  4. # time: 2020-03-04 20:13
  5. import tornado.httpserver
  6. import tornado.ioloop
  7. import tornado.options
  8. import tornado.web
  9. from tornado.options import define, options
  10. import json
  11. import numpy as np
  12. from albert_zh.extract_feature import BertVector
  13. from keras.models import load_model
  14. # 定义端口为10008
  15. define("port", default=10008, help="run on the given port", type=int)
  16. # 加载ALBERT
  17. bert_model = BertVector(pooling_strategy="REDUCE_MEAN", max_seq_len=100)
  18. # 加载已经训练好的模型
  19. load_model = load_model("visit_classify.h5")
  20. # 对句子进行预测
  21. class PredictHandler(tornado.web.RequestHandler):
  22. def post(self):
  23. text = self.get_argument("text")
  24. # 将句子转换成向量
  25. vec = bert_model.encode([text])["encodes"][0]
  26. x_train = np.array([vec])
  27. # 模型预测
  28. predicted = load_model.predict(x_train)
  29. y = np.argmax(predicted[0])
  30. label = '是' if y else "否"
  31. # 返回结果
  32. result = {"原文": text, "是否属于出访类事件?": label}
  33. self.write(json.dumps(result, ensure_ascii=False, indent=2))
  34. # 主函数
  35. def main():
  36. # 开启tornado服务
  37. tornado.options.parse_command_line()
  38. # 定义app
  39. app = tornado.web.Application(
  40. handlers=[(r'/predict', PredictHandler)] #网页路径控制
  41. )
  42. http_server = tornado.httpserver.HTTPServer(app)
  43. http_server.listen(options.port)
  44. tornado.ioloop.IOLoop.instance().start()
  45. main()

用Postman进行测试,如下图:

  实践证明,用ALBERT做文本特征提取,模型训练的效果基本与BERT差别微小,模型训练速度明显提升,更重要的是,模型预测的速度只有BERT版本的6.6%(不同情况下可能有略微差异),这在生产上是十分有帮助的。

参考网址

  1. 中文预训练ALBERT模型来了:小模型登顶GLUE,Base版模型小10倍速度快1倍: https://zhuanlan.zhihu.com/p/85037097
  2. ALBERT一作蓝振忠:预训练模型应用已成熟,ChineseGLUE要对标GLUE基准:https://tech.sina.com.cn/roll/2019-11-17/doc-iihnzhfy9804802.shtml
  3. 解读ALBERT:https://blog.csdn.net/weixin_37947156/article/details/101529943
  4. ALBERT的Github项目地址:https://github.com/brightmart/albert_zh

NLP(二十二)利用ALBERT实现文本二分类的更多相关文章

  1. NLP(二十)利用BERT实现文本二分类

      在我们进行事件抽取的时候,我们需要触发词来确定是否属于某个特定的事件类型,比如我们以政治上的出访类事件为例,这类事件往往会出现"访问"这个词语,但是仅仅通过"访问&q ...

  2. NLP(二十八)多标签文本分类

      本文将会讲述如何实现多标签文本分类. 什么是多标签分类?   在分类问题中,我们已经接触过二分类和多分类问题了.所谓二(多)分类问题,指的是y值一共有两(多)个类别,每个样本的y值只能属于其中的一 ...

  3. NLP(二十) 利用词向量实现高维词在二维空间的可视化

    准备 Alice in Wonderland数据集可用于单词抽取,结合稠密网络可实现其单词的可视化,这与编码器-解码器架构类似. 代码 from __future__ import print_fun ...

  4. 小小知识点(二十)利用MATLAB计算定积分

    一重定积分 1. Z = trapz(X,Y,dim) 梯形数值积分,通过已知参数x,y按dim维使用梯形公式进行积分 %举例说明1 clc clear all % int(sin(x),0,pi) ...

  5. 从零开始学安全(二十六)●利用Nmap目标的本版进行探测

    通过对对方电脑的服务探测 对本版较低的服务 或者无补丁的服务 可以直入侵 版本探测 version  后边就是版本

  6. NLP(十六)轻松上手文本分类

    背景介绍   文本分类是NLP中的常见的重要任务之一,它的主要功能就是将输入的文本以及文本的类别训练出一个模型,使之具有一定的泛化能力,能够对新文本进行较好地预测.它的应用很广泛,在很多领域发挥着重要 ...

  7. Dynamic CRM 2013学习笔记(二十)字段改变事件的二种实现方法

    CRM里有二种方式实现字段change事件,一种是在form里,一种完全通过js来实现.本文介绍下二者的用途及区别. 1. Form里用法 这种方式估计其实也是添加一个js的function. 这种方 ...

  8. javaweb学习总结二十六(response对象的用法二 下载文件)

    一:浏览器打开服务器上的文件 1:读取服务器上面的资源,如果在web层,可以直接使用servletContext,如果在非web层 可以使用类加载器读取文件 2:向浏览器写数据,实际上是把数据封装到r ...

  9. java 面向对象(二十九):异常(二)异常的处理

    1.java异常处理的抓抛模型过程一:"抛":程序在正常执行的过程中,一旦出现异常,就会在异常代码处生成一个对应异常类的对象. * 并将此对象抛出. * 一旦抛出对象以后,其后的代 ...

随机推荐

  1. Gitbook在 Mac 环境上的安装及使用

    一.在 Mac 环境上搭建 gitbook #.安装node.js,在node.js官网下载,直接安装稳定版本. https://nodejs.org/en/ #.检测 node.js 是否安装成功 ...

  2. 二十四、SSH介绍

    1.ssh介绍: SSH先对联机数据包通过加密技术进行加密处理,加密后在进行数据传输,确保了传递的数据安全.(运维的一大重视点就是要对安全敏感) 在当前的生产环境运维工作中,绝大多数企业都是SSH协议 ...

  3. L3-016 二叉搜索树的结构 (30 分)

    二叉搜索树或者是一棵空树,或者是具有下列性质的二叉树: 若它的左子树不空,则左子树上所有结点的值均小于它的根结点的值:若它的右子树不空,则右子树上所有结点的值均大于它的根结点的值:它的左.右子树也分别 ...

  4. ckeditor+ckfinder添加水印。

    1.修改ckfinder文件下面的config.php:添加一句include_once "plugins/watermark/plugin.php";//水印配置文件 2.修改p ...

  5. wareshark判断一个http请求链接是否断开

    使用curl -v www.baidu.com发送一个请求 使用wareshark的过滤器表达式显示这个完整请求 TCP HTTP协议 , 其中192.168.1.4是本地ip 可以看到84 85两个 ...

  6. [LC] 78. Subsets

    Given a set of distinct integers, nums, return all possible subsets (the power set). Note: The solut ...

  7. cs231n spring 2017 lecture12 Visualizing and Understanding

    这一节课很零碎. 1. 神经网络到底在干嘛? 浅层的是具体的特征(比如边.角.色块等),高层的更抽象,最后的全连接层是把图片编码成一维向量然后和每一类标签作比较.如果直接把图片和标签做像素级的最近领域 ...

  8. mysql 优化2 慢查询

    默认情况下mysql不记录慢查询日志,需要在启动的时候指定 bin\mysqld.exe - -slow-query-log 通过慢查询日志定位执行效率较低的SQL语句.慢查询日志记录了所有执行时间超 ...

  9. Golang Slice 总结

    数组 Go的切片是在数组之上的抽象数据类型,因此在了解切片之前必须要要理解数组.数组类型由指定和长度和元素类型定义.数组不需要显式的初始化:数组元素会自动初始化为零值:Go的数组是值语义.一个数组变量 ...

  10. 如何升级gcc

    https://blog.csdn.net/zhaomax/article/details/87807711 1.环境:arm架构的centos6.5系统服务器 2.查看当前的gcc版本:gcc  - ...