基于深度学习的回声消除系统与Pytorch实现
文章作者:凌逆战
文章代码(pytorch实现):https://github.com/LXP-Never/AEC_DeepModel
文章地址(转载请指明出处):https://www.cnblogs.com/LXP-Never/p/14779360.html
写这篇文章的目的:
- 降低全国想要做基于深度学习的回声消除同学们一个入门门槛。万事开头难呀,肯定有很多小白辛苦研究了一年,连基线系统都搭建不出来的,他们肯定心心念念有谁能帮帮他们,这不,我来了。
- 在基于深度学习的回声消除这一块,网上几乎没人开源,github上能找到的几乎都是基于自适应滤波器的。我个人是很提倡开源精神的,能让更多的人能够参与进来,小到促进这个领域的进步,大到提升国家科学竞争力,哪怕只是一小步,都需要有人做出行动。
- 今天我开源,明天你开源。可能以后你们的开源项目也能帮助到我。
作者独白:
- 写这篇文章的目的在于想做基于深度学习的回声消除小白们一份入门教学,所以别对这篇文章有什么创新点或者性能上的较大期待,我只是随便搭建了一个基线系统,来进行回声消除代码的讲解,带领小白入门。
- 别问我为什么不调试好了再分享出来,时间精力有限,我的研究方向也不是回声消除,我只是感兴趣,也没人给我钱支持我研究,从一个基线模型到最终一个完善的模型,是需要巨大的时间成本的,每往下走一步需要的付出精力越多,这就是科研之路。
- 本文分享出来的系统在哪个点可以改进,可以做创新发论文,我都会在文中说明,不用感谢我
- 本文引用了诸多我原先的文章,遇到不懂的大家可能还需要多翻看原来的文章,知识需要积累,没有一蹴而就的捷径。
- 文中若有不对之处,还请各位看官多多包含,多提意见,我会积极修改的。觉得写得不错的,建议点赞关注一下,这是对我最大的支持,是给我开源精神最大的鼓励,我以后也还会努力分享好文章给大家的。
原理
传统算法
主要参考我的另外一篇文章:声学回声消除(Acoustic Echo Cancellation)原理与实现。
图中$x(n)$为远端语音,$y(n)$为远端回声$y(n)=x(n)*w(n)$,$s(n)$为近端语音,$d(n)$为近端麦克风语音信号。
深度学习算法
回声包含线性回声和非线性回声
- 线性回声:远端语音直接 被近端麦克风接收的回声。
- 非线性回声:远端语音经过多径传播后 被近端麦克风接收的回声
线性回声可以通过 时延估计、端点检测和自适应滤波器技术较好的消除,非线性回声经过多次反射后产生了混响,声学特性复杂,很难消除。基于深度学习的回声消除技术,目前有这几个方向在做:
- 神经网络
- 自适应滤波器+神经网络
神经网络
利用神经网络较强的非线性拟合能力,直接消除线性回声和非线性回声
- 优点:过程简单,一步到位
- 缺点:可能需要更复杂或精炼的模型,才能达到更好的效果。更加考验模型的能力
自适应滤波器+神经网络
先利用简单的传统方法消除线性回声,再利用神经网络消除非线性回声
- 优点:有针对性的进行回声消除,能降低神经网络的负担
- 缺点:能一步到位的事情,就不要把事情复杂化
图片来源于论文:Residual acoustic echo suppression based on efficient multi-task convolutional neural network,图中$e(n)$为自适应滤波器输出的的残差信号,$u(n)$为远端参考信号,然后利用短时间傅里叶变换(STFT)将$e(n)$和$u(n)$转换到频域,串联作为输入特征。同样输出mask。估计的近端振幅为:
$$估计的近端振幅=mask*自适应滤波器输出$$
训练策略
- 频谱映射:输入(近端麦克风语音频谱,远端语音频谱),输出(近端语音频谱)
- 波形映射:输入(近端麦克风语音波形,远端语音波形),输出(近端语音波形)
- 频谱mask:输入(近端麦克风语音频谱,远端语音频谱),输出 (mask),近端语音频谱 = mask*近端麦克风语音频谱
- 时域mask:输入(近端麦克风语音波形),输出(近端语音mask, 远端回声mask),近端语音波形 = 近端语音mask*近端麦克风语音波形(这个点,我是受到语音分离的一篇文章启发,觉得可行,所以也分享在这了,目前还没有这方向的论文,科研工作者可以去尝试)
频谱映射、波形映射、频谱mask我在这篇文章中做了详细的说明,时域mask在这篇文章中做了详细的讲解。
回声消除跟语音增强和语音去混响或者语音分离很像,都是从混合语音或者污染语音中提取干净的语音。因此我们如果想要在回声消除领域找创新点的话,不妨去多看看我刚刚提的三个方向的论文。我主要参考的是语音增强和语音分离。
基线模型
本文重点来了,我搭建的基线系统是使用神经网络直接消除回声, 训练策略为 频谱mask。
数据准备
做回声消除任务主要有两类数据,真实回声数据以及合成回声数据。
- 真实回声数据:在真实环境中采集的回声,目前只有微软举办的 回声消除挑战赛中开源的数据集,我个人认为微软数据集中真实数据集有点问题,详情见博客。
- 合成回声数据:通过RIR合成的回声。可以使用任意的语音数据集,使用RIR-Generator生成房间冲击响应(推荐使用MATLAB方法),再卷积远端语音得到回声。科研界主要使用的TIMIT数据集。AEC-Challenge 数据集也有合成数据集。
我这里就偷个懒,直接使用AEC-Challenge合成好了的数据集。文件结构如下
└─Synthetic
├─TEST
│ ├─echo_signal
│ ├─farend_speech
│ ├─nearend_mic_signal
│ └─nearend_speech
├─TRAIN
│ ├─echo_signal
│ ├─farend_speech
│ ├─nearend_mic_signal
│ └─nearend_speech
└─VAL
├─echo_signal
├─farend_speech
├─nearend_mic_signal
└─nearend_speech
如果你们想用TIMIT数据集的话(毕竟很多论文都用他),可以具体参考这篇论文的数据准备方法。我个人被这篇论文给绕晕了,数据准备看似不简单,但用代码实现起来却非常难。你们可以自己去试试。
但不管用哪个数据集,我还是建议大家都把数据按照上面的文件路径结构放好,方便读取。
我搭建的基线系统实现的是频谱mask的训练策略,模型输入为[远端语音振幅,近端麦克风振幅],模型输出IRM mask。IRM公式可以写成以下几种形式为:
$$\operatorname{IRM}=\sqrt{\frac{近端语音振幅^2}{近端语音振幅^2+远端回声振幅^2}}$$
$$\mathrm{IRM}=\sqrt{\frac{\text { 远端语音振幅 }^{2}}{(\text { 近端语音振幅+远端回声振幅 })^{2}}}$$
$$\operatorname{IRM}=\sqrt{\frac{近端语音振幅^2}{近端麦克风语音振幅^2}}$$
我使用的是Pytorch搭建的模型,Pytorch有一套自己的数据加载方式,我之前写过一篇文章进行了总结:pytorch加载语音类自定义数据集 。如果你已经很熟悉了请继续看,本文的回声消除数据预处理代码如下:
# Author:凌逆战
# -*- coding:utf-8 -*-
"""
作用:数据预处理
"""
import glob
import os
import torch.nn.functional as F
import torch
import torchaudio
from torch.utils.data import Dataset
from torch.utils.data import DataLoader class FileDateset(Dataset):
def __init__(self, dataset_path="./Synthetic/TRAIN", fs=16000, win_length=320, mode="train"):
self.fs = fs
self.win_length = win_length
self.mode = mode farend_speech_path = os.path.join(dataset_path, "farend_speech") # "./Synthetic/TRAIN/farend_speech"
nearend_mic_signal_path = os.path.join(dataset_path, "nearend_mic_signal") # "./Synthetic/TRAIN/nearend_mic_signal"
nearend_speech_path = os.path.join(dataset_path, "nearend_speech") # "./Synthetic/TRAIN/nearend_speech" self.farend_speech_list = sorted(glob.glob(farend_speech_path+"/*.wav")) # 远端语音路径,list
self.nearend_mic_signal_list = sorted(glob.glob(nearend_mic_signal_path+"/*.wav")) # 近端麦克风语音路径,list
self.nearend_speech_list = sorted(glob.glob(nearend_speech_path+"/*.wav")) # 近端语音路径,list def spectrogram(self, wav_path):
"""
:param wav_path: 音频路径
:return: 返回该音频的振幅和相位
"""
wav, _ = torchaudio.load(wav_path)
wav = wav.squeeze() if len(wav) < 160000:
wav = F.pad(wav, (0,160000-len(wav)), mode="constant",value=0) S = torch.stft(wav, n_fft=self.win_length, hop_length=self.win_length//2,
win_length=self.win_length, window=torch.hann_window(window_length=self.win_length),
center=False, return_complex=True) # (*, F,T)
magnitude = torch.abs(S) # 振幅
phase = torch.exp(1j * torch.angle(S)) # 相位
return magnitude, phase def __getitem__(self, item):
"""__getitem__是类的专有方法,使类可以像list一样按照索引来获取元素
:param item: 索引
:return: 按 索引取出来的 元素
"""
# 远端语音 振幅,相位 (F, T),F为频点数,T为帧数
farend_speech_magnitude, farend_speech_phase = self.spectrogram(self.farend_speech_list[item]) # torch.Size([161, 999])
# 近端麦克风 振幅,相位
nearend_mic_magnitude, nearend_mic_phase = self.spectrogram(self.nearend_mic_signal_list[item])
# 近端语音 振幅,相位
nearend_speech_magnitude, nearend_speech_phase = self.spectrogram(self.nearend_speech_list[item]) X = torch.cat((farend_speech_magnitude, nearend_mic_magnitude), dim=0) # 在频点维度上进行拼接(161*2, 999),模型输入 _eps = torch.finfo(torch.float).eps # 防止分母出现0
mask_IRM = torch.sqrt(nearend_speech_magnitude ** 2/(nearend_mic_magnitude ** 2+_eps)) # IRM,模型输出 return X, mask_IRM, nearend_mic_magnitude, nearend_speech_magnitude def __len__(self):
"""__len__是类的专有方法,获取整个数据的长度"""
return len(self.farend_speech_list) if __name__ == "__main__":
train_set = FileDateset()
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True) for x, y, nearend_mic_magnitude,nearend_speech_magnitude in train_loader:
print(x.shape) # torch.Size([64, 322, 999])
print(y.shape) # torch.Size([64, 161, 999])
print(nearend_mic_magnitude.shape)
我几乎每行代码都给了注释了,各位点个赞不过分吧。还有不懂地方的各位可以在评论区指出。
如果想要创新发文章的话,数据处理这里也可以做改动:
- 更改mask方法,或者提出更好用的mask,我这篇文章总结了不少:基于深度学习的单通道语音增强,大家可以轮着试一试,反正我给出了代码。
- 我这里使用的是振幅,你们可以尝试提取一些语音其他的特征,类似 梅尔频谱特征,对数功率谱等等。
- 在强调一遍呀,现在没有基于时域mask的回声消除论文,大家快去攻略占地呀,主要参考语音分离这个领域。
模型搭建
我这里使用的是频谱mask的训练策略,模型输入为 远端语音振幅 和 近端麦克风振幅 的串联,模型输出IRM。由上可知,输入大小为 [64, 322, 999],输出大小为 [64, 161, 999]。那么我们只需要随便搭建一个模型符合这个输入输出就行了。
# Author:凌逆战
# -*- coding:utf-8 -*-
"""
作用:随便搭建的模型,只要符合输入大小是[64, 322, 999],输出大小是[64, 161, 999],就能跑通
"""
import torch.nn as nn
import torch class Base_model(nn.Module):
def __init__(self):
super(Base_model, self).__init__()
# [batch, channel, input_size] (B, F, T)
# [64, 322, 999] ---> [64, 161, 999]
self.model = nn.Sequential(
nn.Conv1d(in_channels=322, out_channels=322, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2),
nn.Conv1d(in_channels=322, out_channels=322, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2),
nn.Conv1d(in_channels=322, out_channels=161, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2),
nn.Conv1d(in_channels=161, out_channels=161, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
) def forward(self, x):
"""
:param x: 麦克风信号和远端信号的特征串联在一起作为输入特征 (322, 206)
:return: IRM_mask * input = 近端语音对数谱
"""
Estimated_IRM = self.model(x) return Estimated_IRM if __name__ == "__main__":
model = Base_model().cuda()
x = torch.randn(8, 322, 999).to("cuda") # 输入 [8, 322, 999]
y = model(x) # 输出 [8, 161, 999]
print(y.shape)
模型是一个可以创新的点,大家可以改成目前比较流行的模型来发文章。我这里就随便搭建了。
如果想要创新发文章的话,模型搭建这里也可以做改动:
- 使用时序模型来更多的考量语音帧间相关性,如LSTM、TCN,注意力机制等等,反正现在的模型五花八门,看着谁好用借鉴过来用,然后魔改一下,有良好的效果的话,就能写论文了。
训练模块
训练模块其实是最没啥创新的,所有写的正儿八经的代码,训练模型几乎都一样,但是这一块却是卡住所有新人的较大关卡。不懂的人觉得难的要死,懂的人觉得简单地一批。
训练模块的具体流程有以下几部分:
- 命令行解析
- 数据集加载
- 检测模型保存地址是否存在,如果不存在则创建
- 实例化模型
- 实例化优化器(一般使用Adam优化器)
- 准备事件文件,方便Tensorboard可视化
- 如果接着上一次检查点训练,则加载模型
- 循环epochs,开始训练(前向传播,反向传播)
- 验证模型(根据验证集的损失和度量,对模型的超参数进行调整)
import os
import torch
from torch.utils.data import DataLoader
from torch import nn
import argparse
from tensorboardX import SummaryWriter from data_preparation.data_preparation import FileDateset
from model.Baseline import Base_model
from model.ops import pytorch_LSD def parse_args():
parser = argparse.ArgumentParser()
# 重头开始训练 defaule=None, 继续训练defaule设置为'/**.pth'
parser.add_argument("--model_name", type=str, default=None, help="是否加载模型继续训练 '/50.pth' None")
parser.add_argument("--batch-size", type=int, default=16, help="")
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument('--lr', type=float, default=3e-4, help='学习率 (default: 0.01)')
parser.add_argument('--train_data', default="./data_preparation/Synthetic/TRAIN", help='数据集的path')
parser.add_argument('--val_data', default="./data_preparation/Synthetic/VAL", help='验证样本的path')
parser.add_argument('--checkpoints_dir', default="./checkpoints/AEC_baseline", help='模型检查点文件的路径(以继续培训)')
parser.add_argument('--event_dir', default="./event_file/AEC_baseline", help='tensorboard事件文件的地址')
args = parser.parse_args()
return args def main():
args = parse_args()
print("GPU是否可用:", torch.cuda.is_available()) # True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 实例化 Dataset
train_set = FileDateset(dataset_path=args.train_data) # 实例化训练数据集
val_set = FileDateset(dataset_path=args.val_data) # 实例化验证数据集 # 数据加载器
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, drop_last=True)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=True) # ########### 保存检查点的地址(如果检查点不存在,则创建) ############
if not os.path.exists(args.checkpoints_dir):
os.makedirs(args.checkpoints_dir) ################################
# 实例化模型 #
################################
model = Base_model().to(device) # 实例化模型
# summary(model, input_size=(322, 999)) # 模型输出 torch.Size([64, 322, 999])
# ########### 损失函数 ############
criterion = nn.MSELoss(reduce=True, size_average=True, reduction='mean') ###############################
# 创建优化器 Create optimizers #
###############################
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, )
# lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20], gamma=0.1) # ########### TensorBoard可视化 summary ############
writer = SummaryWriter(args.event_dir) # 创建事件文件 # ########### 加载模型检查点 ############
start_epoch = 0
if args.model_name:
print("加载模型:", args.checkpoints_dir + args.model_name)
checkpoint = torch.load(args.checkpoints_dir + args.model_name)
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
start_epoch = checkpoint['epoch']
# lr_schedule.load_state_dict(checkpoint['lr_schedule']) # 加载lr_scheduler for epoch in range(start_epoch, args.epochs):
model.train() # 训练模型
for batch_idx, (train_X, train_mask, train_nearend_mic_magnitude, train_nearend_magnitude) in enumerate(
train_loader):
train_X = train_X.to(device) # 远端语音cat麦克风语音 [batch_size, 322, 999] (, F, T)
train_mask = train_mask.to(device) # IRM [batch_size 161, 999]
train_nearend_mic_magnitude = train_nearend_mic_magnitude.to(device)
train_nearend_magnitude = train_nearend_magnitude.to(device) # 前向传播
pred_mask = model(train_X) # [batch_size, 322, 999]--> [batch_size, 161, 999]
train_loss = criterion(pred_mask, train_mask) # 近端语音信号频谱 = mask * 麦克风信号频谱 [batch_size, 161, 999]
pred_near_spectrum = pred_mask * train_nearend_mic_magnitude
train_lsd = pytorch_LSD(train_nearend_magnitude, pred_near_spectrum) # 反向传播
optimizer.zero_grad() # 将梯度清零
train_loss.backward() # 反向传播
optimizer.step() # 更新参数 # ########### 可视化打印 ############
print('Train Epoch: {} Loss: {:.6f} LSD: {:.6f}'.format(epoch + 1, train_loss.item(), train_lsd.item())) # ########### TensorBoard可视化 summary ############
# lr_schedule.step() # 学习率衰减
# writer.add_scalar(tag="lr", scalar_value=model.state_dict()['param_groups'][0]['lr'], global_step=epoch + 1)
writer.add_scalar(tag="train_loss", scalar_value=train_loss.item(), global_step=epoch + 1)
writer.add_scalar(tag="train_lsd", scalar_value=train_lsd.item(), global_step=epoch + 1)
writer.flush() # 神经网络在验证数据集上的表现
model.eval() # 测试模型
# 测试的时候不需要梯度
with torch.no_grad():
for val_batch_idx, (val_X, val_mask, val_nearend_mic_magnitude, val_nearend_magnitude) in enumerate(
val_loader):
val_X = val_X.to(device) # 远端语音cat麦克风语音 [batch_size, 322, 999] (, F, T)
val_mask = val_mask.to(device) # IRM [batch_size 161, 999]
val_nearend_mic_magnitude = val_nearend_mic_magnitude.to(device)
val_nearend_magnitude = val_nearend_magnitude.to(device) # 前向传播
val_pred_mask = model(val_X)
val_loss = criterion(val_pred_mask, val_mask) # 近端语音信号频谱 = mask * 麦克风信号频谱 [batch_size, 161, 999]
val_pred_near_spectrum = val_pred_mask * val_nearend_mic_magnitude
val_lsd = pytorch_LSD(val_nearend_magnitude, val_pred_near_spectrum) # ########### 可视化打印 ############
print(' val Epoch: {} \tLoss: {:.6f}\tlsd: {:.6f}'.format(epoch + 1, val_loss.item(), val_lsd.item()))
######################
# 更新tensorboard #
######################
writer.add_scalar(tag="val_loss", scalar_value=val_loss.item(), global_step=epoch + 1)
writer.add_scalar(tag="val_lsd", scalar_value=val_lsd.item(), global_step=epoch + 1)
writer.flush() # # ########### 保存模型 ############
if (epoch + 1) % 10 == 0:
checkpoint = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch + 1,
# 'lr_schedule': lr_schedule.state_dict()
}
torch.save(checkpoint, '%s/%d.pth' % (args.checkpoints_dir, epoch + 1)) if __name__ == "__main__":
main()
咳咳咳,这个注释量,你们爱了没有,很详细了,还看不懂说明你的基础太差了,别看这篇文章了,打基础去吧,基础很重要。
如果想要创新发文章的话,损失这里也可以做改动:
- 使用一个更加全面的损失函数引导模型训练,我言尽于此,剩下的靠大家自己领悟了。
推理阶段
将模型预测的近端语音振幅和近端麦克风语音相位相乘得到近端语音的复数表示,经过短时傅里叶逆变换得到近端语音波形。这里需要补一点基础知识:
复数的几种表示形式:
- 实部、虚部(直角坐标系):$a+bj$ ($a$是实部,$b$是虚部)
- 幅值、相位(指数系):$re^{j\theta }$ ($r$是幅值,$\theta$是相角,$e^{j\theta }$是相位)
- 两种形式互换:$e^{j\theta }=cos\theta+isin\theta$,$re^{j\theta }=r(cos\theta+jsin\theta)=rcos\theta+jrsin\theta$
因此,实部$a=rcos\theta$,虚部$b=rsin\theta$,
幅值$r=\sqrt{a^2+b^2}$,相角$\theta=tan^{-1}(\frac{b}{a})$
还有一种是极坐标表示法:$r\angle \theta $
结合上述补充知识,以及复数矩阵D(F, T),我们可以得到一下频谱信息
- 复数的实部: real = np.real(D(F, T))
- 复数的虚部: imag= np.imag(D(F, T))
- 幅值: magnitude = np.abs(D(F, T)) 或 magnitude = np.sqrt(real**2+imag**2)
- 相角: angle = np.angle(D(F, T))
- 相位: phase = np.exp(1j * np.angle(D(F, T)))
librosa提供了专门将复数矩阵D(F, T)分离为幅值$S$和相位$P$的函数,$D=S*P$
librosa.magphase(D, power=1)
参数:
- D:经过stft得到的复数矩阵
- power:幅度谱的指数,例如,1代表能量,2代表功率,等等。
返回:
- D_mag:幅值$D$,
- D_phase:相位$P$, phase = exp(1.j * phi) , phi 是复数矩阵的相位角 np.angle(D)
当然我们也可以通过上面的公式自己求
# Author:凌逆战
# -*- coding:utf-8 -*-
"""
作用:通过模型生成近端语音
"""
import librosa
import matplotlib
import torchaudio
import torch.nn.functional as F
import torch
import matplotlib.pyplot as plt
from model.Baseline import Base_model
from matplotlib.ticker import FuncFormatter
import numpy as np plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示符号 def spectrogram(wav_path, win_length=320):
wav, _ = torchaudio.load(wav_path)
wav = wav.squeeze() if len(wav) < 160000:
wav = F.pad(wav, (0, 160000 - len(wav)), mode="constant", value=0)
# if len(wav) != 160000:
# print(wav_path)
# print(len(wav)) S = torch.stft(wav, n_fft=win_length, hop_length=win_length // 2,
win_length=win_length, window=torch.hann_window(window_length=win_length),
center=False, return_complex=True)
magnitude = torch.abs(S)
phase = torch.exp(1j * torch.angle(S))
return magnitude, phase fs = 16000
farend_speech = "./farend_speech/farend_speech_fileid_9992.wav"
nearend_mic_signal = "./nearend_mic_signal/nearend_mic_fileid_9992.wav"
nearend_speech = "./nearend_speech/nearend_speech_fileid_9992.wav"
echo_signal = "./echo_signal/echo_fileid_9992.wav" print("GPU是否可用:", torch.cuda.is_available()) # True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") farend_speech_magnitude, farend_speech_phase = spectrogram(farend_speech) # 远端语音 振幅,相位
nearend_mic_magnitude, nearend_mic_phase = spectrogram(nearend_mic_signal) # 近端麦克风语音 振幅,相位
nearend_speech_magnitude, nearend_speech_phase = spectrogram(nearend_speech) # 近端语音振 幅,相位 farend_speech_magnitude = farend_speech_magnitude.to(device)
nearend_mic_phase = nearend_mic_phase.to(device)
nearend_mic_magnitude = nearend_mic_magnitude.to(device) nearend_speech_magnitude = nearend_speech_magnitude.to(device)
nearend_speech_phase = nearend_speech_phase.to(device) model = Base_model().to(device) # 实例化模型
checkpoint = torch.load("../checkpoints/AEC_baseline/10.pth")
model.load_state_dict(checkpoint["model"]) X = torch.cat((farend_speech_magnitude, nearend_mic_magnitude), dim=0)
X = X.unsqueeze(0)
per_mask = model(X) # [1, 322, 999]-->[1, 161, 999] per_nearend_magnitude = per_mask * nearend_mic_magnitude # 预测的近端语音 振幅 complex_stft = per_nearend_magnitude * nearend_mic_phase # 振幅*相位=语音复数表示
print("complex_stft", complex_stft.shape) # [1, 161, 999] per_nearend = torch.istft(complex_stft, n_fft=320, hop_length=160, win_length=320,
window=torch.hann_window(window_length=320).to("cuda")) torchaudio.save("./predict/nearend_speech_fileid_9992.wav", src=per_nearend.cpu().detach(), sample_rate=fs)
# print("近端语音", per_nearend.shape) # [1, 159680] y, _ = librosa.load(nearend_speech, sr=fs)
time_y = np.arange(0, len(y)) * (1.0 / fs)
recover_wav, _ = librosa.load("./predict/nearend_speech_fileid_9992.wav", sr=16000)
time_recover = np.arange(0, len(recover_wav)) * (1.0 / fs) plt.figure(figsize=(8,6))
ax_1 = plt.subplot(3, 1, 1)
plt.title("近端语音和预测近端波形图", fontsize=14)
plt.plot(time_y, y, label="近端语音")
plt.plot(time_recover, recover_wav, label="深度学习生成的近端语音波形")
plt.xlabel('时间/s', fontsize=14)
plt.ylabel('幅值', fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.subplots_adjust(top=0.932, bottom=0.085, left=0.110, right=0.998)
plt.subplots_adjust(hspace=0.809, wspace=0.365) # 调整子图间距
plt.legend() norm = matplotlib.colors.Normalize(vmin=-200, vmax=-40)
ax_2 = plt.subplot(3, 1, 2)
plt.title("近端语音频谱", fontsize=14)
plt.specgram(y, Fs=fs, scale_by_freq=True, sides='default', cmap="jet", norm=norm)
plt.xlabel('时间/s', fontsize=14)
plt.ylabel('频率/kHz', fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.subplots_adjust(top=0.932, bottom=0.085, left=0.110, right=0.998)
plt.subplots_adjust(hspace=0.809, wspace=0.365) # 调整子图间距 ax_3 = plt.subplot(3, 1, 3)
plt.title("深度学习生成的近端语音频谱", fontsize=14)
plt.specgram(recover_wav, Fs=fs, scale_by_freq=True, sides='default', cmap="jet", norm=norm)
plt.xlabel('时间/s', fontsize=14)
plt.ylabel('频率/kHz', fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.subplots_adjust(top=0.932, bottom=0.085, left=0.110, right=0.998)
plt.subplots_adjust(hspace=0.809, wspace=0.365) # 调整子图间距 def formatnum(x, pos):
return '$%d$' % (x / 1000) formatter = FuncFormatter(formatnum)
ax_2.yaxis.set_major_formatter(formatter)
ax_3.yaxis.set_major_formatter(formatter) plt.show()
为了方便可视化对比,我顺便把波形图可语谱图画了出来
如果这篇文章对你有帮助,点个赞是对我最大的鼓励。
关注我,我将分享更有价值的文章!
基于深度学习的回声消除系统与Pytorch实现的更多相关文章
- 基于深度学习的中文语音识别系统框架(pluse)
目录 声学模型 GRU-CTC DFCNN DFSMN 语言模型 n-gram CBHG 数据集 本文搭建一个完整的中文语音识别系统,包括声学模型和语言模型,能够将输入的音频信号识别为汉字. 声学模型 ...
- 基于深度学习的人脸识别系统(Caffe+OpenCV+Dlib)【一】如何配置caffe属性表
前言 基于深度学习的人脸识别系统,一共用到了5个开源库:OpenCV(计算机视觉库).Caffe(深度学习库).Dlib(机器学习库).libfacedetection(人脸检测库).cudnn(gp ...
- 基于深度学习的人脸识别系统(Caffe+OpenCV+Dlib)【三】VGG网络进行特征提取
前言 基于深度学习的人脸识别系统,一共用到了5个开源库:OpenCV(计算机视觉库).Caffe(深度学习库).Dlib(机器学习库).libfacedetection(人脸检测库).cudnn(gp ...
- 基于深度学习的人脸识别系统(Caffe+OpenCV+Dlib)【二】人脸预处理
前言 基于深度学习的人脸识别系统,一共用到了5个开源库:OpenCV(计算机视觉库).Caffe(深度学习库).Dlib(机器学习库).libfacedetection(人脸检测库).cudnn(gp ...
- 基于深度学习的人脸识别系统系列(Caffe+OpenCV+Dlib)——【四】使用CUBLAS加速计算人脸向量的余弦距离
前言 基于深度学习的人脸识别系统,一共用到了5个开源库:OpenCV(计算机视觉库).Caffe(深度学习库).Dlib(机器学习库).libfacedetection(人脸检测库).cudnn(gp ...
- VulDeePecker:基于深度学习的脆弱性检测系统
最近的两款软件,VUDDY和VulPecker,假阴性率高而假阳性率低,用于检测由代码克隆引发的漏洞.而如果用于非代码克隆引起的漏洞则会出现高误报率. 本文使用深度学习处理程序中的代码片段,不应由专家 ...
- 基于深度学习的人脸识别系统Win10 环境安装与配置(python+opencv+tensorflow)
一.需要下载的软件.环境及文件 (由于之前见识短浅,对Anaconda这个工具不了解,所以需要对安装过程做出改变:就是Python3.7.2的下载安装是可选的,因为Anaconda已经为我们解决Pyt ...
- 基于深度学习的人脸性别识别系统(含UI界面,Python代码)
摘要:人脸性别识别是人脸识别领域的一个热门方向,本文详细介绍基于深度学习的人脸性别识别系统,在介绍算法原理的同时,给出Python的实现代码以及PyQt的UI界面.在界面中可以选择人脸图片.视频进行检 ...
- OCR技术浅探:基于深度学习和语言模型的印刷文字OCR系统
作者: 苏剑林 系列博文: 科学空间 OCR技术浅探:1. 全文简述 OCR技术浅探:2. 背景与假设 OCR技术浅探:3. 特征提取(1) OCR技术浅探:3. 特征提取(2) OCR技术浅探:4. ...
随机推荐
- 攻防世界 reverse 进阶 APK-逆向2
APK-逆向2 Hack-you-2014 (看名以为是安卓逆向呢0.0,搞错了吧) 程序是.net写的,直接祭出神器dnSpy 1 using System; 2 using System.Diag ...
- 在B站刷视频多倍速操作
B站多倍数播放 1. 最初天真版 F12 或者笔记本(Fn+F12) console控制台 输入 document.querySelector('video').playbackRate = 4: - ...
- shiro报错SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".和Exception in thread "main" java.lang.NoClassDefFoundError: org/apache/commons/logging/LogFactory
未能加载类"org.slf4j.impl.StaticLoggerBinder" 解决方案: <dependency> <groupId>org.slf4j ...
- Vue3教程:Vue 3 + Element Plus + Vite 2 的后台管理系统开源啦
之前发布过一篇文章<Vue3教程:开发一个 Vue 3 + element-plus 的后台管理系统>,文中提到会开发并开源一个 Vue 3 + Element Plus 的项目供大家练手 ...
- 经常问到的 BFC 和 IFC 是什么?
什么是BFC?什么作用? Block Formatting Context 块盒子布局发生的区域,浮动元素和其他元素交互的区域 浮动定位和清除浮动的时候只会应用于同一个BFC内的元素.浮动不会影响其他 ...
- mysql 批量操作,已存在则修改,不存在则insert,同时判断空选择性写入字段
注:如果是批量插入需要在 Java 连接数据库的字串中设置 &allowMultiQueries=true 针对单行数据有则修改无则新增 本案例的建表语句是: -- auto-generate ...
- .NET Core3.1 Dotnetty实战系列视频
一.概要 由于在.net的环境当中对dotnetty相关资料相对较少,所以这里主要分享一个dotnetty使用教程希望能帮助到正在使用这套框架的开发者们.虽然这套框架已微软官方已经不在维护,但是这套框 ...
- Go-08-函数与指针
Go语言的函数本身可以作为值进行传递,既支持匿名函数和闭包,又能满足接口. 函数声明 func 函数名 (参数列表)(返回参数列表){ // 函数体 } func funcName(parameter ...
- Data Mining UVA - 1591
Dr. Tuple is working on the new data-mining application for Advanced Commercial Merchandise Inc. O ...
- 熟知Mysql基本操作
本文是学习 Mysql必知必会 后的笔记 学习之前需要创建一个数据库,然后导入下面两个mysql脚本 create database db1 charset utf8; ############### ...