这份代码来自于苏剑林

# -*- coding:utf-8 -*-

from keras.layers import Layer
import keras.backend as K class CRF(Layer):
"""纯Keras实现CRF层
CRF层本质上是一个带训练参数的loss计算层,因此CRF层只用来训练模型,
而预测则需要另外建立模型,但是还是要用到训练好的转移矩阵
"""
def __init__(self, ignore_last_label=False, **kwargs):
"""ignore_last_label:定义要不要忽略最后一个标签,起到mask的效果
"""
self.ignore_last_label = 1 if ignore_last_label else 0
super(CRF, self).__init__(**kwargs)
def build(self, input_shape):
self.num_labels = input_shape[-1] - self.ignore_last_label
self.trans = self.add_weight(name='crf_trans',
shape=(self.num_labels, self.num_labels),
initializer='glorot_uniform',
trainable=True)
def log_norm_step(self, inputs, states):
"""递归计算归一化因子
要点:1、递归计算;2、用logsumexp避免溢出。
技巧:通过expand_dims来对齐张量。
"""
states = K.expand_dims(states[0], 2) # previous
inputs = K.expand_dims(inputs, 2) # 这个时刻的对标签的打分值,Emission score
trans = K.expand_dims(self.trans, 0) # 转移矩阵 output = K.logsumexp(states+trans+inputs, 1) # e 指数求和,log是防止溢出
return output, [output] def path_score(self, inputs, labels):
"""计算目标路径的相对概率(还没有归一化)
要点:逐标签得分,加上转移概率得分。
技巧:用“预测”点乘“目标”的方法抽取出目标路径的得分。
"""
# 在CRF中涉及到标签得分加上转移概率,而这个point score就是相当于是标签得分(在真是标签的情况下,查看预测对于真实标签位置的总得分),因为labels的shape是[B, T, N],而在N这个维度是one-hot,
# 这里再乘以pred,相当于是对labels存在1的地方进行打分,其余地方全为0,再进行第2个维度相加表示去除0的值,再相加表示求一个总的标签得分
point_score = K.sum(K.sum(inputs*labels, 2), 1, keepdims=True) # 逐标签得分, shape [B, 1]
labels1 = K.expand_dims(labels[:, :-1], 3) # shape [B, T-1, N, 1]
labels2 = K.expand_dims(labels[:, 1:], 2) # shape [B, T-1, 1, N]
# 这里相乘的目的相当于从上一时刻转移到当前时刻,确定当前时刻是从上一时刻哪一个标签转移过来的,因为labels是one-hot的形式,所以在最后两个维度只有1个元素为1,其他全部为0,表示转移标志
labels = labels1 * labels2 # 两个错位labels,负责从转移矩阵中抽取目标转移得分 shape [B, T-1, N, N]
trans = K.expand_dims(K.expand_dims(self.trans, 0), 0)
# K.sum(trans*labels, [2, 3]),因为trans*labels的结果是[B, T-1, N, N], 而后面两个维度中只有1个有值,表示转移得分
trans_score = K.sum(K.sum(trans*labels, [2, 3]), 1, keepdims=True) # 求出所有T-1时刻的概率转移总得分,K.sum(trans*labels, [2, 3]), 表示每个时刻的转移得分
return point_score+trans_score # 两部分得分之和 def call(self, inputs): # CRF本身不改变输出,它只是一个loss
return inputs def loss(self, y_true, y_pred): # 目标y_pred需要是one hot形式
mask = 1-y_true[:, 1:, -1] if self.ignore_last_label else None
y_true, y_pred = y_true[:, :, :self.num_labels], y_pred[:, :, :self.num_labels]
init_states = [y_pred[:, 0]] # 初始状态
log_norm, _, _ = K.rnn(self.log_norm_step, y_pred[:, 1:], init_states, mask=mask) # 计算Z向量(对数) shape[batch_size, output_dim]
log_norm = K.logsumexp(log_norm, 1, keepdims=True) # 计算Z(对数)shape [batch_size, 1] 计算一个总的
path_score = self.path_score(y_pred, y_true) # 计算分子(对数)
return log_norm - path_score # 即log(分子/分母) def accuracy(self, y_true, y_pred): # 训练过程中显示逐帧准确率的函数,排除了mask的影响
mask = 1-y_true[:,:,-1] if self.ignore_last_label else None
y_true,y_pred = y_true[:,:,:self.num_labels],y_pred[:,:,:self.num_labels]
isequal = K.equal(K.argmax(y_true, 2), K.argmax(y_pred, 2))
isequal = K.cast(isequal, 'float32')
if mask == None:
return K.mean(isequal)
else:
return K.sum(isequal*mask) / K.sum(mask)

CRF keras代码实现的更多相关文章

  1. 从 python 中 axis 参数直觉解释 到 CNN 中 BatchNorm 的工作方式(Keras代码示意)

    1. python 中 axis 参数直觉解释 网络上的解释很多,有的还带图带箭头.但在高维下是画不出什么箭头的.这里阐述了 axis 参数最简洁的解释. 假设我们有矩阵a, 它的shape是(4, ...

  2. 深度学习(七)U-Net原理以及keras代码实现医学图像眼球血管分割

    原文作者:aircraft 原文链接:https://www.cnblogs.com/DOMLX/p/9780786.html DRIVE数据集下载百度云链接:链接:https://pan.baidu ...

  3. 大数据开发之keras代码框架应用

    总体来讲keras这个深度学习框架真的很“简易”,它体现在可参考的文档写的比较详细,不像caffe,装完以后都得靠技术博客,keras有它自己的官方文档(不过是英文的),这给初学者提供了很大的学习空间 ...

  4. Keras代码超详细讲解LSTM实现细节

    1.首先我们了解一下keras中的Embedding层:from keras.layers.embeddings import Embedding: Embedding参数如下: 输入尺寸:(batc ...

  5. 条件随机场CRF原理介绍 以及Keras实现

    本文是对CRF基本原理的一个简明的介绍.当然,“简明”是相对而言中,要想真的弄清楚CRF,免不了要提及一些公式,如果只关心调用的读者,可以直接移到文末. 图示# 按照之前的思路,我们依旧来对比一下普通 ...

  6. 到底该如何入门Keras、Theano呢?(浅谈)

    目前刚刚开始学习Theano,可以说是一头雾水,后来发现Keras是对Theano进行了包装,直接使用Keras可以减少很多细节程序的书写,它是模块儿化的,使用比较方便,但更为细节的内容,还没有理解, ...

  7. Keras 学习之旅(一)

    软件环境(Windows): Visual Studio Anaconda CUDA MinGW-w64 conda install -c anaconda mingw libpython CNTK ...

  8. Inception模型和Residual模型卷积操作的keras实现

    Inception模型和Residual残差模型是卷积神经网络中对卷积升级的两个操作. 一.  Inception模型(by google) 这个模型的trick是将大卷积核变成小卷积核,将多个卷积核 ...

  9. Keras官方中文文档:序贯模型

    快速开始序贯(Sequential)模型 序贯模型是多个网络层的线性堆叠,也就是"一条路走到黑". 可以通过向Sequential模型传递一个layer的list来构造该模型: f ...

随机推荐

  1. [考试反思]1112csp-s模拟测试112:二返

    连着两场... 信心赛.但是题锅了,我也锅了. 然后Day2就不用考了. T1没开够long long.(a+b+c+0ll)与(0ll+a+b+c)还是有一点区别的. T2出题人用Windows出数 ...

  2. JS Proxy(代理)

    前言 Proxy 也就是代理,可以帮助我们完成很多事情,例如对数据的处理,对构造函数的处理,对数据的验证,说白了,就是在我们访问对象前添加了一层拦截,可以过滤很多操作,而这些过滤,由你来定义. 想了解 ...

  3. java并发编程-12个原子类

    背景 多线程更新变量的值,可能得不到预期的值,当然增加syncronized关键字可以解决线程并发的问题. 这里提供另外一种解决问题的方案,即位于 java.util.concurrent.atomi ...

  4. QLineEdit限制数据类型——只能输入浮点型数

    前言 最近做了一个小的上位机,要通过串口来下发几个时间参数,为了防止误输入,产生不必要的麻烦,我把输入范围限制在0-680的浮点型数据,支持小数点后2位.学习了一下QLineEdit类是如何限制输入类 ...

  5. java获取当前年份、月份和日期字符串等

    Java获取当前年份.月份和日期是通过Calendar类的实例对象来获取的. 首先创建一个Calendar类的实例对象,Calendar类属于java.util包. Calendar calendar ...

  6. Unix 开发中的 Make 三连

    Unix 开发过程中,经常性的操作是从源码编译安装相应库文件,所以下面三个命令便是家常便饭,俗称三连: ./configure make make install 下面来看看这三步分别做了什么. co ...

  7. 在Asp.Net Core中配置使用MarkDown富文本编辑器实现图片上传和截图上传(开源代码.net core3.0)

    我们的富文本编辑器不能没有图片上传尤其是截图上传,下面我来教大家怎么实现MarkDown富文本编辑器截图上传和图片上传. 1.配置编辑器到html页 <div id="test-edi ...

  8. PlayJava Day021

    容器: Collection接口:定义了存取一组对象的方法,其子接口Set和List分别定义了存储方式 List:存储数据有序且可重复 ----> ArrayList Set:存储数据无序且不可 ...

  9. C#中获取多个对象list中对象共有的属性项

    场景 有一组数据list<TestDataList> 每一个TestDataList是一个对象,此对象可能有温度数据,也可能没有温度数据. 有温度数据的情况下,温度数据属性又是一个list ...

  10. vue jsx与render的区别及基本使用

    vue template语法简单明了,数据操作与视图分离,开发体验友好.但是在某些特定场合中,会限制一些功能的扩展,如动态使用过滤器.解析字符串类型的模板文件等.以上功能的实现可以借助vue的rend ...