https://blog.csdn.net/left_think/article/details/76370453

1. 背景介绍
  在传统的语音识别的模型中,我们对语音模型进行训练之前,往往都要将文本与语音进行严格的对齐操作。这样就有两点不太好:

严格对齐要花费人力、时间。
严格对齐之后,模型预测出的label只是局部分类的结果,而无法给出整个序列的输出结果,往往要对预测出的label做一些后处理才可以得到我们最终想要的结果。
  虽然现在已经有了一些比较成熟的开源对齐工具供大家使用,但是随着deep learning越来越火,有人就会想,能不能让我们的网络自己去学习对齐方式呢?因此CTC(Connectionist temporal classification)就应运而生啦。

  想一想,为什么CTC就不需要去对齐语音和文本呢?因为CTC它允许我们的神经网络在任意一个时间段预测label,只有一个要求:就是输出的序列顺序只要是正确的就ok啦~这样我们就不在需要让文本和语音严格对齐了,而且CTC输出的是整个序列标签,因此也不需要我们再去做一些后处理操作。

  对一段音频使用CTC和使用文本对齐的例子如下图所示:

2. 从输出到标签
2.1符号的表示
  接下来,我们要对一些符号的定义进行介绍。由于水平有限,看这部分定义介绍的时候绕在里面很久,可能有些理解有误,还恳请各位大大及时指出~

ytkykt:代表输出序列在第t步的输出为k的概率。举个简单的例子:当输出的序列为(a-ab-)时,y3aya3 代表了在第3步输出的字母为a的概率;

p(π∣x)p(π∣x):代表了给定输入x,输出路径为 ππ 的概率;

由于假设在每一个时间步输出的label的概率都是相互独立的,那么 p(π∣x)p(π∣x) 用公式来表示为 p(π∣x)=∏Tt=1(ytk)p(π∣x)=∏t=1T(ykt),可以理解为每一个时间步输出路径 ππ 的相应label的概率的乘积。

FF:代表一种多对一的映射,将输出路径 ππ 映射到 标签序列 ll 的一种变换

举个简单的例子 F(a−ab−)=F(−aa−−abb)=aabF(a−ab−)=F(−aa−−abb)=aab (其中-代表了空格)

p(l∣x)p(l∣x) :代表给定输入x,输出为序列 ll 的概率。

因此输出的序列为 ll 的概率可以表示为所有输出的路径 ππ 映射后的序列为 ll 的概率之和,用公式表示为 p(l∣x)=∑π∈F−1(l)p(π∣x)p(l∣x)=∑π∈F−1(l)p(π∣x)
2.2 空格的作用
  在最开始的CTC设定中是没有空格的,FF 只是简单的移除了连续的相同字母。但是这样会产生两个问题:

无法预测出连续两个相同的字母的单词了,比如说hello这个单词,在CTC中会删除掉连续相同的字母,因此CTC最后预测出的label应该是helo;
无法预测出一句完整的话,而只能预测单个的单词。因为缺乏空格,CTC无法表示出单词与单词之间停顿的部分,因此只能预测出单个单词,或者将一句话中的单词全部连接起来了;
因此,空格在CTC中的作用还是十分重要的。

3. 前向传播与反向传播
3.1前向传播
  在对符号做了一些定义之后,我们接下来看看CTC的前向传播的过程。我们前向传播就是要去计算 p(l∣x)p(l∣x)。由于一个序列 ll 通常可以有多条路径经过映射后得到,而随着序列 ll 长度的增加,相对应的路径的数目是成指数增加的,因此我们需要一种高效的算法来计算它。

  有一种类似于HMM的前向传播的算法可以帮助我们来解决这个问题。它的key就是那些与序列 ll 对应的路径概率都可以通过迭代来计算得出。

  在进行计算之前,我们需要对序列 ll 做一些预处理,在序列 ll 的开头与结尾分别加上空格,并且在字母与字母之间都添加上空格。如果原来序列 ll 的长度为U,那么预处理之后,序列 l′l′ 的长度为2U+1 。

  对于一个特定的序列 ll ,我们定义前向变量 α(t,u)α(t,u) 为输出所有长度为 tt ,且经过 FF 映射之后为序列 ll 的路径的概率之和,用公式表达如下所示:

α(t,u)=∑π∈V(t,u)∏ti=1yiπiα(t,u)=∑π∈V(t,u)∏i=1tyπii
其中,V(t,u)={π∈A′t:F(π)=l1:u/2,πt=l′u}V(t,u)={π∈A′t:F(π)=l1:u/2,πt=l′u} 代表了所有满足经过 FF 映射之后为序列 ll ,长度为t的路径集合,且在第t时间步的输出为label: l′ul′u。

  所有正确路径的开头必须是空格或者label l1l1,因此存在着初始化的约束条件:
α(1,1)=y1bα(1,1)=yb1
α(1,2)=y1l1α(1,2)=yl11
α(1,u)=0,∀u>2α(1,u)=0,∀u>2
也就是当路径长度为1时,它只可能对应到空格或者序列 ll 的第一个label,不可能对应到序列 ll 第一个之后的label中。

  因此,p(l∣x)p(l∣x) 可以由前向变量来表示,即为
p(l∣x)=α(T,U′)+α(T,U′−1)p(l∣x)=α(T,U′)+α(T,U′−1)
其中α(T,U′)α(T,U′)可以理解为所有路径长度为T,经过 FF 映射之后为序列 ll ,且第T时刻的输出的label为:l′Ul′U 或者 l′U−1l′U−1。也就是路径的最后一个是否包括了空格。

  怎么去理解它呢?我们不妨先看看它的递归图

上图中,白色的点表示一个label,黑色的点表示空格,纵向每一列表示的是路径的长度T(或者时刻T?),箭头代表了路径下一个时刻可以输出到哪个label去。如果在时刻 1 的 label 为空格,那么路径在下一时刻只有两个选择,第一个还是输出空格,第二个就是输出序列 ll 中对应的空格的下一个label:C;如果在时刻2的 label 为 C,那么在时刻3,它可以有三种选择:第一种就是输出还是 C,第二种是输出为空格,第三种是直接输出A。

  从上图可以看出长度为T的输出路径映射到序列 l:catl:cat, 可以由第T步为label:T的所有路径和第T步为空格的所有路径的概率之和来表示(注意:并不是所有以空格或者T结束的路径都是对的,这里路径是有限制的,不要忘了初始条件的限制哦)。

  现在我们要来引出它的递推公式啦,具体公式如下所示:
α(t,u)=ytl′u∑ui=f(u)α(t−1,i)α(t,u)=yl′ut∑i=f(u)uα(t−1,i)
其中
f(u)={u−1,u−2,if l′u=blank or l′u−2=l′uotherwisef(u)={u−1,if l′u=blank or l′u−2=l′uu−2,otherwise
  如何理解这个递推公式呢,很简单,我们可以看上面递推图,就以时刻T为空格的前向变量为例,由于我们之前讲过了如果当前时刻的输出为空格,下一时刻路径输出只有两种可能性,而如果我们当前时刻是空格,上一时刻的输出从图中可以看出也是由两种可能性,一种是在T-1时刻输出为空格,另外一种是在T-1时刻输出为T。因此我们只要计算出T-1时刻输出为空格的所有正确路径的概率之和以及在T-1时刻输出为T的所有路径的概率之和,再乘上T时刻输出为空格的概率 yTl′uyl′uT,就可以得到前向变量 α(t,u)α(t,u) 啦。时刻T为label:T的前向变量的求法和空格的类似,只是它由三种可能情况求和再乘上 yTl′uyl′uT 得到的。

3.2反向传播
  与前向传播类似,我们首先定义一个反向变量 β(t,u)β(t,u),它的含义是从t+1时刻开始,在前向变量 α(t,u)α(t,u) 上添加路径 π′π′,使得最后通过 FF 映射之后为序列 ll 的概率之和,用公式表示为:
β(t,u)=∑π∈W(t,u)∏T−ti=1yt+iπiβ(t,u)=∑π∈W(t,u)∏i=1T−tyπit+i
其中W(t,u)={π∈A′T−t:F(π′+π)=l,∀π′∈V(t,u)}W(t,u)={π∈A′T−t:F(π′+π)=l,∀π′∈V(t,u)}
  按照前向传播的图举例说明:假设我们在T-2时刻路径输出为label:A,那么此时的反向变量的求法就是在T-2时刻开始,所有能到达T时刻输出为空格或者label:T的“剩余”路径 π′π′ 的概率之和。

反向传播也有相对应的初始化条件:
β(T,U′)=β(T,U′−1)=1β(T,U′)=β(T,U′−1)=1
β(T,u′)=0,∀u′<U′−1β(T,u′)=0,∀u′<U′−1
它的递推公式如下所示
β(t,u)=∑g(u)i=uβ(t+1,i)yt+1l′iβ(t,u)=∑i=ug(u)β(t+1,i)yli′t+1
其中
g(u)={u−1,u−2,if l′u=blank or l′u−2=l′uotherwiseg(u)={u−1,if l′u=blank or l′u−2=l′uu−2,otherwise
3.3对数运算
  不论是在计算前向变量还是反向变量时,都涉及到了大量的概率的乘积。由于这些乘积都是小于1的,在大量的小数相乘时,最后得到的结果往往都会趋向于0,更严重的是产生underflow。因此在计算时对其做了取对数的处理,这样乘法就会转化为加法了,不仅避免了underflow,还简化了计算。但是,原来的加法计算就不是太方便了。不过这里有一个数学的trick:
ln(a+b)=lna+ln(1+elnb−lna)ln(a+b)=lna+ln(1+elnb−lna)
4.损失函数
  CTC的损失函数定义如下所示
L(S)=−ln∏(x,z)∈Sp(z|x)=−∑(x,z)∈Slnp(z|x)L(S)=−ln∏(x,z)∈Sp(z|x)=−∑(x,z)∈Slnp(z|x)
其中 p(z|x)p(z|x) 代表给定输入x,输出序列 zz 的概率,S为训练集。损失函数可以解释为:给定样本后输出正确label的概率的乘积(这里个人不理解为啥要做乘积运算,求和的话不应该好解释一点么?可能是因为要取对数运算,求和可能不太方便,所以是做乘积运算),再取负对数就是损失函数了。取负号之后我们通过最小化损失函数,就可以使输出正确的label的概率达到最大了。

  由于上述定义的损失函数是可微的,因此我们可以求出它对每一个权重的导数,然后就可以使用什么梯度下降、Adam之类的算法来进行优化求解啦~

  下面我们就要把上一节定义的前向变量与反向变量用到我们的损失函数中去,让序列 l=zl=z,定义一个新的集合 X(t,u)={π∈A′T:F(π)=z,πt=z′u}X(t,u)={π∈A′T:F(π)=z,πt=zu′} , X(t,u)X(t,u) 代表了在时刻t经过label:l′ulu′ 的所有路径的集合,这样由之前对前向变量与反向变量的定义,它俩的乘积就可以写成:
α(t,u)β(t,u)=∑π∈X(t,u)∏Tt=1ytπtα(t,u)β(t,u)=∑π∈X(t,u)∏t=1Tyπtt
而 p(π∣x)=∏Tt=1(ytk)p(π∣x)=∏t=1T(ykt),因此进一步转化可以得到
α(t,u)β(t,u)=∑π∈X(t,u)p(π|x)α(t,u)β(t,u)=∑π∈X(t,u)p(π|x)
因此,对于任意的时刻t,我们给定输入x,输出序列 zz 的概率可以表示成
p(z∣x)=∑|z′|u=1α(t,u)β(t,u)p(z∣x)=∑u=1|z′|α(t,u)β(t,u)
也就是在任意一个时刻分开,前向变量与反向变量的乘积为在该时刻经过label:l′ulu′ 的所有概率之和,然后再遍历了序列 l′l′ 的每一个label,因此就得到了所有输出为序列 l′l′ 的概率之和。

  损失函数就可以进一步转化为
L(x,z)=−ln∑|z′|u=1α(t,u)β(t,u)L(x,z)=−ln∑u=1|z′|α(t,u)β(t,u)
4.1损失函数梯度计算
  损失函数关于网络输出 ytkykt 的偏导数为:
∂L(x,z)∂ytk=−∂lnp(x|z)∂ytk=−1p(x|z)∂p(x|z)∂ytk∂L(x,z)∂ykt=−∂lnp(x|z)∂ykt=−1p(x|z)∂p(x|z)∂ykt
而 p(z∣x)=∑|z′|u=1α(t,u)β(t,u)=∑π∈X(t,u)∏Tt=1ytπtp(z∣x)=∑u=1|z′|α(t,u)β(t,u)=∑π∈X(t,u)∏t=1Tyπtt,我们记label:k出现在序列 z′z′ 的所有路径的集合为B(z,k)={u:z′u=k}B(z,k)={u:zu′=k},因此可以得出
∂α(t,u)β(t,u)∂ytk={α(t,u)β(t,u)ytk,0,if k occurs in z'otherwise∂α(t,u)β(t,u)∂ykt={α(t,u)β(t,u)ykt,if k occurs in z'0,otherwise
因此损失函数关于输出的偏导数可以写为
∂L(x,z)∂ytk=−1p(x|z)∂p(x|z)∂ytk=−1p(x|z)ytk∑u∈B(z,k)α(t,u)β(t,u)∂L(x,z)∂ykt=−1p(x|z)∂p(x|z)∂ykt=−1p(x|z)ykt∑u∈B(z,k)α(t,u)β(t,u)
最后,我们可以通过链式法则,得到损失函数对未经过sofmax层的网络输出的 atkakt 的偏导数:
∂L(x,z)∂atk=−∑k′∂L(x,z)∂ytk′∂ytk′∂atk′∂L(x,z)∂akt=−∑k′∂L(x,z)∂yk′t∂yk′t∂ak′t
又有
ytk=eatk∑k′eatkykt=eakt∑k′eakt
因此可以得到损失函数对未经过sofmax层的网络输出的 atkakt 的偏导数:
∂L(x,z)∂atk=ytk−1p(x|z)∑u∈B(z,k)α(t,u)β(t,u)∂L(x,z)∂akt=ykt−1p(x|z)∑u∈B(z,k)α(t,u)β(t,u)
5.参考文献
1.《Supervised Sequence Labelling with Recurrent Neural Networks》 chapter7

2. http://blog.csdn.net/xmdxcsj/article/details/51763886

CTC Loss原理的更多相关文章

  1. CTC loss 理解

    参考文献 CTC学习笔记(一) 简介:https://blog.csdn.net/xmdxcsj/article/details/51763868 CTC学习笔记(二) 训练和公式推导 很详细的公示推 ...

  2. 以lstm+ctc对汉字识别为例对tensorflow 中的lstm,ctc loss的调试

    #-*-coding:utf8-*- __author = "buyizhiyou" __date = "2017-11-21" ''' 单步调试,结合汉字的识 ...

  3. 记CTC原理

    CTC,Connectionist temporal classification.从字面上理解它是用来解决时序类数据的分类问题.语音识别端到端解决方案中应用的技术.主要是解决以下两个问题 解决语音输 ...

  4. CTC 的工作原理

    CTC 的工作原理     Fig. 1. How CTC  combine a word (source: https://distill.pub/2017/ctc/) 这篇文章主要解释CTC 的工 ...

  5. 怎样在caffe中添加layer以及caffe中triplet loss layer的实现

    关于triplet loss的原理.目标函数和梯度推导在上一篇博客中已经讲过了.详细见:triplet loss原理以及梯度推导.这篇博文主要是讲caffe下实现triplet loss.编程菜鸟.假 ...

  6. tensorflow源码分析——CTC

    CTC是2006年的论文Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurren ...

  7. [免费下载应用]iNeuKernel.Ocr 图像数据识别与采集原理和产品化应用

    目       录 1..... 应用概述... 2 2..... 免费下载试用... 2 3..... 视频介绍... 2 4..... iNeuLink.Ocr图像数据采集应用... 2 5... ...

  8. CRNN网络结构详解

    目录 一. CRNN概论 简介 网络 二. CRNN局部之特征提取 三. CRNN局部之BLSTM 四. CRNN局部之CTC 关于CTC是什么东西? CTC理论基础 五. 参考文献 一. CRNN概 ...

  9. 【OCR技术系列之八】端到端不定长文本识别CRNN代码实现

    CRNN是OCR领域非常经典且被广泛使用的识别算法,其理论基础可以参考我上一篇文章,本文将着重讲解CRNN代码实现过程以及识别效果. 数据处理 利用图像处理技术我们手工大批量生成文字图像,一共360万 ...

随机推荐

  1. Linux升级nodejs及多版本管理

    最近要用到开发要用到nodejs,于是跑到开发机运行了下node,已经安装了,深感欣慰,是啥版本呢?再次运行了下node -v,原来是0.6.x的.估计是早先什么时候谁弄的.那么来升级下node吧. ...

  2. B树、B-树、B+树、B*树都是什么

    B树.B-树.B+树.B*树都是什么 B树 即二叉搜索树: 1.所有非叶子结点至多拥有两个儿子(Left和Right): 2.所有结点存储一个关键字: 3.非叶子结点的左指针指向小于其关键字的子树,右 ...

  3. 《DSP using MATLAB》Problem 2.19

    代码: %% ------------------------------------------------------------------------ %% Output Info about ...

  4. TCP/IP详解与OSI七层模型

    TCP/IP协议 包含了一系列构成互联网基础的网络协议,是Internet的核心协议.基于TCP/IP的参考模型将协议分成四个层次,它们分别是链路层.网络层.传输层和应用层.下图表示TCP/IP模型与 ...

  5. (转)性能分析之-- JAVA Thread Dump 分析综述

    原文链接:http://blog.csdn.net/rachel_luo/article/details/8920596 最近在做性能测试,需要对线程堆栈进行分析,在网上收集了一些资料,学习完后,将相 ...

  6. 【转】Notepad++中Windows,Unix,Mac三种格式之间的转换

    原文网址:http://www.crifan.com/files/doc/docbook/rec_soft_npp/release/htmls/npp_func_windows_unix_mac.ht ...

  7. rsync 通过密码文件实现远程同步

    https://my.oschina.net/yyping/blog/91964 1.源文件服务器:192.168.10.203 2.备份服务器:192.168.10.88 配置备份服务器(192.1 ...

  8. 刷新SQL Server所有视图、函数、存储过程 更多 sql 此脚本用于在删除或添加字段时刷新相关视图,并检查视图、函数、存储过程有效性。 [SQL]代码 --视图、存储过程、函数名称 DECLARE @NAME NVARCHAR(255); --局部游标 DECLARE @CUR CURSOR --自动修改未上状态为旷课 SET @CUR=CURSOR SCROLL DYNAMIC FO

    刷新SQL Server所有视图.函数.存储过程 更多   sql   此脚本用于在删除或添加字段时刷新相关视图,并检查视图.函数.存储过程有效性. [SQL]代码 --视图.存储过程.函数名称 DE ...

  9. android 点击返回键 以及 加载activity 生命周期 记录。。。,一目了然

    简叙 Activity 生命周期及android 返回按钮捕捉   @Override protected void onPostCreate(Bundle savedInstanceState) { ...

  10. mac下安装wxPython2.8.12.1方法

    搭建robot_framework 环境 找不到 wxPython2.8.12.1的解决方法 1.mac终端pip安装robotframework-ride后 启动ride.py报: wxPython ...