本文记录使用BERT预训练模型,修改最顶层softmax层,微调几个epoch,进行文本分类任务。

BERT源码

首先BERT源码来自谷歌官方tensorflow版:https://github.com/google-research/bert

注意,这是tensorflow 1.x 版本的。

BERT预训练模型

预训练模型采用哈工大讯飞联合实验室推出的WWM(Whole Word Masking)全词覆盖预训练模型,主要考量是BERT对于中文模型来说,是按照字符进行切割,但是注意到BERT随机mask掉15%的词,这里是完全随机的,对于中文来说,很有可能一个词的某些字被mask掉了,比如说让我预测这样一句话:

原话: ”我今天早上去打羽毛球了,然后又去蒸了桑拿,感觉身心愉悦“

MASK:”我[MASK]天早上去打[MASK]毛球了,然后[MASK]去蒸了[MASK]拿,感觉身心[MASK]悦“

虽然说从统计学意义上来讲这样做依然可以学得其特征,但这样实际上破坏了中文特有的词结构,那么全词覆盖主要就是针对这个问题,提出一种机制保证在MASK的时候要么整个词都不MASK,要么MASK掉整个词。

WWM MASK:”我今天早上去打[MASK][MASK][MASK]了,然后又去蒸了[MASK][MASK],感觉身心愉悦“

例子可能举得不是很恰当,但大概是这个意思,可以参考这篇文章:

https://www.jiqizhixin.com/articles/2019-06-21-01

修改源码

首先看到下下来的项目结构:

可以看到run_classifier.py文件,这个是我们需要用的。另外,chinese开头的文件是我们的模型地址,data文件是我们的数据地址,这个每个人可以自己设置。

在run_classifier.py文件中,有一个基类DataProcessor类,这个是我们需要继承并重写的:

class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError() def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError() def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError() def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError() @classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with tf.gfile.Open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines

  

可以看到我们需要实现获得训练、验证、测试数据接口,以及获得标签的接口。

这里我自己用的一个类。注释比较详细,就不解释了,主要体现了只要能获得数据,不论我们的文件格式是什么样的,都可以,所以不需要专门为了这个项目去改自己的输入数据格式。

class StatutesProcessor(DataProcessor):

    def _read_txt_(self, data_dir, x_file_name, y_file_name):
# 定义我们的读取方式,我的工程中已经将x文本和y文本分别存入txt文件中,没有分隔符
# 用gfile读取,打开一个没有线程锁的的文件IO Wrapper
# 基本上和python原生的open是一样的,只是在某些方面更高效一点
with tf.gfile.Open(data_dir + x_file_name, 'r') as f:
lines_x = [x.strip() for x in f.readlines()]
with tf.gfile.Open(data_dir + y_file_name, 'r') as f:
lines_y = [x.strip() for x in f.readlines()]
return lines_x, lines_y def get_train_examples(self, data_dir):
lines_x, lines_y = self._read_txt_(data_dir, 'train_x.txt', 'train_y.txt')
examples = []
for (i, line) in enumerate(zip(lines_x, lines_y)):
guid = 'train-%d' % i
# 规范输入编码
text_a = tokenization.convert_to_unicode(line[0])
label = tokenization.convert_to_unicode(line[1])
# 这里有一些特殊的任务,一般任务直接用上面的就行,下面的label操作可以注释掉
# 这里因为y会有多个标签,这里按单标签来做
label = label.strip().split()[0] # 这里不做匹配任务,text_b为None
examples.append(
InputExample(guid=guid, text_a=text_a, label=label)
)
return examples def get_dev_examples(self, data_dir):
lines_x, lines_y = self._read_txt_(data_dir, 'val_x.txt', 'val_y.txt')
examples = []
for (i, line) in enumerate(zip(lines_x, lines_y)):
guid = 'train-%d' % i
# 规范输入编码
text_a = tokenization.convert_to_unicode(line[0])
label = tokenization.convert_to_unicode(line[1])
label = label.strip().split()[0] # 这里不做匹配任务,text_b为None
examples.append(
InputExample(guid=guid, text_a=text_a, label=label)
)
return examples def get_test_examples(self, data_dir):
lines_x, lines_y = self._read_txt_(data_dir, 'test_x.txt', 'test_y.txt')
examples = []
for (i, line) in enumerate(zip(lines_x, lines_y)):
guid = 'train-%d' % i
# 规范输入编码
text_a = tokenization.convert_to_unicode(line[0])
label = tokenization.convert_to_unicode(line[1])
label = label.strip().split()[0] # 这里不做匹配任务,text_b为None
examples.append(
InputExample(guid=guid, text_a=text_a, label=label)
)
return examples def get_labels(self):
# 我事先统计了所有出现的y值,放在了vocab_y.txt里
# 因为这里没有原生的接口,这里暂时这么做了,只要保证能读到所有的类别就行了
with tf.gfile.Open('data/statutes_small/vocab_y.txt', 'r') as f:
vocab_y = [x.strip() for x in f.readlines()] return vocab_y

  

写好了之后需要更新一下processors列表,在main函数中,最下面一条就是我新加的。

执行训练微调

python run_classifier.py --data_dir=data/statutes_small/ --task_name=cail2018 --vocab_file=chinese_wwm_ext_L-12_H-768_A-12/vocab.txt --bert_config_file=chinese_wwm_ext_L-12_H-768_A-12/bert_config.json --output_dir=output/ --do_train=true --do_eval=true --init_checkpoint=chinese_wwm_ext_L-12_H-768_A-12/bert_model.ckpt --max_seq_length=200 --train_batch_size=16 --learning_rate=5e-5 --num_train_epoch=3

相信我,写在一行,这个会有很多小问题,在centos服务器上如果不能按上返回上一条命令,将会很痛苦。。具体参数含义就和参数名是一致的,不需要解释。

另外,可以稍稍修改一些东西来动态输入训练集上的loss,因为BERT源码封装的太高了,所以只能按照这篇文章:https://www.cnblogs.com/jiangxinyang/p/10241243.html里面讲的方法,每100个step输出一次train loss(就是100个batch),这样做虽然意义不大,但是可以看在你的数据集上模型是不是在收敛,方便调整学习率。

在测试集上进行测试

默认test_batch_size = 8

python run_classifier.py --data_dir=data/statutes_small/ --task_name=cail2018 --vocab_file=chinese_wwm_ext_L-12_H-768_A-12/vocab.txt --bert_config_file=chinese_wwm_ext_L-12_H-768_A-12/bert_config.json --output_dir=output/ --do_predict=true --max_seq_length=200

需要注意的是,调用测试接口会在out路径中生成一个test_results.tsv,这是一个以’\t’为分隔符的文件,记录了每一条输入测试样例,输出的每一个维度的值(维度数就是类别数目),需要手动做一点操作来得到最终分类结果,以及计算指标等等。

# 计算测试结果
# 因为原生的predict生成一个test_results.tsv文件,给出了每一个sample的每一个维度的值
# 却并没有给出具体的类别预测以及指标,这里再对这个“中间结果手动转化一下” def cal_accuracy(rst_file_dir, y_test_dir):
rst_contents = pd.read_csv(rst_file_dir, sep='\t', header=None)
# value_list: ndarray
value_list = rst_contents.values
pred = value_list.argmax(axis=1)
labels = [] # 这一步是获取y标签到id,id到标签的对应dict,每个人获取的方式应该不一致
y2id, id2y = get_y_to_id(vocab_y_dir='../data/statutes_small/vocab_y.txt')
with open(y_test_dir, 'r', encoding='utf-8') as f:
line = f.readline()
while line:
# 这里因为y有多个标签,我要取第一个标签,所以要单独做操作
label = line.strip().split()[0]
labels.append(y2id[label])
line = f.readline()
labels = np.asarray(labels) # 预测,pred,真实标签,labels
accuracy = metrics.accuracy_score(y_true=labels, y_pred=pred)
# 这里只举例了accuracy,其他的指标也类似计算
print(accuracy) def get_y_to_id(vocab_y_dir):
# 这里把所有的y标签值存在了文件中
y_vocab = open(vocab_y_dir, 'r', encoding='utf-8').read().splitlines()
y2idx = {token: idx for idx, token in enumerate(y_vocab)}
idx2y = {idx: token for idx, token in enumerate(y_vocab)}
return y2idx, idx2y

  

这部分代码在classifier/cal_test_matrix.py中。

我的代码地址:

点击这里

参考:

https://github.com/google-research/bert

https://www.cnblogs.com/jiangxinyang/p/10241243.html

https://www.jiqizhixin.com/articles/2019-06-21-01

https://arxiv.org/abs/1906.08101

使用BERT预训练模型+微调进行文本分类的更多相关文章

  1. Pytorch——BERT 预训练模型及文本分类

    BERT 预训练模型及文本分类 介绍 如果你关注自然语言处理技术的发展,那你一定听说过 BERT,它的诞生对自然语言处理领域具有着里程碑式的意义.本次试验将介绍 BERT 的模型结构,以及将其应用于文 ...

  2. 文本分类实战(十)—— BERT 预训练模型

    1 大纲概述 文本分类这个系列将会有十篇左右,包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMo,BERT等)的文本分类.总共有以下系列: word2vec预训练词向量 te ...

  3. BERT预训练模型的演进过程!(附代码)

    1. 什么是BERT BERT的全称是Bidirectional Encoder Representation from Transformers,是Google2018年提出的预训练模型,即双向Tr ...

  4. bert 预训练模型路径

    google的bert预训练模型: BERT-Large, Uncased (Whole Word Masking): 24-layer, 1024-hidden, 16-heads, 340M pa ...

  5. BERT的通俗理解 预训练模型 微调

    1.预训练模型      BERT是一个预训练的模型,那么什么是预训练呢?举例子进行简单的介绍      假设已有A训练集,先用A对网络进行预训练,在A任务上学会网络参数,然后保存以备后用,当来一个新 ...

  6. NLP与深度学习(五)BERT预训练模型

    1. BERT简介 Transformer架构的出现,是NLP界的一个重要的里程碑.它激发了很多基于此架构的模型,其中一个非常重要的模型就是BERT. BERT的全称是Bidirectional En ...

  7. BERT和ULMFIT embedding比较文本分类结果

    Instructions [THIS REPOSITORY IS UNDER DEVELOPMENT AND MOER DATASETS AND MODELS WILL BE ADDED] [FEEL ...

  8. Paddle预训练模型应用工具PaddleHub

    Paddle预训练模型应用工具PaddleHub 本文主要介绍如何使用飞桨预训练模型管理工具PaddleHub,快速体验模型以及实现迁移学习.建议使用GPU环境运行相关程序,可以在启动环境时,如下图所 ...

  9. 文本分类实战(九)—— ELMO 预训练模型

    1 大纲概述 文本分类这个系列将会有十篇左右,包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMo,BERT等)的文本分类.总共有以下系列: word2vec预训练词向量 te ...

随机推荐

  1. 从一个案例窥探ORACLE的PASSWORD_VERSIONS

    1.环境说明 ORACLE 客户端版本 11.2.0.1 ORACLE 服务端版本 12.2.0.1 2.异常现象 客户端(下文也称为Cp)访问服务端(Sp),报了一个错误: Figure 1 以错误 ...

  2. java并发-ReentrantLock的lock和lockInterruptibly的区别

    ReentrantLock的加锁方法Lock()提供了无条件地轮询获取锁的方式,lockInterruptibly()提供了可中断的锁获取方式.这两个方法的区别在哪里呢?通过分析源码可以知道lock方 ...

  3. Qt信号-槽原理剖析--(1)信号槽简介

    唯有创造才是快乐.只有创造的生灵才是生灵.--罗曼·罗兰 信号槽是观察者模式的一种实现,特性如下: A.一个信号就是一个能够被观察的事件,或者至少是事件已经发生的一种通知: B.一个槽就是一个观察者, ...

  4. idea常用的快捷键(黑马培训班)

    idea常用的快捷键 Alt+回车 导入包,自动修正 Ctrl+N 查找类 Ctrl+Shift+N 查找文件 Ctrl+Alt+L 格式化代码 Ctrl+Alt+O 优化导入的类和包 Alt+Ins ...

  5. 12 Spring JdbcTemplate的使用

    1.项目搭建 <1>数据库表account对应的账户实体类 package domain; import java.io.Serializable; /** * 账户实体类 */ publ ...

  6. Apache Kafka Connect - 2019完整指南

    今天,我们将讨论Apache Kafka Connect.此Kafka Connect文章包含有关Kafka Connector类型的信息,Kafka Connect的功能和限制.此外,我们将了解Ka ...

  7. maven dependency中provided和compile的区别

    重点:这个项目打成war包时,scope=provided的jar包,不会出现在WEB-INFO/lib目录下,而scope=compile的jar包,会放到WEB-INFO/lib目录 scope= ...

  8. 链表习题(2)-一个集合用带头结点的单链表L表示,编写算法删除其值最大的结点。

    /*一个集合用带头结点的单链表L表示,编写算法删除其值最大的结点.*/ /* 算法思想:使用pre,p,premax,max四个指针,pre和p进行比较,premax和max进行最后的删除操作 通过遍 ...

  9. centos发布 7.7.1908版本了,怎么把老版本更新到新版本了?

    CENTOS升级 7.6 升级到7.7.1908 0.查看目前版本 cat /etc/issue cat /etc/redhat-release 1.下载系统镜像文件 https://www.cent ...

  10. 模块 logging random

    模块logging logging模块的主要功能是记录软件调试.操作过程中的各种日志. 默认情况下Python的logging模块将日志打印到了标准输出中,且只显示了大于等于WARNING级别的日志, ...