CTCLoss如何使用
CTCLoss如何使用
什么是CTC
CTC全称为Connectionist Temporal Classification,中文翻译不好类似“联结主义按时间分类”。
CTCLoss是一类损失函数,用于计算模型输出\(y\)和标签\(label\)的损失。
\]
神经网络在训练过程中,是让\(loss\)减少的过程。常用于图片文字识别OCR和语音识别项目,因为CTCLoss计算过程中不需要\(y\)和\(label\)对齐,这样做的好处就是大幅的减轻了数据对齐标注的工作量,极大的提高了效率。
架构介绍
本文主要是介绍CTCLoss,这里介绍模型架构是为了更好的理解CTCLoss函数在整体的做用。现有一段原始数据,它可以是一张带文字的图片或一段说话的音频。
如图所示原始的声音通过DFT(离散傅立叶变化)得到一张具有时频特性的特征图,将特征图通过网络\(\mathcal{N}_w\)后输出结果\(y\)(\(y\in\mathbb{R}^{K \times T}\),\(K\)维是在每一时间点预测的词的概率,\(T\)是时间维度)。
一个简单的例子
现在有一段语音,是一个人在拼写英文单词“CAT”,语音内容是“C”、“A”、“T”这三个字母。这个人读完这三个字母用了5s的时间。我们想通过语音识别这三个字母。
首先我们需要一个26个字母的词表,我们用序号1-26,分别来表示字母A-Z这26个字母,我们用序号0表示blank。blank是用来区分那些不属于这26字母的部分。然后是假设这个模型每秒会给出一个识别字母表的概率分布,
音频持续了5s,因此有5列这样的概率分布。
\]
下表就是\(y\)的概率分布,每一列是当前时刻输入数据所对应的概率分布。
\(y_t^k\) | t=1 | t=2 | t=3 | t=4 | t=5 |
---|---|---|---|---|---|
k=0 (-) | 0.031953 | 0.044296 | 0.038297 | 0.038320 | 0.027464 |
k=1 (A) | 0.026221 | 0.030363 | 0.031878 | 0.027295 | 0.029824 |
k=2 (B) | 0.040555 | 0.025838 | 0.023487 | 0.041529 | 0.028116 |
k=3 (C) | 0.029333 | 0.045889 | 0.031872 | 0.023184 | 0.029338 |
k=4 (D) | 0.023595 | 0.053792 | 0.022519 | 0.039882 | 0.025342 |
k=5 (E) | 0.048014 | 0.028887 | 0.020526 | 0.041302 | 0.045833 |
k=6 (F) | 0.028770 | 0.040735 | 0.045488 | 0.044244 | 0.032191 |
k=7 (G) | 0.035127 | 0.032281 | 0.034032 | 0.051973 | 0.041613 |
k=8 (H) | 0.044897 | 0.047910 | 0.049222 | 0.056956 | 0.048665 |
k=9 (I) | 0.032323 | 0.044911 | 0.038994 | 0.046017 | 0.040002 |
k=10 (J) | 0.047130 | 0.024608 | 0.034797 | 0.038146 | 0.041496 |
k=11 (K) | 0.033491 | 0.049294 | 0.043909 | 0.053962 | 0.037901 |
k=12 (L) | 0.044700 | 0.056019 | 0.046794 | 0.038094 | 0.027488 |
k=13 (M) | 0.045632 | 0.034822 | 0.052229 | 0.021692 | 0.039653 |
k=14 (N) | 0.035123 | 0.050406 | 0.019438 | 0.024067 | 0.056986 |
k=15 (O) | 0.023015 | 0.037482 | 0.046163 | 0.050536 | 0.058191 |
k=16 (P) | 0.031419 | 0.024302 | 0.035848 | 0.034614 | 0.031820 |
k=17 (Q) | 0.034497 | 0.025424 | 0.052284 | 0.049642 | 0.029912 |
k=18 (R) | 0.029572 | 0.031274 | 0.032931 | 0.026295 | 0.042725 |
k=19 (S) | 0.027484 | 0.044015 | 0.031383 | 0.037050 | 0.046068 |
k=20 (T) | 0.051330 | 0.047532 | 0.043297 | 0.040039 | 0.036849 |
k=21 (U) | 0.034691 | 0.045869 | 0.024400 | 0.022020 | 0.029838 |
k=22 (V) | 0.054835 | 0.028627 | 0.031971 | 0.039436 | 0.062661 |
k=23 (W) | 0.033373 | 0.035513 | 0.047827 | 0.030642 | 0.026361 |
k=24 (X) | 0.048700 | 0.022777 | 0.034515 | 0.022410 | 0.026991 |
k=25 (Y) | 0.033561 | 0.023278 | 0.045237 | 0.034797 | 0.027990 |
k=26 (Z) | 0.050657 | 0.023858 | 0.040665 | 0.025854 | 0.028682 |
上面的例子已经给出了网络\(\mathcal{N}_w\)输出\(y\)的描述,与这段音频所对应的标签\(label\),应该是‘C’、‘A’、‘T’这三个字母,将它转换成用字母表中序号表示
\]
CTC计算的推导
在论文中CTCLoss的计算公式为
\]
那上面这个公式表示的含义是什么呢?
符号\(S\)一个训练样本集合,它是总体分布的一个子集。
\(x,z \in S\),\(x\)是训练样本集合\(S\)中原始的数据经过网络\(\mathcal{N}_w\)后的输出,\(z\)是与\(x\)相对应的标签。
\(p(z|x)\)表示以\(\mathcal{N}_w\)的输出\(x\),将\(x\)恢复为标签\(z\)的概率,也就是\(z\)相对于\(x\)的条件概率。
这样将样本集合\(S\)中每一条样本的\(p(z|x)\)相乘,就是样本\(S\)对于\(\mathcal{N}_w\)似然函数:
\[L(S,\mathcal{N}_w)=\prod_{x,z \in S}{p(z|x)}
\]我们通过训练调整网络\(\mathcal{N}_w\)的参数\(w\),使\(ln{(L(S,\mathcal{N}_w))}\)最大,这个过程就叫最大似然估计。
为了方便计算,我们在等式两边取\(ln\),这就是对数似然函数。
\[ln{(L(S,\mathcal{N}_w))}=\sum_{x,z \in S}{ln{(p(z|x))}}
\]因为似然函数是越大表示结果越好,而损失函数是越小则表示结果越好所以需要一个负号
\[O^{ML}(S,\mathcal{N}_w)=-ln{(S,\mathcal{N}_w)}=-\sum_{x,z \in S}ln(p(z|x))
\]
总概率\(p(z|x)\)
CTCLoss中最关键的就是计算每一条样本\({\{x,z\}} \in S\)的条件概率\(p(z|x)\),\(z\)是目标标签与\(x\)是一一对应关系,\(l\)是任意标签只要是符合字母表规则的标签都是可以的,而\(z\)只是符合\(l\)规则中的一条。在训练的时候可以指定\(l=z\),但在公式推导时应该更严谨更泛化一些。因此\(p(z|x)\)可以用作\(p(l|x)\)替代,下面给出\(p(l|x)\)的计算公式
\]
路径的含义
已知网络\(\mathcal{N}_w\)的输出\(x\in\mathbb{R}^{K \times T}\),它有\(T\)个时间点,并在每个时间点中有\(K\)种输出的可能,一共有\(K^ T\)条路径。在上面的例子中\(K=27,T=5\)所以一共就有\(27^5=14348907\)条可能的路径。仅仅\(T=5\)时,总路径条数已经相当的巨大了。
路径概率\(p(\pi|x)\)
表1已经给出于每个时刻所有的字母概率,由每个时刻选出的字母将组成一条路径,那么这条路径的概率就等于各个时刻选择字母的概率的乘积。
p(\pi|x)&=\prod_{t=1}^{T}{y_{k=\pi^t}^t} \\
&=y_{k=\pi^1}^1\times y_{k=\pi^2}^2\times y_{k=\pi^3}^3\times...\times y_{k=\pi^T}^T
\end{aligned}\]
什么是\(\mathcal{B}\)变换
在上面提到的\(27^5\)条路径中\(\mathcal{B}\)变换就是将路径中所有的blank\((-)\),和相邻重复的元素删除,比如
\]
\]
同理符号\(\mathcal{B}^{-1}(l)\)则是\(\mathcal{B}(\pi)\)的逆变换。表示所有满足\(\mathcal{B}(\pi)=l\)的路径
\(p(l|x)\)并不是计算所有路径的概率之和,而是计算所有满足\(\mathcal{B}(\pi)=l\)的路径概率之和。
一步一步手动计算CTCLoss
现在就根据上面提供的例子,一步一步手动计算CTCLoss
找出所有满足\(\mathcal{B}(\pi)=l\),\(l\)=“CAT”的路径
在上面给出的\(27^5\)条路径中给出的符合\(\mathcal{B}(\pi)=l\),\(l\)=“CAT”共有28条,
如表2所示
表2 所有满足条件的路径,共28条
t=1 | t=2 | t=3 | t=4 | t=5 | |
---|---|---|---|---|---|
\(\pi_{1}\) | - | - | C | A | T |
\(\pi_{2}\) | - | C | - | A | T |
\(\pi_{3}\) | - | C | C | A | T |
\(\pi_{4}\) | - | C | A | - | T |
\(\pi_{5}\) | - | C | A | A | T |
\(\pi_{6}\) | - | C | A | T | - |
\(\pi_{7}\) | - | C | A | T | T |
\(\pi_{8}\) | C | - | - | A | T |
\(\pi_{9}\) | C | - | A | - | T |
\(\pi_{10}\) | C | - | A | A | T |
\(\pi_{11}\) | C | - | A | T | - |
\(\pi_{12}\) | C | - | A | T | T |
\(\pi_{13}\) | C | C | - | A | T |
\(\pi_{14}\) | C | C | C | A | T |
\(\pi_{15}\) | C | C | A | - | T |
\(\pi_{16}\) | C | C | A | A | T |
\(\pi_{17}\) | C | C | A | T | - |
\(\pi_{18}\) | C | C | A | T | T |
\(\pi_{19}\) | C | A | - | - | T |
\(\pi_{20}\) | C | A | - | T | - |
\(\pi_{21}\) | C | A | - | T | T |
\(\pi_{22}\) | C | A | A | - | T |
\(\pi_{23}\) | C | A | A | A | T |
\(\pi_{24}\) | C | A | A | T | - |
\(\pi_{25}\) | C | A | A | T | T |
\(\pi_{26}\) | C | A | T | - | - |
\(\pi_{27}\) | C | A | T | T | - |
\(\pi_{28}\) | C | A | T | T | T |
计算每条路径的概率\(p(\pi|x)\)
路径\(\pi_1\)所对应的标签为"- - C A T",这段序列转换为字母表中的索引,
则路径\(\pi_1\)在每个时刻的取值如下
\]
\]
\]
\]
\]
因此路径\(\pi_1的概率\)\(p(\pi_1|x)\)的计算如下
p(\pi_1|x)&=\prod_{t=1}^{T}{y_{k=\pi_1^t}^t} \\
&=y_{k=\pi_1^1}^1\times y_{k=\pi_1^2}^2 \times y_{k=\pi_1^3}^3 \times ... \times y_{k=\pi_1^T}^T \\
&=y_{0}^1 \times y_{0}^2 \times y_{3}^3 \times y_{1}^4\times y_{20}^5 \\
&=0.031953 \times0.044296\times0.031872\times0.027295\times0.036849 \\
&=4.5373e^{-8}
\end{aligned}\]
同理可计算
p(\pi_2|x)=5.6482e^{-8},
p(\pi_3|x)=4.7006e^{-8},
p(\pi_4|x)=6.6003e^{-8}\]
p(\pi_6|x)=5.1401e^{-8},
p(\pi_7|x)=6.8965e^{-8},
p(\pi_8|x)=5.0050e^{-8}\]
p(\pi_{10}|x)=4.1660e^{-8},
p(\pi_{11}|x)=4.5547e^{-8},
p(\pi_{12}|x)=6.1111e^{-8}\]
p(\pi_{14}|x)=4.3151e^{-8},
p(\pi_{15}|x)=6.0590e^{-8},
p(\pi_{16}|x)=4.3158e^{-8}\]
p(\pi_{18}|x)=6.3309e^{-8},
p(\pi_{19}|x)=4.8163e^{-8},
p(\pi_{20}|x)=3.7508e^{-8}\]
p(\pi_{22}|x)=4.0090e^{-8},
p(\pi_{23}|x)=2.8556e^{-8},
p(\pi_{24}|x)=3.1220e^{-8}\]
p(\pi_{26}|x)=4.0583e^{-8},
p(\pi_{27}|x)=4.2404e^{-8},
p(\pi_{28}|x)=5.6894e^{-8}\]
计算总概率\(p(l|x)\)
\(p(l|x)\)是所有满足\(\mathcal{B}(\pi)=l\)的路径概率之和。
p(l|x)&=\sum_{\pi \in \mathcal{B}^{-1}(l)}{p(\pi|x)} \\
&=p(\pi_1|x)+p(\pi_2|x)+p(\pi_1|x)+...+p(\pi_{28}|x) \\
&=4.5374e^{-8} + 5.6482e^{-8}+4.7006e^{-8}+...+5.6894e^{-8} \\
&=1.366e^{-6}
\end{aligned}\]
计算损失函数CTCLoss
由于例子中只给了1样本,所以下面的损失函数CTCLoss也就只有这一个样本的损失。
O^{ML}(S,\mathcal{N}_w)&=-ln{(S,\mathcal{N}_w)} \\
&=-\sum_{x,z \in S}ln(p(z|x)) \\
&=-ln(p(z|x)) \\
&=-ln(1.366e^{-6}) \\
&=\ 13.5036
\end{aligned}\]
CTCLoss库函数的验证
网络\(\mathcal{N}_w\)输出\(y\_out\)的softmax处理
这里有一点需要解释一下,CTCLoss的输入\(ctc\_input\)与网络\(\mathcal{N}_w\)的输出\(y\_out\)之间的关系。
在网络\(\mathcal{N}_w\)输出的最后一级是没有softmax,所以\(y\_out\)在每一个时间点的的概率和都不为1,为了将概率分布归一化需要将\(y\)进行softmax计算。
\]
同时CTCLoss中包含有大量的概率的乘法运算,需要将\(y\_softmax\)进行\(ln\)计算,
这样可以将乘法转换为加法计算,提升计算的速度。
\]
上面的例子,为了让文档更直观,已经默认
\]
下表就是\(y\_out\),显然每一列之和不为1。
\(y\_out_t^k\) | t=1 | t=2 | t=3 | t=4 | t=5 |
---|---|---|---|---|---|
k=0 (-) | 0.347713 | 0.755077 | 0.678652 | 0.585987 | 0.123084 |
k=1 (A) | 0.149997 | 0.377396 | 0.495177 | 0.246735 | 0.205494 |
k=2 (B) | 0.586092 | 0.216019 | 0.189710 | 0.666416 | 0.146515 |
k=3 (C) | 0.262145 | 0.790407 | 0.495006 | 0.083483 | 0.189072 |
k=4 (D) | 0.044454 | 0.949304 | 0.147608 | 0.625960 | 0.042652 |
k=5 (E) | 0.754933 | 0.327565 | 0.054974 | 0.660945 | 0.635198 |
k=6 (F) | 0.242785 | 0.671264 | 0.850713 | 0.729752 | 0.281867 |
k=7 (G) | 0.442402 | 0.438645 | 0.560560 | 0.890752 | 0.538597 |
k=8 (H) | 0.687796 | 0.833501 | 0.929609 | 0.982303 | 0.695163 |
k=9 (I) | 0.359228 | 0.768854 | 0.696667 | 0.769029 | 0.499116 |
k=10 (J) | 0.736340 | 0.167254 | 0.582791 | 0.581446 | 0.535801 |
k=11 (K) | 0.394707 | 0.861980 | 0.815397 | 0.928313 | 0.445183 |
k=12 (L) | 0.683416 | 0.989872 | 0.879014 | 0.580090 | 0.123932 |
k=13 (M) | 0.704047 | 0.514423 | 0.988912 | 0.016983 | 0.490357 |
k=14 (N) | 0.442305 | 0.884281 | 0.000522 | 0.120860 | 0.852998 |
k=15 (O) | 0.019578 | 0.588026 | 0.865439 | 0.862711 | 0.873927 |
k=16 (P) | 0.330858 | 0.154752 | 0.612566 | 0.484297 | 0.270294 |
k=17 (Q) | 0.424309 | 0.199863 | 0.989950 | 0.844856 | 0.208461 |
k=18 (R) | 0.270270 | 0.406955 | 0.527680 | 0.209405 | 0.564980 |
k=19 (S) | 0.197054 | 0.748706 | 0.479523 | 0.552291 | 0.640312 |
k=20 (T) | 0.821721 | 0.825584 | 0.801348 | 0.629883 | 0.417029 |
k=21 (U) | 0.429921 | 0.789963 | 0.227843 | 0.031991 | 0.205976 |
k=22 (V) | 0.887771 | 0.318524 | 0.498094 | 0.614713 | 0.947933 |
k=23 (W) | 0.391183 | 0.534064 | 0.900852 | 0.362411 | 0.082071 |
k=24 (X) | 0.769114 | 0.089951 | 0.574661 | 0.049533 | 0.105709 |
k=25 (Y) | 0.396792 | 0.111706 | 0.845178 | 0.489570 | 0.142041 |
k=26 (Z) | 0.808514 | 0.136293 | 0.738640 | 0.192510 | 0.166460 |
pytorch库函数验证
CTCLoss使用细节可以参考pytorch官方文档
import torch
import torch.nn as nn
import numpy as np
y_softmax = np.array([
[[0.0319533345695271, 0.0262210133693412, 0.0405548727460100, 0.0293328834922530, 0.0235946021815836, 0.0480142162870594, 0.0287704618407728, 0.0351268637054168, 0.0448965052477630, 0.0323234212279283, 0.0471297269219778, 0.0334908192070999, 0.0447002788315031, 0.0456320948241136,
0.0351234600906292, 0.0230148922614546, 0.0314192811142228, 0.0344970346892286, 0.0295721871384341, 0.0274843752526059, 0.0513304969210734, 0.0346911732659917, 0.0548353372646645, 0.0333729892573427, 0.0486999624899632, 0.0335606882517763, 0.0506570275502634]],
[[0.0442961938109001, 0.0303627704208565, 0.0258378526020265, 0.0458891577161975, 0.0537920435977104, 0.0288868677848477, 0.0407349328912650, 0.0322806067098565, 0.0479099042067772, 0.0449106925711146, 0.0246080887866719, 0.0492939884049119, 0.0560191619281624, 0.0348218517081914,
0.0504056201105211, 0.0374815087428365, 0.0243023731122621, 0.0254237678526359, 0.0312736688595233, 0.0440148630768450, 0.0475321094768427, 0.0458687788283468, 0.0286268732637606, 0.0355125367928648, 0.0227774801386588, 0.0232784351056503, 0.0238578714997625]],
[[0.0382974368377362, 0.0318777312135849, 0.0234868589674224, 0.0318722744011979, 0.0225185381516373, 0.0205262552943881, 0.0454877627911883, 0.0340316234294017, 0.0492219436202117, 0.0389936131137926, 0.0347966678592871, 0.0439093761642613, 0.0467935124498177, 0.0522292290638150,
0.0194384495697102, 0.0461625681675025, 0.0358483354617907, 0.0522835019782284, 0.0329308772273817, 0.0313826141807340, 0.0432967801742709, 0.0243997674509821, 0.0319708630090250, 0.0478266566415420, 0.0345149265806327, 0.0452367066323343, 0.0406651295681235]],
[[0.0383195501689954, 0.0272951137973125, 0.0415288927451887, 0.0231838517718695, 0.0398823138441169, 0.0413022813256117, 0.0442442329310963, 0.0519730489462436, 0.0569558497142297, 0.0460166028890008, 0.0381459528257684, 0.0539623316283564, 0.0380942573161036, 0.0216922730554261,
0.0240667868142706, 0.0505358960731075, 0.0346143968556499, 0.0496415831760055, 0.0262949856792733, 0.0370498580320465, 0.0400391034751884, 0.0220202876462848, 0.0394362954874324, 0.0306423990773223, 0.0224099657044701, 0.0347974172676594, 0.0258544717519696]],
[[0.0274643882294982, 0.0298236175649923, 0.0281155092606543, 0.0293378537462717, 0.0253418924737544, 0.0458330002578632, 0.0321905618820226, 0.0416126048467898, 0.0486654573566434, 0.0400017201897758, 0.0414964341812715, 0.0379014590893513, 0.0274877024782956, 0.0396528862221281,
0.0569859416555112, 0.0581911831104043, 0.0318201830875284, 0.0299122412570334, 0.0427250763149338, 0.0460679863549903, 0.0368492548068844, 0.0298379764585031, 0.0626610201269008, 0.0263607892806820, 0.0269913345266294, 0.0279900073565483, 0.0286819178841385]]
]).astype("float32")
labels = np.array([[3, 1, 20]]).astype("int32")
input_lengths = np.array([5]).astype("int64")
label_lengths = np.array([3]).astype("int64")
ctc_input = torch.tensor(y_softmax).log()
labels = torch.tensor(labels)
input_lengths = torch.tensor(input_lengths)
label_lengths = torch.tensor(label_lengths)
ctc_loss = nn.CTCLoss(reduction='none')
loss = ctc_loss(ctc_input, labels, input_lengths, label_lengths)
print('loss is {}'.format(loss))
loss is tensor([13.5036])
paddle库函数的使用
CTCLoss使用细节可以参考
paddle官方文档
由于paddle的CTCLoss库底层已经实现了log_softmax,所以它的输入可以直接为\(y\_out\)
import numpy as np
import paddle
import paddle.nn.functional as F
y_out = np.array([
[[0.347712671277525, 0.149997253831683, 0.586092067231462, 0.262145317727807, 0.0444540922782385, 0.754933267231179, 0.242785357820962, 0.442402313001943, 0.687796085120107, 0.359228210401861, 0.736340074301202, 0.394707475278763, 0.683415866967978, 0.704047430334266,
0.442305413383371, 0.0195776235533187, 0.330857880214071, 0.424309496833137, 0.270270423432065, 0.197053798095456, 0.821721184961310, 0.429921409383266, 0.887770954256354, 0.391182995461163, 0.769114387388296, 0.396791517013617, 0.808514095887345]],
[[0.755077099007084, 0.377395544835103, 0.216018915961394, 0.790407217966913, 0.949303911849797, 0.327565434075205, 0.671264370451740, 0.438644982586956, 0.833500595588975, 0.768854252429615, 0.167253545494722, 0.861980478702072, 0.989872153631504, 0.514423456505704,
0.884281023126955, 0.588026055308498, 0.154752348656045, 0.199862822857452, 0.406954837138907, 0.748705718215691, 0.825583815786156, 0.789963029944531, 0.318524245398992, 0.534064127370726, 0.0899506787705811, 0.111705744193203, 0.136292548938299]],
[[0.678652304800188, 0.495177019089661, 0.189710406017580, 0.495005824990221, 0.147608221976689, 0.0549741469061882, 0.850712674289007, 0.560559527354885, 0.929608866756663, 0.696667200555228, 0.582790965175840, 0.815397211477421, 0.879013904597178, 0.988911616079589,
0.000522375356944771, 0.865438591013025, 0.612566469483999, 0.989950205708831, 0.527680069338442, 0.479523385210219, 0.801347605521952, 0.227842935706042, 0.498094291196390, 0.900852488532005, 0.574661219130188, 0.845178185054037, 0.738640291995402]],
[[0.585987035826476, 0.246734525985975, 0.666416217319468, 0.0834828136026227, 0.625959785171583, 0.660944557947342, 0.729751855317221, 0.890752116325322, 0.982303222883606, 0.769029085335896, 0.581446487875398, 0.928313062314188, 0.580090365758442, 0.0169829383372613,
0.120859571098558, 0.862710718699670, 0.484296511212103, 0.844855674576263, 0.209405084020935, 0.552291341538775, 0.629883385064421, 0.0319910157625669, 0.614713419117141, 0.362411462273053, 0.0495325790420612, 0.489569989177322, 0.192510396062075]],
[[0.123083747545945, 0.205494170907680, 0.146514910614890, 0.189072174472614, 0.0426524109111434, 0.635197916859882, 0.281866855880430, 0.538596678045340, 0.695163039444332, 0.499116013482590, 0.535801055751113, 0.445183165296042, 0.123932277598070, 0.490357293468018,
0.852998155340816, 0.873927405861733, 0.270294332292698, 0.208461358751314, 0.564979570738201, 0.640311825162758, 0.417028951642886, 0.205975515532243, 0.947933121293169, 0.0820712070977259, 0.105709426581721, 0.142041121903998, 0.166460440876421]]
]).astype("float32")
labels = np.array([[3, 1, 20]]).astype("int32")
input_lengths = np.array([5]).astype("int64")
label_lengths = np.array([3]).astype("int64")
y_out=paddle.to_tensor(y_out)
labels = paddle.to_tensor(labels)
input_lengths = paddle.to_tensor(input_lengths)
label_lengths = paddle.to_tensor(label_lengths)
loss = paddle.nn.CTCLoss(blank=0, reduction='none')(y_out, labels,
input_lengths,
label_lengths)
print('loss is {}'.format(loss))
loss is Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
[13.50364304])
CTCLoss如何使用的更多相关文章
- 语音识别中的CTC算法的基本原理解释
欢迎大家前往腾讯云+社区,获取更多腾讯海量技术实践干货哦~ 本文作者:罗冬日 目前主流的语音识别都大致分为特征提取,声学模型,语音模型几个部分.目前结合神经网络的端到端的声学模型训练方法主要CTC和基 ...
- 【OCR技术系列之八】端到端不定长文本识别CRNN代码实现
CRNN是OCR领域非常经典且被广泛使用的识别算法,其理论基础可以参考我上一篇文章,本文将着重讲解CRNN代码实现过程以及识别效果. 数据处理 利用图像处理技术我们手工大批量生成文字图像,一共360万 ...
- CTC+pytorch编译配置warp-CTC
CTC CTC可以生成一个损失函数,用于在序列数据上进行监督式学习,不需要对齐输入数据及标签,经常连接在一个RNN网络的末端,训练端到端的语音和文本识别系统.CTC论文地址:http://www.cs ...
- 服务器个人环境下pytorch0.4.1编译warp-ctc遇到的问题及解决方法
一.关于warp-ctc CTC可以生成一个损失函数,用于在序列数据上进行监督式学习,不需要对齐输入数据及标签,经常连接在一个RNN网络的末端,训练端到端的语音或文本识别系统.CTC论文 CTC网络的 ...
- 从零和使用mxnet实现线性回归
1.线性回归从零实现 from mxnet import ndarray as nd import matplotlib.pyplot as plt import numpy as np import ...
- Pytorch的19种损失函数
基本用法 12 criterion = LossCriterion() loss = criterion(x, y) # 调用标准时也有参数 损失函数 L1范数损失:L1Loss 计算 output ...
- [PyTorch 学习笔记] 4.2 损失函数
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/loss_function_1.py https:// ...
- pytorch(16)损失函数(二)
5和6是在数据回归中用的较多的损失函数 5. nn.L1Loss 功能:计算inputs与target之差的绝对值 代码: nn.L1Loss(reduction='mean') 公式: \[l_n ...
- PaddleOCR详解
@ 目录 PaddleOCR简介 环境配置 PaddleOCR2.0的配置环境 Docker 数据集 文本检测 使用自己的数据集 文本识别 使用自己的数据集 字典 自定义字典 添加空格类别 文本角度分 ...
随机推荐
- C++实例2--职工管理系统
职工管理系统 1. 头文件 1.1 workerManager.h 系统类 1 #pragma once // 防止头文件重复包含 2 #include<iostream> // 包含输 ...
- 印尼医疗龙头企业Halodoc的数据平台转型之路:数据平台V1.0
1. 摘要 数据是每项技术业务的支柱,作为一个健康医疗技术平台,Halodoc 更是如此,用户可以通过以下方式与 Halodoc 交互: 送药 与医生交谈 实验室测试 医院预约和药物 所有这些交互都会 ...
- SSH只能用于远程Linux主机?那说明你见识太小了!
开源Linux 长按二维码加关注~ 今天小编为大家分享一篇关于SSH 的介绍和使用方法的文章.本文从SSH是什么出发,讲述了SSH的基本用法,之后在远程登录.端口转发等多种场景下进行独立的讲述,希望能 ...
- 1 Mybatis动态SQL
Mybatis动态SQL 1. 注解开发 我们也可以使用注解的形式来进行开发,用注解来替换掉xml. 使用注解来映射简单语句会使代码显得更加简洁,但对于稍微复杂一点的语句,Java 注解不仅力不从 ...
- 精彩分享 | 欢乐游戏 Istio 云原生服务网格三年实践思考
作者 吴连火,腾讯游戏专家开发工程师,负责欢乐游戏大规模分布式服务器架构.有十余年微服务架构经验,擅长分布式系统领域,有丰富的高性能高可用实践经验,目前正带领团队完成云原生技术栈的全面转型. 导语 欢 ...
- 图解Dijkstra算法+代码实现
简介 Dijkstra(迪杰斯特拉)算法是典型的单源最短路径算法,用于计算一个节点到其他所有节点的最短路径.主要特点是以起始点为中心向外层层扩展,直到扩展到终点为止.Dijkstra算法是很有代表性的 ...
- 一起看 I/O | Flutter 3 更新详解
作者 / Kevin Jamaul Chisholm, Technical Program Manager for Dart and Flutter at Google 又到了 Flutter 稳定版 ...
- 工具分享:清理 Markdown 中没有引用的图片
前言: 之前,我写笔记的工具一直都是 notion,而且没有写博客的习惯.但是一是由于 notion 的服务器在国外,有时候很不稳定:二是由于 notion 的分享很不方便,把笔记分享给别人点开链接之 ...
- 【原创】项目四Tr0ll-1
实战流程 1.nmap枚举 nmap -sP 192.168.186.0/24 nmap -p- 192.168.186.142 nmap 192.168.186.142 -p- -sS -sV -A ...
- java中关于@override注解的使用
@Override是伪代码,表示重写,作用有:1.可以当注释用,方便阅读:2.编译器可以给你验证@Override下面的方法名是否是你父类中所有的,如果没有则报错.例如:如果想重写父类的方法,比如to ...