该项目github地址

基于keras的中文语音识别

  • 该项目实现了GRU-CTC中文语音识别,所有代码都在gru_ctc_am.py中,包括:

    • 音频文件特征提取
    • 文本数据处理
    • 数据格式处理
    • 构建模型
    • 模型训练及解码
  • 之外还包括将aishell数据处理为thchs30数据格式,合并数据进行训练。代码及数据放在gen_aishell_data中。

默认数据集为thchs30,参考gen_aishell_data中的数据及代码,也可以使用aishell的数据进行训练。

音频文件特征提取

# -----------------------------------------------------------------------------------------------------
'''
&usage: [audio]对音频文件进行处理,包括生成总的文件列表、特征提取等
'''
# -----------------------------------------------------------------------------------------------------
# 生成音频列表
def genwavlist(wavpath):
wavfiles = {}
fileids = []
for (dirpath, dirnames, filenames) in os.walk(wavpath):
for filename in filenames:
if filename.endswith('.wav'):
filepath = os.sep.join([dirpath, filename])
fileid = filename.strip('.wav')
wavfiles[fileid] = filepath
fileids.append(fileid)
return wavfiles,fileids # 对音频文件提取mfcc特征
def compute_mfcc(file):
fs, audio = wav.read(file)
mfcc_feat = mfcc(audio, samplerate=fs, numcep=26)
mfcc_feat = mfcc_feat[::3]
mfcc_feat = np.transpose(mfcc_feat)
mfcc_feat = pad_sequences(mfcc_feat, maxlen=500, dtype='float', padding='post', truncating='post').T
return mfcc_feat

文本数据处理

# -----------------------------------------------------------------------------------------------------
'''
&usage: [text]对文本标注文件进行处理,包括生成拼音到数字的映射,以及将拼音标注转化为数字的标注转化
'''
# -----------------------------------------------------------------------------------------------------
# 利用训练数据生成词典
def gendict(textfile_path):
dicts = []
textfile = open(textfile_path,'r+')
for content in textfile.readlines():
content = content.strip('\n')
content = content.split(' ',1)[1]
content = content.split(' ')
dicts += (word for word in content)
counter = Counter(dicts)
words = sorted(counter)
wordsize = len(words)
word2num = dict(zip(words, range(wordsize)))
num2word = dict(zip(range(wordsize), words))
return word2num, num2word #1176个音素 # 文本转化为数字
def text2num(textfile_path):
lexcion,num2word = gendict(textfile_path)
word2num = lambda word:lexcion.get(word, 0)
textfile = open(textfile_path, 'r+')
content_dict = {}
for content in textfile.readlines():
content = content.strip('\n')
cont_id = content.split(' ',1)[0]
content = content.split(' ',1)[1]
content = content.split(' ')
content = list(map(word2num,content))
add_num = list(np.zeros(50-len(content)))
content = content + add_num
content_dict[cont_id] = content
return content_dict,lexcion

数据格式处理

# -----------------------------------------------------------------------------------------------------
'''
&usage: [data]数据生成器构造,用于训练的数据生成,包括输入特征及标注的生成,以及将数据转化为特定格式
'''
# -----------------------------------------------------------------------------------------------------
# 将数据格式整理为能够被网络所接受的格式,被data_generator调用
def get_batch(x, y, train=False, max_pred_len=50, input_length=500):
X = np.expand_dims(x, axis=3)
X = x # for model2
# labels = np.ones((y.shape[0], max_pred_len)) * -1 # 3 # , dtype=np.uint8
labels = y
input_length = np.ones([x.shape[0], 1]) * ( input_length - 2 )
# label_length = np.ones([y.shape[0], 1])
label_length = np.sum(labels > 0, axis=1)
label_length = np.expand_dims(label_length,1)
inputs = {'the_input': X,
'the_labels': labels,
'input_length': input_length,
'label_length': label_length,
}
outputs = {'ctc': np.zeros([x.shape[0]])} # dummy data for dummy loss function
return (inputs, outputs) # 数据生成器,默认音频为thchs30\train,默认标注为thchs30\train.syllable,被模型训练方法fit_generator调用
def data_generate(wavpath = 'E:\\Data\\data_thchs30\\train', textfile = 'E:\\Data\\thchs30\\train.syllable.txt', bath_size=4):
wavdict,fileids = genwavlist(wavpath)
#print(wavdict)
content_dict,lexcion = text2num(textfile)
genloop = len(fileids)//bath_size
print("all loop :", genloop)
while True:
feats = []
labels = []
# 随机选择某个音频文件作为训练数据
i = random.randint(0,genloop-1)
for x in range(bath_size):
num = i * bath_size + x
fileid = fileids[num]
# 提取音频文件的特征
mfcc_feat = compute_mfcc(wavdict[fileid])
feats.append(mfcc_feat)
# 提取标注对应的label值
labels.append(content_dict[fileid])
# 将数据格式修改为get_batch可以处理的格式
feats = np.array(feats)
labels = np.array(labels)
# 调用get_batch将数据处理为训练所需的格式
inputs, outputs = get_batch(feats, labels)
yield inputs, outputs

构建模型

# -----------------------------------------------------------------------------------------------------
'''
&usage: [net model]构件网络结构,用于最终的训练和识别
'''
# -----------------------------------------------------------------------------------------------------
# 被creatModel调用,用作ctc损失的计算
def ctc_lambda(args):
labels, y_pred, input_length, label_length = args
y_pred = y_pred[:, :, :]
return K.ctc_batch_cost(labels, y_pred, input_length, label_length) # 构建网络结构,用于模型的训练和识别
def creatModel():
input_data = Input(name='the_input', shape=(500, 26))
layer_h1 = Dense(512, activation="relu", use_bias=True, kernel_initializer='he_normal')(input_data)
#layer_h1 = Dropout(0.3)(layer_h1)
layer_h2 = Dense(512, activation="relu", use_bias=True, kernel_initializer='he_normal')(layer_h1)
layer_h3_1 = GRU(512, return_sequences=True, kernel_initializer='he_normal', dropout=0.3)(layer_h2)
layer_h3_2 = GRU(512, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', dropout=0.3)(layer_h2)
layer_h3 = add([layer_h3_1, layer_h3_2])
layer_h4 = Dense(512, activation="relu", use_bias=True, kernel_initializer='he_normal')(layer_h3)
#layer_h4 = Dropout(0.3)(layer_h4)
layer_h5 = Dense(1177, activation="relu", use_bias=True, kernel_initializer='he_normal')(layer_h4)
output = Activation('softmax', name='Activation0')(layer_h5)
model_data = Model(inputs=input_data, outputs=output)
#ctc
labels = Input(name='the_labels', shape=[50], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
loss_out = Lambda(ctc_lambda, output_shape=(1,), name='ctc')([labels, output, input_length, label_length])
model = Model(inputs=[input_data, labels, input_length, label_length], outputs=loss_out)
model.summary()
ada_d = Adadelta(lr=0.01, rho=0.95, epsilon=1e-06)
#model=multi_gpu_model(model,gpus=2)
model.compile(loss={'ctc': lambda y_true, output: output}, optimizer=ada_d)
#test_func = K.function([input_data], [output])
print("model compiled successful!")
return model, model_data

模型训练及解码

# -----------------------------------------------------------------------------------------------------
'''
&usage: 模型的解码,用于将数字信息映射为拼音
'''
# -----------------------------------------------------------------------------------------------------
# 对model预测出的softmax的矩阵,使用ctc的准则解码,然后通过字典num2word转为文字
def decode_ctc(num_result, num2word):
result = num_result[:, :, :]
in_len = np.zeros((1), dtype = np.int32)
in_len[0] = 50;
r = K.ctc_decode(result, in_len, greedy = True, beam_width=1, top_paths=1)
r1 = K.get_value(r[0][0])
r1 = r1[0]
text = []
for i in r1:
text.append(num2word[i])
return r1, text # -----------------------------------------------------------------------------------------------------
'''
&usage: 模型的训练
'''
# -----------------------------------------------------------------------------------------------------
# 训练模型
def train():
# 准备训练所需数据
yielddatas = data_generate()
# 导入模型结构,训练模型,保存模型参数
model, model_data = creatModel()
model.fit_generator(yielddatas, steps_per_epoch=2000, epochs=1)
model.save_weights('model.mdl')
model_data.save_weights('model_data.mdl') # -----------------------------------------------------------------------------------------------------
'''
&usage: 模型的测试,看识别结果是否正确
'''
# -----------------------------------------------------------------------------------------------------
# 测试模型
def test():
# 准备测试数据,以及生成字典
word2num, num2word = gendict('E:\\Data\\thchs30\\train.syllable.txt')
yielddatas = data_generate(bath_size=1)
# 载入训练好的模型,并进行识别
model, model_data = creatModel()
model_data.load_weights('model_data.mdl')
result = model_data.predict_generator(yielddatas, steps=1)
# 将数字结果转化为文本结果
result, text = decode_ctc(result, num2word)
print('数字结果: ', result)
print('文本结果:', text)

aishell数据转化

将aishell中的汉字标注转化为拼音标注,利用该数据与thchs30数据训练同样的网络结构。

该模型作为一个练手小项目。

没有使用语言模型,直接简单建模。

我的github: https://github.com/audier

GRU-CTC中文语音识别的更多相关文章

  1. 基于深度学习的中文语音识别系统框架(pluse)

    目录 声学模型 GRU-CTC DFCNN DFSMN 语言模型 n-gram CBHG 数据集 本文搭建一个完整的中文语音识别系统,包括声学模型和语言模型,能够将输入的音频信号识别为汉字. 声学模型 ...

  2. python使用vosk进行中文语音识别

    操作系统:Windows10 Python版本:3.9.2 vosk是一个离线开源语音识别工具,它可以识别16种语言,包括中文. 这里记录下使用vosk进行中文识别的过程,以便后续查阅. vosk地址 ...

  3. pyttsx的中文语音识别问题及探究之路

    最近在学习pyttsx时,发现中文阅读一直都识别错误,从发音来看应该是字符编码问题,但搜索之后并未发现解决方案.自己一路摸索解决,虽说最终的原因非常可笑,大牛们可能也是一眼就能洞穿,但也值得记录一下. ...

  4. Unity中使用百度中文语音识别功能

    下面是API类 Asr.cs using System; using System.Collections; using System.Collections.Generic; using Unity ...

  5. 深度学习实战篇-基于RNN的中文分词探索

    深度学习实战篇-基于RNN的中文分词探索 近年来,深度学习在人工智能的多个领域取得了显著成绩.微软使用的152层深度神经网络在ImageNet的比赛上斩获多项第一,同时在图像识别中超过了人类的识别水平 ...

  6. [DeeplearningAI笔记]序列模型3.9-3.10语音辨识/CTC损失函数/触发字检测

    5.3序列模型与注意力机制 觉得有用的话,欢迎一起讨论相互学习~Follow Me 3.9语音辨识 Speech recognition 问题描述 对于音频片段(audio clip)x ,y生成文本 ...

  7. Python实现各类验证码识别

    项目地址: https://github.com/kerlomz/captcha_trainer 编译版下载地址: https://github.com/kerlomz/captcha_trainer ...

  8. TensorFlow练习13: 制作一个简单的聊天机器人

    现在很多卖货公司都使用聊天机器人充当客服人员,许多科技巨头也纷纷推出各自的聊天助手,如苹果Siri.Google Now.Amazon Alexa.微软小冰等等.前不久有一个视频比较了Google N ...

  9. linux install Openvino

    recommend centos7 github Openvino tooltiks 1. download openvino addational installation for ncs2 ncs ...

随机推荐

  1. Java Bean与Map之间相互转化的实现

    目录树 概述 Apache BeanUtils将Bean转Map Apache BeanUtils将Map转Bean 理解BeanUtils将Bean转Map的实现之手写Bean转Map 概述 Apa ...

  2. CoacoaPods安装使与使用超级详细教程

    对于一个iOS开发的初学者来说,并不知道第三方类库的存在,知道了也不知道如何使用,那么下面便来介绍一下使用方法. iOS开发常用的第三方类库是GitHub:https://github.com/ 在上 ...

  3. 关于MySQL优化问题

    众所周知在数据量庞大的情况下普通的SQL语句已经满足不了我们的需要了,这个时候就需要DBA去进行数据库的优化,而我们作为一名开发人员不能对数据库进行优化这时该怎么办呢?答案是只能在SQL语句上面进行优 ...

  4. 学习 Linux_kernel_exploits 小记

    Linux_kernel_exploits+ 功能:自动生成UAF类型漏洞exp文件的工具,目前缺少文档介绍,可以参考test文件下的使用实例,但是源码中缺少dataflowanalyzer模块+ 相 ...

  5. 06.升级git版本及命令学习

    博客为日常工作学习积累总结: 1.升级git版本: 参考博客:https://blog.csdn.net/yuexiahunone/article/details/78647565由于新的版本可以使用 ...

  6. layui 图片与表单一起提交 + layer.photos图片层预览

    HTML基本结构: <form class="layui-form" action="" id="feedBackForm"> ...

  7. 07JavaScript数据类型

    JavaScript 数据类型 值类型(基本类型):字符串(String).数字(Number).布尔(Boolean).对空(Null).未定义(Undefined).Symbol. 引用数据类型: ...

  8. SQL Server 2012 - SQL查询

    执行计划显示SQL执行的开销 工具→ SQL Server Profiler : SQL Server 分析器,监视系统调用的SQL Server查询 Top查询 -- Top Percent 选择百 ...

  9. jQuery 动画效果 与 动画队列

    基础效果 .hide([duration ] [,easing ] [,complete ]) 用于隐藏元素,没有参数的时候等同于直接设置 display 属性 $('.target').hide() ...

  10. php函数strtotime结合date时间修饰语的使用

    下面简单介绍在项目开发中date时间函数和strtotime所遇到的问题,以及解决办法. 原文地址:小时刻个人技术博客 > http://small.aiweimeng.top/index.ph ...