GRU-CTC中文语音识别
基于keras的中文语音识别
- 该项目实现了GRU-CTC中文语音识别,所有代码都在
gru_ctc_am.py
中,包括:- 音频文件特征提取
- 文本数据处理
- 数据格式处理
- 构建模型
- 模型训练及解码
- 之外还包括将aishell数据处理为thchs30数据格式,合并数据进行训练。代码及数据放在
gen_aishell_data
中。
默认数据集为thchs30,参考gen_aishell_data中的数据及代码,也可以使用aishell的数据进行训练。
音频文件特征提取
# -----------------------------------------------------------------------------------------------------
'''
&usage: [audio]对音频文件进行处理,包括生成总的文件列表、特征提取等
'''
# -----------------------------------------------------------------------------------------------------
# 生成音频列表
def genwavlist(wavpath):
wavfiles = {}
fileids = []
for (dirpath, dirnames, filenames) in os.walk(wavpath):
for filename in filenames:
if filename.endswith('.wav'):
filepath = os.sep.join([dirpath, filename])
fileid = filename.strip('.wav')
wavfiles[fileid] = filepath
fileids.append(fileid)
return wavfiles,fileids
# 对音频文件提取mfcc特征
def compute_mfcc(file):
fs, audio = wav.read(file)
mfcc_feat = mfcc(audio, samplerate=fs, numcep=26)
mfcc_feat = mfcc_feat[::3]
mfcc_feat = np.transpose(mfcc_feat)
mfcc_feat = pad_sequences(mfcc_feat, maxlen=500, dtype='float', padding='post', truncating='post').T
return mfcc_feat
文本数据处理
# -----------------------------------------------------------------------------------------------------
'''
&usage: [text]对文本标注文件进行处理,包括生成拼音到数字的映射,以及将拼音标注转化为数字的标注转化
'''
# -----------------------------------------------------------------------------------------------------
# 利用训练数据生成词典
def gendict(textfile_path):
dicts = []
textfile = open(textfile_path,'r+')
for content in textfile.readlines():
content = content.strip('\n')
content = content.split(' ',1)[1]
content = content.split(' ')
dicts += (word for word in content)
counter = Counter(dicts)
words = sorted(counter)
wordsize = len(words)
word2num = dict(zip(words, range(wordsize)))
num2word = dict(zip(range(wordsize), words))
return word2num, num2word #1176个音素
# 文本转化为数字
def text2num(textfile_path):
lexcion,num2word = gendict(textfile_path)
word2num = lambda word:lexcion.get(word, 0)
textfile = open(textfile_path, 'r+')
content_dict = {}
for content in textfile.readlines():
content = content.strip('\n')
cont_id = content.split(' ',1)[0]
content = content.split(' ',1)[1]
content = content.split(' ')
content = list(map(word2num,content))
add_num = list(np.zeros(50-len(content)))
content = content + add_num
content_dict[cont_id] = content
return content_dict,lexcion
数据格式处理
# -----------------------------------------------------------------------------------------------------
'''
&usage: [data]数据生成器构造,用于训练的数据生成,包括输入特征及标注的生成,以及将数据转化为特定格式
'''
# -----------------------------------------------------------------------------------------------------
# 将数据格式整理为能够被网络所接受的格式,被data_generator调用
def get_batch(x, y, train=False, max_pred_len=50, input_length=500):
X = np.expand_dims(x, axis=3)
X = x # for model2
# labels = np.ones((y.shape[0], max_pred_len)) * -1 # 3 # , dtype=np.uint8
labels = y
input_length = np.ones([x.shape[0], 1]) * ( input_length - 2 )
# label_length = np.ones([y.shape[0], 1])
label_length = np.sum(labels > 0, axis=1)
label_length = np.expand_dims(label_length,1)
inputs = {'the_input': X,
'the_labels': labels,
'input_length': input_length,
'label_length': label_length,
}
outputs = {'ctc': np.zeros([x.shape[0]])} # dummy data for dummy loss function
return (inputs, outputs)
# 数据生成器,默认音频为thchs30\train,默认标注为thchs30\train.syllable,被模型训练方法fit_generator调用
def data_generate(wavpath = 'E:\\Data\\data_thchs30\\train', textfile = 'E:\\Data\\thchs30\\train.syllable.txt', bath_size=4):
wavdict,fileids = genwavlist(wavpath)
#print(wavdict)
content_dict,lexcion = text2num(textfile)
genloop = len(fileids)//bath_size
print("all loop :", genloop)
while True:
feats = []
labels = []
# 随机选择某个音频文件作为训练数据
i = random.randint(0,genloop-1)
for x in range(bath_size):
num = i * bath_size + x
fileid = fileids[num]
# 提取音频文件的特征
mfcc_feat = compute_mfcc(wavdict[fileid])
feats.append(mfcc_feat)
# 提取标注对应的label值
labels.append(content_dict[fileid])
# 将数据格式修改为get_batch可以处理的格式
feats = np.array(feats)
labels = np.array(labels)
# 调用get_batch将数据处理为训练所需的格式
inputs, outputs = get_batch(feats, labels)
yield inputs, outputs
构建模型
# -----------------------------------------------------------------------------------------------------
'''
&usage: [net model]构件网络结构,用于最终的训练和识别
'''
# -----------------------------------------------------------------------------------------------------
# 被creatModel调用,用作ctc损失的计算
def ctc_lambda(args):
labels, y_pred, input_length, label_length = args
y_pred = y_pred[:, :, :]
return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
# 构建网络结构,用于模型的训练和识别
def creatModel():
input_data = Input(name='the_input', shape=(500, 26))
layer_h1 = Dense(512, activation="relu", use_bias=True, kernel_initializer='he_normal')(input_data)
#layer_h1 = Dropout(0.3)(layer_h1)
layer_h2 = Dense(512, activation="relu", use_bias=True, kernel_initializer='he_normal')(layer_h1)
layer_h3_1 = GRU(512, return_sequences=True, kernel_initializer='he_normal', dropout=0.3)(layer_h2)
layer_h3_2 = GRU(512, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', dropout=0.3)(layer_h2)
layer_h3 = add([layer_h3_1, layer_h3_2])
layer_h4 = Dense(512, activation="relu", use_bias=True, kernel_initializer='he_normal')(layer_h3)
#layer_h4 = Dropout(0.3)(layer_h4)
layer_h5 = Dense(1177, activation="relu", use_bias=True, kernel_initializer='he_normal')(layer_h4)
output = Activation('softmax', name='Activation0')(layer_h5)
model_data = Model(inputs=input_data, outputs=output)
#ctc
labels = Input(name='the_labels', shape=[50], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
loss_out = Lambda(ctc_lambda, output_shape=(1,), name='ctc')([labels, output, input_length, label_length])
model = Model(inputs=[input_data, labels, input_length, label_length], outputs=loss_out)
model.summary()
ada_d = Adadelta(lr=0.01, rho=0.95, epsilon=1e-06)
#model=multi_gpu_model(model,gpus=2)
model.compile(loss={'ctc': lambda y_true, output: output}, optimizer=ada_d)
#test_func = K.function([input_data], [output])
print("model compiled successful!")
return model, model_data
模型训练及解码
# -----------------------------------------------------------------------------------------------------
'''
&usage: 模型的解码,用于将数字信息映射为拼音
'''
# -----------------------------------------------------------------------------------------------------
# 对model预测出的softmax的矩阵,使用ctc的准则解码,然后通过字典num2word转为文字
def decode_ctc(num_result, num2word):
result = num_result[:, :, :]
in_len = np.zeros((1), dtype = np.int32)
in_len[0] = 50;
r = K.ctc_decode(result, in_len, greedy = True, beam_width=1, top_paths=1)
r1 = K.get_value(r[0][0])
r1 = r1[0]
text = []
for i in r1:
text.append(num2word[i])
return r1, text
# -----------------------------------------------------------------------------------------------------
'''
&usage: 模型的训练
'''
# -----------------------------------------------------------------------------------------------------
# 训练模型
def train():
# 准备训练所需数据
yielddatas = data_generate()
# 导入模型结构,训练模型,保存模型参数
model, model_data = creatModel()
model.fit_generator(yielddatas, steps_per_epoch=2000, epochs=1)
model.save_weights('model.mdl')
model_data.save_weights('model_data.mdl')
# -----------------------------------------------------------------------------------------------------
'''
&usage: 模型的测试,看识别结果是否正确
'''
# -----------------------------------------------------------------------------------------------------
# 测试模型
def test():
# 准备测试数据,以及生成字典
word2num, num2word = gendict('E:\\Data\\thchs30\\train.syllable.txt')
yielddatas = data_generate(bath_size=1)
# 载入训练好的模型,并进行识别
model, model_data = creatModel()
model_data.load_weights('model_data.mdl')
result = model_data.predict_generator(yielddatas, steps=1)
# 将数字结果转化为文本结果
result, text = decode_ctc(result, num2word)
print('数字结果: ', result)
print('文本结果:', text)
aishell数据转化
将aishell中的汉字标注转化为拼音标注,利用该数据与thchs30数据训练同样的网络结构。
该模型作为一个练手小项目。
没有使用语言模型,直接简单建模。
我的github: https://github.com/audier
GRU-CTC中文语音识别的更多相关文章
- 基于深度学习的中文语音识别系统框架(pluse)
目录 声学模型 GRU-CTC DFCNN DFSMN 语言模型 n-gram CBHG 数据集 本文搭建一个完整的中文语音识别系统,包括声学模型和语言模型,能够将输入的音频信号识别为汉字. 声学模型 ...
- python使用vosk进行中文语音识别
操作系统:Windows10 Python版本:3.9.2 vosk是一个离线开源语音识别工具,它可以识别16种语言,包括中文. 这里记录下使用vosk进行中文识别的过程,以便后续查阅. vosk地址 ...
- pyttsx的中文语音识别问题及探究之路
最近在学习pyttsx时,发现中文阅读一直都识别错误,从发音来看应该是字符编码问题,但搜索之后并未发现解决方案.自己一路摸索解决,虽说最终的原因非常可笑,大牛们可能也是一眼就能洞穿,但也值得记录一下. ...
- Unity中使用百度中文语音识别功能
下面是API类 Asr.cs using System; using System.Collections; using System.Collections.Generic; using Unity ...
- 深度学习实战篇-基于RNN的中文分词探索
深度学习实战篇-基于RNN的中文分词探索 近年来,深度学习在人工智能的多个领域取得了显著成绩.微软使用的152层深度神经网络在ImageNet的比赛上斩获多项第一,同时在图像识别中超过了人类的识别水平 ...
- [DeeplearningAI笔记]序列模型3.9-3.10语音辨识/CTC损失函数/触发字检测
5.3序列模型与注意力机制 觉得有用的话,欢迎一起讨论相互学习~Follow Me 3.9语音辨识 Speech recognition 问题描述 对于音频片段(audio clip)x ,y生成文本 ...
- Python实现各类验证码识别
项目地址: https://github.com/kerlomz/captcha_trainer 编译版下载地址: https://github.com/kerlomz/captcha_trainer ...
- TensorFlow练习13: 制作一个简单的聊天机器人
现在很多卖货公司都使用聊天机器人充当客服人员,许多科技巨头也纷纷推出各自的聊天助手,如苹果Siri.Google Now.Amazon Alexa.微软小冰等等.前不久有一个视频比较了Google N ...
- linux install Openvino
recommend centos7 github Openvino tooltiks 1. download openvino addational installation for ncs2 ncs ...
随机推荐
- Java Bean与Map之间相互转化的实现
目录树 概述 Apache BeanUtils将Bean转Map Apache BeanUtils将Map转Bean 理解BeanUtils将Bean转Map的实现之手写Bean转Map 概述 Apa ...
- CoacoaPods安装使与使用超级详细教程
对于一个iOS开发的初学者来说,并不知道第三方类库的存在,知道了也不知道如何使用,那么下面便来介绍一下使用方法. iOS开发常用的第三方类库是GitHub:https://github.com/ 在上 ...
- 关于MySQL优化问题
众所周知在数据量庞大的情况下普通的SQL语句已经满足不了我们的需要了,这个时候就需要DBA去进行数据库的优化,而我们作为一名开发人员不能对数据库进行优化这时该怎么办呢?答案是只能在SQL语句上面进行优 ...
- 学习 Linux_kernel_exploits 小记
Linux_kernel_exploits+ 功能:自动生成UAF类型漏洞exp文件的工具,目前缺少文档介绍,可以参考test文件下的使用实例,但是源码中缺少dataflowanalyzer模块+ 相 ...
- 06.升级git版本及命令学习
博客为日常工作学习积累总结: 1.升级git版本: 参考博客:https://blog.csdn.net/yuexiahunone/article/details/78647565由于新的版本可以使用 ...
- layui 图片与表单一起提交 + layer.photos图片层预览
HTML基本结构: <form class="layui-form" action="" id="feedBackForm"> ...
- 07JavaScript数据类型
JavaScript 数据类型 值类型(基本类型):字符串(String).数字(Number).布尔(Boolean).对空(Null).未定义(Undefined).Symbol. 引用数据类型: ...
- SQL Server 2012 - SQL查询
执行计划显示SQL执行的开销 工具→ SQL Server Profiler : SQL Server 分析器,监视系统调用的SQL Server查询 Top查询 -- Top Percent 选择百 ...
- jQuery 动画效果 与 动画队列
基础效果 .hide([duration ] [,easing ] [,complete ]) 用于隐藏元素,没有参数的时候等同于直接设置 display 属性 $('.target').hide() ...
- php函数strtotime结合date时间修饰语的使用
下面简单介绍在项目开发中date时间函数和strtotime所遇到的问题,以及解决办法. 原文地址:小时刻个人技术博客 > http://small.aiweimeng.top/index.ph ...