Keras实现Self-Attention
本文转载自:https://blog.csdn.net/xiaosongshine/article/details/90600028
一、Self-Attention概念详解
对于self-attention来讲,Q(Query), K(Key), V(Value)三个矩阵均来自同一输入,首先我们要计算Q与K之间的点乘,然后为了防止其结果过大,会除以一个尺度标度其中 为一个query和key向量的维度。再利用Softmax操作将其结果归一化为概率分布,然后再乘以矩阵V就得到权重求和的表示。该操作可以表示为
如果将输入的所有向量合并为矩阵形式,则所有query, key, value向量也可以合并为矩阵形式表示
其中 是我们模型训练过程学习到的合适的参数。上述操作即可简化为矩阵形式
二、Self_Attention模型搭建
笔者使用Keras来实现对于Self_Attention模型的搭建,由于网络中间参数量比较多,这里采用自定义网络层的方法构建Self_Attention,关于如何自定义Keras可以参看这里:编写你自己的 Keras 层
Keras实现自定义网络层。需要实现以下三个方法:(注意input_shape是包含batch_size项的)
- build(input_shape): 这是你定义权重的地方。这个方法必须设 self.built = True,可以通过调用 super([Layer], self).build() 完成。
- call(x): 这里是编写层的功能逻辑的地方。你只需要关注传入 call 的第一个参数:输入张量,除非你希望你的层支持masking。
- compute_output_shape(input_shape): 如果你的层更改了输入张量的形状,你应该在这里定义形状变化的逻辑,这让Keras能够自动推断各层的形状
from keras.preprocessing import sequence
from keras.datasets import imdb
from matplotlib import pyplot as plt
import pandas as pd from keras import backend as K
from keras.engine.topology import Layer class Self_Attention(Layer): def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(Self_Attention, self).__init__(**kwargs) def build(self, input_shape):
# 为该层创建一个可训练的权重
#inputs.shape = (batch_size, time_steps, seq_len)
self.kernel = self.add_weight(name='kernel',
shape=(3,input_shape[2], self.output_dim),
initializer='uniform',
trainable=True) super(Self_Attention, self).build(input_shape) # 一定要在最后调用它 def call(self, x):
WQ = K.dot(x, self.kernel[0])
WK = K.dot(x, self.kernel[1])
WV = K.dot(x, self.kernel[2]) print("WQ.shape",WQ.shape) print("K.permute_dimensions(WK, [0, 2, 1]).shape",K.permute_dimensions(WK, [0, 2, 1]).shape) QK = K.batch_dot(WQ,K.permute_dimensions(WK, [0, 2, 1])) QK = QK / (64**0.5) QK = K.softmax(QK) print("QK.shape",QK.shape) V = K.batch_dot(QK,WV) return V def compute_output_shape(self, input_shape): return (input_shape[0],input_shape[1],self.output_dim)
Keras实现Self-Attention的更多相关文章
- Keras实现Hierarchical Attention Network时的一些坑
Reshape 对于的张量x,x.shape=(a, b, c, d)的情况 若调用keras.layer.Reshape(target_shape=(-1, c, d)), 处理后的张量形状为(?, ...
- LSTM/RNN中的Attention机制
一.解决的问题 采用传统编码器-解码器结构的LSTM/RNN模型存在一个问题,不论输入长短都将其编码成一个固定长度的向量表示,这使模型对于长输入序列的学习效果很差(解码效果很差). 注意下图中,ax ...
- 文本分类:Keras+RNN vs传统机器学习
摘要:本文通过Keras实现了一个RNN文本分类学习的案例,并详细介绍了循环神经网络原理知识及与机器学习对比. 本文分享自华为云社区<基于Keras+RNN的文本分类vs基于传统机器学习的文本分 ...
- Sequence Models
Sequence Models This is the fifth and final course of the deep learning specialization at Coursera w ...
- keras系列︱seq2seq系列相关实现与案例(feedback、peek、attention类型)
之前在看<Semi-supervised Sequence Learning>这篇文章的时候对seq2seq半监督的方式做文本分类的方式产生了一定兴趣,于是开始简单研究了seq2seq.先 ...
- [深度应用]·Keras极简实现Attention结构
[深度应用]·Keras极简实现Attention结构 在上篇博客中笔者讲解来Attention结构的基本概念,在这篇博客使用Keras搭建一个基于Attention结构网络加深理解.. 1.生成数据 ...
- Attention and Augmented Recurrent Neural Networks
Attention and Augmented Recurrent Neural Networks CHRIS OLAHGoogle Brain SHAN CARTERGoogle Brain Sep ...
- [深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88)
[深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88) 个人主页--> https://xiaosongshine.github.io/ 项目g ...
- [深度应用]·DC竞赛轴承故障检测开源Baseline(基于Keras 1D卷积 val_acc:0.99780)
[深度应用]·DC竞赛轴承故障检测开源Baseline(基于Keras1D卷积 val_acc:0.99780) 个人网站--> http://www.yansongsong.cn/ Githu ...
- Attention Model(注意力模型)思想初探
1. Attention model简介 0x1:AM是什么 深度学习里的Attention model其实模拟的是人脑的注意力模型,举个例子来说,当我们观赏一幅画时,虽然我们可以看到整幅画的全貌,但 ...
随机推荐
- WPF外包团队:2019 WPF数据监控系统案例演示
项目版权均为客户所有,如有WPF项目外包欢迎联系我们. QQ:372900288 TEL:13911652504 WX:Liuxiang0884
- 【ARM-Linux开发】Ubuntu下的/usr目录权限,导致不能使用sudo命令的修复
刚开始运行sudo时,报了下面这个错误 sudo: must be setuid root,于是上网找解决方法,搜索出来的都是这样解决的 ls -l /usr/bin/sudochown root: ...
- Xshell设置运行自动化脚本
使用Xshell工具连接操作Linux系统,并编写运行自动化脚本示例: 这里介绍一种自动化下载日志文件的例子,下面先贴上编写的脚本,这里脚本命名为cyp-assout-log.js 如下: /* xs ...
- Python 推导式详解
各种推导式详解 推导式的套路 之前我们已经学习了最简单的列表推导式和生成器表达式.但是除此之外,其实还有字典推导式.集合推导式等等. 下面是一个以列表推导式为例的推导式详细格式,同样适用于其他推导式. ...
- [转帖]JAVA BIO与NIO、AIO的区别(这个容易理解)
JAVA BIO与NIO.AIO的区别(这个容易理解) https://blog.csdn.net/ty497122758/article/details/78979302 2018-01-05 11 ...
- Java 总结篇2
第02章:数据类型和运算符 一.概述: 1.数据类型:int.float.char.boolean 2.运算符:算术运算符.赋值运算符.关系运算符.逻辑运算符.位运算符(了解即可).条件运算符 3.基 ...
- SAS学习笔记60 统计SAS实例之T检验
单样本 H0:服从正态分布 P=0.0988>0.05不拒绝H0,服从正态分布 H0:等于140t=-2.14,P=0.0397 P<0.05,拒绝H0,差异有统计学意义 均值x=130. ...
- HTML学习--基础知识
WEB a) 什么是WEB WEB,是基于Internet上的一种应用程序(网页应用程序),WEB页面,简称WEB页(网页),就是保存在服务器端上的一个具体的页面 b) WEB ...
- uni-app项目导入第三方组件库muse-ui
你说uni-app是什么 我说,uni-app是一套基于vue.js开发跨平台应用的前端框架,可编译多个平台,比如:Android.IOS.H5.微信小程序.支付宝小程序.头条小程序.百度小程序 懂行 ...
- C#中Unity对象的注册方式与生命周期解析
1.示例代码 请详细阅读 static void Main(string[] args) { { Console.WriteLine("----------全局设置----------&qu ...