本文转载自:https://blog.csdn.net/xiaosongshine/article/details/90600028

一、Self-Attention概念详解

对于self-attention来讲,Q(Query), K(Key), V(Value)三个矩阵均来自同一输入,首先我们要计算Q与K之间的点乘,然后为了防止其结果过大,会除以一个尺度标度其中  为一个query和key向量的维度。再利用Softmax操作将其结果归一化为概率分布,然后再乘以矩阵V就得到权重求和的表示。该操作可以表示为

如果将输入的所有向量合并为矩阵形式,则所有query, key, value向量也可以合并为矩阵形式表示

其中  是我们模型训练过程学习到的合适的参数。上述操作即可简化为矩阵形式

二、Self_Attention模型搭建

笔者使用Keras来实现对于Self_Attention模型的搭建,由于网络中间参数量比较多,这里采用自定义网络层的方法构建Self_Attention,关于如何自定义Keras可以参看这里:编写你自己的 Keras 层

Keras实现自定义网络层。需要实现以下三个方法:(注意input_shape是包含batch_size项的)

  • build(input_shape): 这是你定义权重的地方。这个方法必须设 self.built = True,可以通过调用 super([Layer], self).build() 完成。
  • call(x): 这里是编写层的功能逻辑的地方。你只需要关注传入 call 的第一个参数:输入张量,除非你希望你的层支持masking。
  • compute_output_shape(input_shape): 如果你的层更改了输入张量的形状,你应该在这里定义形状变化的逻辑,这让Keras能够自动推断各层的形状
from keras.preprocessing import sequence
from keras.datasets import imdb
from matplotlib import pyplot as plt
import pandas as pd from keras import backend as K
from keras.engine.topology import Layer class Self_Attention(Layer): def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(Self_Attention, self).__init__(**kwargs) def build(self, input_shape):
# 为该层创建一个可训练的权重
#inputs.shape = (batch_size, time_steps, seq_len)
self.kernel = self.add_weight(name='kernel',
shape=(3,input_shape[2], self.output_dim),
initializer='uniform',
trainable=True) super(Self_Attention, self).build(input_shape) # 一定要在最后调用它 def call(self, x):
WQ = K.dot(x, self.kernel[0])
WK = K.dot(x, self.kernel[1])
WV = K.dot(x, self.kernel[2]) print("WQ.shape",WQ.shape) print("K.permute_dimensions(WK, [0, 2, 1]).shape",K.permute_dimensions(WK, [0, 2, 1]).shape) QK = K.batch_dot(WQ,K.permute_dimensions(WK, [0, 2, 1])) QK = QK / (64**0.5) QK = K.softmax(QK) print("QK.shape",QK.shape) V = K.batch_dot(QK,WV) return V def compute_output_shape(self, input_shape): return (input_shape[0],input_shape[1],self.output_dim)

  

Keras实现Self-Attention的更多相关文章

  1. Keras实现Hierarchical Attention Network时的一些坑

    Reshape 对于的张量x,x.shape=(a, b, c, d)的情况 若调用keras.layer.Reshape(target_shape=(-1, c, d)), 处理后的张量形状为(?, ...

  2. LSTM/RNN中的Attention机制

    一.解决的问题 采用传统编码器-解码器结构的LSTM/RNN模型存在一个问题,不论输入长短都将其编码成一个固定长度的向量表示,这使模型对于长输入序列的学习效果很差(解码效果很差). 注意下图中,ax ...

  3. 文本分类:Keras+RNN vs传统机器学习

    摘要:本文通过Keras实现了一个RNN文本分类学习的案例,并详细介绍了循环神经网络原理知识及与机器学习对比. 本文分享自华为云社区<基于Keras+RNN的文本分类vs基于传统机器学习的文本分 ...

  4. Sequence Models

    Sequence Models This is the fifth and final course of the deep learning specialization at Coursera w ...

  5. keras系列︱seq2seq系列相关实现与案例(feedback、peek、attention类型)

    之前在看<Semi-supervised Sequence Learning>这篇文章的时候对seq2seq半监督的方式做文本分类的方式产生了一定兴趣,于是开始简单研究了seq2seq.先 ...

  6. [深度应用]·Keras极简实现Attention结构

    [深度应用]·Keras极简实现Attention结构 在上篇博客中笔者讲解来Attention结构的基本概念,在这篇博客使用Keras搭建一个基于Attention结构网络加深理解.. 1.生成数据 ...

  7. Attention and Augmented Recurrent Neural Networks

    Attention and Augmented Recurrent Neural Networks CHRIS OLAHGoogle Brain SHAN CARTERGoogle Brain Sep ...

  8. [深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88)

    [深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88) 个人主页--> https://xiaosongshine.github.io/ 项目g ...

  9. [深度应用]·DC竞赛轴承故障检测开源Baseline(基于Keras 1D卷积 val_acc:0.99780)

    [深度应用]·DC竞赛轴承故障检测开源Baseline(基于Keras1D卷积 val_acc:0.99780) 个人网站--> http://www.yansongsong.cn/ Githu ...

  10. Attention Model(注意力模型)思想初探

    1. Attention model简介 0x1:AM是什么 深度学习里的Attention model其实模拟的是人脑的注意力模型,举个例子来说,当我们观赏一幅画时,虽然我们可以看到整幅画的全貌,但 ...

随机推荐

  1. ThinkPHP5最新URL访问:PATH_INFO和兼容模式

    https://www.jianshu.com/p/c43fb5817ae1 http://tp5.com/index.php?s=USER/manger_user/add&n=2000&am ...

  2. JVM 线上故障排查基本操作--内容问题排查

    内存问题排查 说完了 CPU 的问题排查,再说说内存的排查,通常,内存的问题就是 GC 的问题,因为 Java 的内存由 GC 管理.有2种情况,一种是内存溢出了,一种是内存没有溢出,但 GC 不健康 ...

  3. 关于LPC MUD的关键字及其它重要术语

    关于LPMUD的关键字及其它重要术语 前面的内容中对LPC语言和 lpmud 做了介绍,也完成了学习开发的准备工作,为了更好的学习,这里先对基本术语做一个说明. 关键字(Keywords):LPC语言 ...

  4. LODOP中无规律无法还原偶尔出现问题排查

    一些问题无法还原且偶尔出现,没法通过做例子来展示问题,为了找到问题在哪里,就需要排查定位问题 .由于这些问题偶尔出现,且无规律,出现频率低,所以只能不断通过各种对比测试,定位排查到问题和什么有关.如果 ...

  5. Excel如何输入负数

    一般红字发票很少开,以前都是单独把红字发票摘出来放到一行里,然后加减一下,前段时间有个客户因为普票无法报销,改要了专票,因为是电子发票,无法作废,开了张红字.虽然红字很少开,但是想着百度一下如何在ex ...

  6. java 多线程 面试

    1.多线程有什么用? (1)发挥多核CPU的优势: 当前,应用服务器至少也都是双核的,4核.8核甚至16核的也都不少见,如果是单线程的程序,那么在双核CPU上就浪费了50%,在4核CPU上就浪费了75 ...

  7. 【转】Spring中@Async

    Spring中@Async 在Java应用中,绝大多数情况下都是通过同步的方式来实现交互处理的:但是在处理与第三方系统交互的时候,容易造成响应迟缓的情况,之前大部分都是使用多线程来完成此类任务,其实, ...

  8. [CMD] Jenkins上执行robot命令如果出现fail不往下走其他的CMD命令了

    需要在后面加上||exit 0 robot -o %disSection%.xml --include %disSection% -v ENV:%envBmk% .||exit 0

  9. C++类的组合、前向引用声明

    3.5类的组合 Part1.应用背景 对于复杂的问题,往往可以逐步划分为一系列稍微简单的子问题. 解决复杂问题的有效方法是将其层层分解为简单的问题组合,首先解决简单问题复杂问题也就迎刃而解了. 在面向 ...

  10. memcached源码分析二-lru

    在前一篇文章中介绍了memcached中的内存管理策略slab,那么需要缓存的数据是如何使用slab的呢? 1.    缓存对象item内存分布 在memcached,每一个缓存的对象都使用一个ite ...