过年放了七天假,每年第一件事就是立一个flag——希望今年除了能够将技术学扎实之外,还希望能够将所学能够用来造福社会,好像flag立得有点大了。没关系,套用一句电影台词为自己开脱一下——人没有梦想,和咸鱼有什么区别。闲话至此,进入今天主题:Transformer。谷歌于2017年提出Transformer网络架构,此网络一经推出就引爆学术界。目前,在NLP领域,Transformer模型被认为是比CNN,RNN都要更强的特征提取器。

Transformer算法简介

Transformer引入了self-attention机制,同时还借鉴了CNN领域中残差机制(Residuals),由于以上原因导致transformer有如下优势:

  • 模型表达能力较强,由于self-attention机制考虑到了句子之中词与词之间的关联,
  • 抛弃了RNN的循环结构,同时借用了CNN中的残差结构加快了模型的训练速度。

接下来我们来看看transformer的一些细节:

  • 首先Scaled Dot-Product Attention步骤是transformer的精髓所在,作者引入Q,W,V参数通过点乘相识度去计算句子中词与词之间的关联重要程度。其大致过程如图所示,笔者将会在实战部分具体介绍此过程如何实现。

     
    Scaled Dot-Product Attention
  • 第二个是muti-head步骤,直白的解释就是将上面的Scaled Dot-Product Attention步骤重复执行,然后将每次执行的结果拼接起来,需要注意的是每次重复执行Scaled Dot-Product Attention步骤的参数并不共享。

     
    Multi-Head
  • 第三个步骤就是残差网络结构——将muti-head步骤的输出和原始输入之间相加。这里不明白的可以参考笔者之前介绍残差网络的文章

接下来就是实战部分,实战部分只使用了muti-head attention或者说是self-attention的向量表示作为最终特征进行文本分类。

Transformer文本分类实战

数据载入

下方代码的作用是将情感分析数据读入,格式为一句话和一个label:
sen_1 : 1
sen_2 : 0
1代表正面情绪,0代表负面情绪。

#! -*- coding: utf-8 -*-
from keras import backend as K
from keras.engine.topology import Layer
import numpy as np
from keras.preprocessing import sequence
from keras.layers import *
from keras import Model
from keras.callbacks import TensorBoard
data = np.load("imdb.npz")
x_test = data["x_test"]
x_train = data["x_train"]
y_test = data["y_test"]
y_train = data["y_train"]

数据预处理

由于文本数据长短不一,下面代码可将数据padding到相同的长度。

from itertools import chain
all_word = list(chain.from_iterable(list(x_train)))
all_word = set(all_word)
max_features = len(all_word)
data_train = sequence.pad_sequences(x_train,200)

Self-attention

这里详细介绍一下模型最关键的部分Scaled Dot-Product Attention的构建过程,如图一 Scaled Dot-Product Attention:

 
图一 self-attention
  • 1.首先申明三个待优化的参数,
  • 2.将输入X分别和进行点乘,得到,此过程可以理解成将同一句话中的词映射到三个不同的向量空间,这里笔者将三个不同的向量空间命名为Q空间,K空间和V空间,如图二 Query,Key,Value metrix
     
    图二 Query,Key,Value metrix
  • 3.然后计算Q空间的某一个词在K空间所以词向量分别点乘得分,之后将这些得分通过softmax函计算一个重要度系数。然后用计算出来的重要度系数乘上该词在V空间的词向量并加和得到该词最终的词向量表示,整个过程如图三 Softmax所示,这样就可以得到一句话经过self-attention后的向量表示
     
    图三 Softmax

上述整个过程就是Scaled Dot-Product Attention,本质上考虑到了一个句子中不同词之间的关联程度,这个过程或多或少增强了句子语义的表达。下方为keras定义的self-attention层的代码,这里加入了muti-head和mask功能的实现。

class Attention(Layer):

    def __init__(self, nb_head, size_per_head, **kwargs):
self.nb_head = nb_head
self.size_per_head = size_per_head
self.output_dim = nb_head * size_per_head
super(Attention, self).__init__(**kwargs) def build(self, input_shape):
self.WQ = self.add_weight(name='WQ',
shape=(input_shape[0][-1], self.output_dim),
initializer='glorot_uniform',
trainable=True)
self.WK = self.add_weight(name='WK',
shape=(input_shape[1][-1], self.output_dim),
initializer='glorot_uniform',
trainable=True)
self.WV = self.add_weight(name='WV',
shape=(input_shape[2][-1], self.output_dim),
initializer='glorot_uniform',
trainable=True)
super(Attention, self).build(input_shape) def Mask(self, inputs, seq_len, mode='mul'):
if seq_len == None:
return inputs
else:
mask = K.one_hot(seq_len[:, 0], K.shape(inputs)[1])
mask = 1 - K.cumsum(mask, 1)
for _ in range(len(inputs.shape) - 2):
mask = K.expand_dims(mask, 2)
if mode == 'mul':
return inputs * mask
if mode == 'add':
return inputs - (1 - mask) * 1e12 def call(self, x):
# 如果只传入Q_seq,K_seq,V_seq,那么就不做Mask
# 如果同时传入Q_seq,K_seq,V_seq,Q_len,V_len,那么对多余部分做Mask
if len(x) == 3:
Q_seq, K_seq, V_seq = x
Q_len, V_len = None, None
elif len(x) == 5:
Q_seq, K_seq, V_seq, Q_len, V_len = x
# 对Q、K、V做线性变换
Q_seq = K.dot(Q_seq, self.WQ)
Q_seq = K.reshape(Q_seq, (-1, K.shape(Q_seq)[1], self.nb_head, self.size_per_head))
Q_seq = K.permute_dimensions(Q_seq, (0, 2, 1, 3))
K_seq = K.dot(K_seq, self.WK)
K_seq = K.reshape(K_seq, (-1, K.shape(K_seq)[1], self.nb_head, self.size_per_head))
K_seq = K.permute_dimensions(K_seq, (0, 2, 1, 3))
V_seq = K.dot(V_seq, self.WV)
V_seq = K.reshape(V_seq, (-1, K.shape(V_seq)[1], self.nb_head, self.size_per_head))
V_seq = K.permute_dimensions(V_seq, (0, 2, 1, 3))
# 计算内积,然后mask,然后softmax
A = K.batch_dot(Q_seq, K_seq, axes=[3, 3]) / self.size_per_head ** 0.5
A = K.permute_dimensions(A, (0, 3, 2, 1))
A = self.Mask(A, V_len, 'add')
A = K.permute_dimensions(A, (0, 3, 2, 1))
A = K.softmax(A)
# 输出并mask
O_seq = K.batch_dot(A, V_seq, axes=[3, 2])
O_seq = K.permute_dimensions(O_seq, (0, 2, 1, 3))
O_seq = K.reshape(O_seq, (-1, K.shape(O_seq)[1], self.output_dim))
O_seq = self.Mask(O_seq, Q_len, 'mul')
return O_seq def compute_output_shape(self, input_shape):
return (input_shape[0][0], input_shape[0][1], self.output_dim)

位置编码

接下来定义一个位置编码层,由于是输入是句子属于一个序列,加入位置编码会使得语义表达更准确。

class Position_Embedding(Layer):

    def __init__(self, size=None, mode='sum', **kwargs):
self.size = size # 必须为偶数
self.mode = mode
super(Position_Embedding, self).__init__(**kwargs) def call(self, x):
if (self.size == None) or (self.mode == 'sum'):
self.size = int(x.shape[-1])
batch_size, seq_len = K.shape(x)[0], K.shape(x)[1]
position_j = 1. / K.pow(10000., 2 * K.arange(self.size / 2, dtype='float32') / self.size)
position_j = K.expand_dims(position_j, 0)
position_i = K.cumsum(K.ones_like(x[:, :, 0]), 1) - 1 # K.arange不支持变长,只好用这种方法生成
position_i = K.expand_dims(position_i, 2)
position_ij = K.dot(position_i, position_j)
position_ij = K.concatenate([K.cos(position_ij), K.sin(position_ij)], 2)
if self.mode == 'sum':
return position_ij + x
elif self.mode == 'concat':
return K.concatenate([position_ij, x], 2) def compute_output_shape(self, input_shape):
if self.mode == 'sum':
return input_shape
elif self.mode == 'concat':
return (input_shape[0], input_shape[1], input_shape[2] + self.size)

而谷歌的论文直接给出了position embedding 层的公式,如下图所示。

 
position embeding

此公式的含义是将

 

 

的位置映射为一个

 

维的位置向量,此向量的第

 

个元素的值就是通过上述公式算出来的

 

。position embeding背后的物理意义参考于参考文献第一篇:由于在数学上有以及,这表明位置的向量可以表示成位置的向量的线性变换,这提供了表达相对位置信息的可能性。

模型构建

接下来使用上方定义好的的self-attention层和position embedding层进行模型构建,这里设置的8个head,意味着将self-attention流程重复做8次,这里的代码实现不是讲8个head向量拼接,而是通过keras自带的 GlobalAveragePooling1D函数将8个head的向量求和平均一下。

K.clear_session()
callbacks = [TensorBoard("log/")]
S_inputs = Input(shape=(None,), dtype='int32')
embeddings = Embedding(max_features, 128)(S_inputs)
embeddings = Position_Embedding()(embeddings) # Position_Embedding
O_seq = Attention(8, 16)([embeddings, embeddings, embeddings])# Self Attention
O_seq = GlobalAveragePooling1D()(O_seq)
O_seq = Dropout(0.5)(O_seq)
outputs = Dense(1, activation='sigmoid')(O_seq)
model = Model(inputs=S_inputs, outputs=outputs)
model.summary()

模型的网络结构可视化输出如下:

 
model

模型训练

将之前预处理好的数据喂给模型,同时设置好batch size 和 epoch就可以跑起来了。由于笔者是使用的是笔记本的cpu,所以只跑一个epoch。

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(data_train, y_train,
batch_size=2,
epochs=1,
callbacks=callbacks,
validation_split=0.2)
 
train

结语

Transformer在各方面性能上都超过了RNN和CNN,但是其最主要的思想还是引入了self-attention,使得模型可以考虑到句子中词与词之间的相互联系,这个思想在NLP很多领域,如机器阅读(R-Net)中也曾出现。所以如何在embeding时的更好挖掘句子的语义,才是深度学习在nlp领域最需要解决的难题。

参考文献

https://spaces.ac.cn/archives/4765
https://blog.csdn.net/qq_41664845/article/details/84969266
Attention Is All You Need

作者:王鹏你妹
链接:https://www.jianshu.com/p/704893b996f9
来源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。

Attention is all your need 谷歌的超强特征提取网络——Transformer的更多相关文章

  1. android.os.NetworkOnMainThreadException 在4.0之后谷歌强制要求连接网络不能在主线程进行访问

    谷歌在4.0系统以后就禁止在主线程中进行网络访问了,原因是: 主线程是负责UI的响应,如果在主线程进行网络访问,超过5秒的话就会引发强制关闭, 所以这种耗时的操作不能放在主线程里.放在子线程里,而子线 ...

  2. 《Attention is All You Need》

    https://www.jianshu.com/p/25fc600de9fb 谷歌最近的一篇BERT取得了卓越的效果,为了研究BERT的论文,我先找出了<Attention is All You ...

  3. 2. Attention Is All You Need(Transformer)算法原理解析

    1. 语言模型 2. Attention Is All You Need(Transformer)算法原理解析 3. ELMo算法原理解析 4. OpenAI GPT算法原理解析 5. BERT算法原 ...

  4. 从Seq2seq到Attention模型到Self Attention

    Seq2seq Seq2seq全名是Sequence-to-sequence,也就是从序列到序列的过程,是近年当红的模型之一.Seq2seq被广泛应用在机器翻译.聊天机器人甚至是图像生成文字等情境. ...

  5. Attention模型

    李宏毅深度学习 https://www.bilibili.com/video/av9770302/?p=8 Generation 生成模型基本结构是这样的, 这个生成模型有个问题是我不能干预数据生成, ...

  6. 【转载】细粒度图像识别Object-Part Attention Driven Discriminative Localization for Fine-grained Image Classification

    细粒度图像识别Object-Part Attention Driven Discriminative Localization for Fine-grained Image Classificatio ...

  7. Deep attention tracking via Reciprocative Learning

    文章:Deep attention tracking via Reciprocative Learning 出自NIPS2018 文章链接:https://arxiv.org/pdf/1810.038 ...

  8. [深度概念]·Attention Model(注意力模型)学习笔记

    此文源自一个博客,笔者用黑体做了注释与解读,方便自己和大家深入理解Attention model,写的不对地方欢迎批评指正.. 1.Attention Model 概述 深度学习里的Attention ...

  9. Sequence Model-week3编程题1-Neural Machine Translation with Attention

    1. Neural Machine Translation 下面将构建一个神经机器翻译(NMT)模型,将人类可读日期 ("25th of June, 2009") 转换为机器可读日 ...

随机推荐

  1. LINUX下C++编程如何获得某进程的ID

    #include <stdio.h> #include <stdlib.h> #include <unistd.h> using namespace std; pi ...

  2. 解决CSDN博客插入代码出现的问题

    我在写CSDN博客的时候有时候会在插入代码之后继续编辑,然后保存之后经常会出现一些多余的符号<p 例如<pre></pre>,这样的标记,其实这是html的一个元素,pr ...

  3. win下在虚拟机安装CentOS 7 Linux系统

    准备: CentOS 7下载地址(我下的是everthing版本):https://www.centos.org/download/ 一.首先下载安装虚拟机VMware 地址官网下载即可. 二.安装操 ...

  4. mysql的三种连接方式

    SQL的三种连接方式分为:左外连接.右外连接.内连接,专业术语分别为:LEFT JOIN.RIGHT JOING.INNER JOIN 内连接INNER JOIN:使用比较运算符来根据指定的连接的每个 ...

  5. C++ 之手写strcat

    char *strcat(char* strDest, const char*strSrc){ assert(strDest != NULL&&strSrc != NULL); cha ...

  6. git pull 提示错误,Your local changes to the following files would be overwritten by merge

    error: Your local changes to the following files would be overwritten by merge: Please commit your c ...

  7. PHPCMS网站迁移过程后,添加内容 报500错误解决方案

    问题出现原因:1.网站迁移过程中,上传下载文件时文件丢失  2.PHPCMS源码更新升级 解决方法 1.可以到官方下载最新版源码,替换过去.如果对源码有改动,需要先保存改动过的文件,替换过去之后,再替 ...

  8. jq处理JSON数据, jq Manual (development version)

    jq 允许你直接在命令行下对 JSON 进行操作,包括分片.过滤.转换等等.让我们通过几个例子来说明 jq 的功能:一.输出格式化,漂亮的打印效果如果我们用文本编辑器打开 JSON,有时候可能看起来会 ...

  9. kubernetes1.4新特性:支持两种新的卷插件

    背景介绍 在Kubernetes中卷的作用在于提供给POD持久化存储,这些持久化存储可以挂载到POD中的容器上,进而给容器提供持久化存储. 从图中可以看到结构体PodSpec有个属性是Volumes, ...

  10. Oracle 定义变量总结

    首先,当在cmd里办入scott密码提示错误时,可以这样改一下,scott的解锁命令是: 以system用户登录:cmdsqlplus system/tigertigeralter user scot ...