pytorch加载语音类自定义数据集
pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合
- torch.utils.data.Dataset:所有继承他的子类都应该重写 __len()__ , __getitem()__ 这两个方法
- __len()__ :返回数据集中数据的数量
- __getitem()__ :返回支持下标索引方式获取的一个数据
- torch.utils.data.DataLoader:对数据集进行包装,可以设置batch_size、是否shuffle....
第一步
自定义的 Dataset 都需要继承 torch.utils.data.Dataset 类,并且重写它的两个成员方法:
- __len()__:读取数据,返回数据和标签
- __getitem()__:返回数据集的长度
from torch.utils.data import Dataset class AudioDataset(Dataset):
def __init__(self, ...):
"""类的初始化"""
pass def __getitem__(self, item):
"""每次怎么读数据,返回数据和标签"""
return data, label def __len__(self):
"""返回整个数据集的长度"""
return total
注意事项:Dataset只负责数据的抽象,一次调用getiitem只返回一个样本
案例:
文件目录结构
- p225
- ***.wav
- ***.wav
- ***.wav
- ...
- dataset.py
目的:读取p225文件夹中的音频数据
1 class AudioDataset(Dataset):
2 def __init__(self, data_folder, sr=16000, dimension=8192):
3 self.data_folder = data_folder
4 self.sr = sr
5 self.dim = dimension
6
7 # 获取音频名列表
8 self.wav_list = []
9 for root, dirnames, filenames in os.walk(data_folder):
10 for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
11 self.wav_list.append(os.path.join(root, filename))
12
13 def __getitem__(self, item):
14 # 读取一个音频文件,返回每个音频数据
15 filename = self.wav_list[item]
16 wb_wav, _ = librosa.load(filename, sr=self.sr)
17
18 # 取 帧
19 if len(wb_wav) >= self.dim:
20 max_audio_start = len(wb_wav) - self.dim
21 audio_start = np.random.randint(0, max_audio_start)
22 wb_wav = wb_wav[audio_start: audio_start + self.dim]
23 else:
24 wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
25
26 return wb_wav, filename
27
28 def __len__(self):
29 # 音频文件的总数
30 return len(self.wav_list)
注意事项:19-24行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
第二步
实例化 Dataset 对象
Dataset= AudioDataset("./p225", sr=16000)
如果要通过batch读取数据的可直接跳到第三步,如果你想一个一个读取数据的可以看我接下来的操作
# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000) for i, data in enumerate(train_set):
wb_wav, filname = data
print(i, wb_wav.shape, filname) if i == 3:
break
# 0 (8192,) ./p225\p225_001.wav
# 1 (8192,) ./p225\p225_002.wav
# 2 (8192,) ./p225\p225_003.wav
# 3 (8192,) ./p225\p225_004.wav
第三步
如果想要通过batch读取数据,需要使用DataLoader进行包装
为何要使用DataLoader?
- 深度学习的输入是mini_batch形式
- 样本加载时候可能需要随机打乱顺序,shuffle操作
- 样本加载需要采用多线程
pytorch提供的 DataLoader 封装了上述的功能,这样使用起来更方便。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
参数:
- dataset:加载的数据集(Dataset对象)
- batch_size:每个批次要加载多少个样本(默认值:1)
- shuffle:每个epoch是否将数据打乱
- sampler:定义从数据集中抽取样本的策略。如果指定,则不能指定洗牌。
- batch_sampler:类似于sampler,但每次返回一批索引。与batch_size、shuffle、sampler和drop_last相互排斥。
- num_workers:使用多进程加载的进程数,0代表不使用多线程
- collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认拼接方式
- pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
- drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃
返回:数据加载器
案例:
# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True) for (i, data) in enumerate(train_loader):
wav_data, wav_name = data
print(wav_data.shape) # torch.Size([8, 8192])
print(i, wav_name)
# ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
# './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')
我们来吃几个栗子消化一下:
栗子1
这个例子就是本文一直举例的,栗子1只是合并了一下而已
文件目录结构
- p225
- ***.wav
- ***.wav
- ***.wav
- ...
- dataset.py
目的:读取p225文件夹中的音频数据
1 import fnmatch
2 import os
3 import librosa
4 import numpy as np
5 from torch.utils.data import Dataset
6 from torch.utils.data import DataLoader
7
8
9 class Aduio_DataLoader(Dataset):
10 def __init__(self, data_folder, sr=16000, dimension=8192):
11 self.data_folder = data_folder
12 self.sr = sr
13 self.dim = dimension
14
15 # 获取音频名列表
16 self.wav_list = []
17 for root, dirnames, filenames in os.walk(data_folder):
18 for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
19 self.wav_list.append(os.path.join(root, filename))
20
21 def __getitem__(self, item):
22 # 读取一个音频文件,返回每个音频数据
23 filename = self.wav_list[item]
24 print(filename)
25 wb_wav, _ = librosa.load(filename, sr=self.sr)
26
27 # 取 帧
28 if len(wb_wav) >= self.dim:
29 max_audio_start = len(wb_wav) - self.dim
30 audio_start = np.random.randint(0, max_audio_start)
31 wb_wav = wb_wav[audio_start: audio_start + self.dim]
32 else:
33 wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
34
35 return wb_wav, filename
36
37 def __len__(self):
38 # 音频文件的总数
39 return len(self.wav_list)
40
41
42 train_set = Aduio_DataLoader("./p225", sr=16000)
43 train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
44
45
46 for (i, data) in enumerate(train_loader):
47 wav_data, wav_name = data
48 print(wav_data.shape) # torch.Size([8, 8192])
49 print(i, wav_name)
50 # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
51 # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')
注意事项:
- 27-33行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
- 48行:我们在__getitem__中并没有将numpy数组转换为tensor格式,可是第48行显示数据是tensor格式的。这里需要引起注意
栗子2
相比于案例1,案例二才是重点,因为我们不可能每次只从一音频文件中读取一帧,然后读取另一个音频文件,通常情况下,一段音频有很多帧,我们需要的是按顺序的读取一个batch_size的音频帧,先读取第一个音频文件,如果满足一个batch,则不用读取第二个batch,如果不足一个batch则读取第二个音频文件,来补充。
我给出一个建议,先按顺序读取每个音频文件,以窗长8192、帧移4096对语音进行分帧,然后拼接。得到(帧数,帧长,1)(frame_num, frame_len, 1)的数组保存到h5中。然后用上面讲到的 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 读取数据。
具体实现代码:
第一步:创建一个H5_generation脚本用来将数据转换为h5格式文件:
第二步:通过Dataset从h5格式文件中读取数据
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py def load_h5(h5_path):
# load training data
with h5py.File(h5_path, 'r') as hf:
print('List of arrays in input file:', hf.keys())
X = np.array(hf.get('data'), dtype=np.float32)
Y = np.array(hf.get('label'), dtype=np.float32)
return X, Y class AudioDataset(Dataset):
"""数据加载器"""
def __init__(self, data_folder):
self.data_folder = data_folder
self.X, self.Y = load_h5(data_folder) # (3392, 8192, 1) def __getitem__(self, item):
# 返回一个音频数据
X = self.X[item]
Y = self.Y[item] return X, Y def __len__(self):
return len(self.X) train_set = AudioDataset("./speaker225_resample_train.h5")
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True) for (i, wav_data) in enumerate(train_loader):
X, Y = wav_data
print(i, X.shape)
# 0 torch.Size([64, 8192, 1])
# 1 torch.Size([64, 8192, 1])
# ...
我尝试在__init__中生成h5文件,但是会导致内存爆炸,就很奇怪,因此我只好分开了,
参考
pytorch学习(四)—自定义数据集(讲的比较详细)
pytorch加载语音类自定义数据集的更多相关文章
- pytorch 加载mnist数据集报错not gzip file
利用pytorch加载mnist数据集的代码如下 import torchvision import torchvision.transforms as transforms from torch.u ...
- 解决Eclipse中“诡异”的错误:找不到或无法加载主类
记录下来遇到的(问题,解决方法),是更有效的解决问题的方式.(原谅我领悟的太晚与懒,从此用更有意义的方法,做一个更有意义的人) 因为遇到了多次,参考同一个方法,原文连接:https://blog.cs ...
- JVM如何加载一个类的过程,双亲委派模型中有哪些方法
1.类加载过程:加载.验证.准备.解析.初始化 加载 在加载阶段,虚拟机主要完成三件事: 1.通过一个类的全限定名来获取定义此类的二进制字节流. 2.将这个字节流所代表的静态存储结构转化为方法 ...
- 找不到或无法加载主类 ide 正常执行,但是打包jar后报错 maven 引入本地包
错误: 找不到或无法加载主类 com.myali.TTSmy 问题原因: ide中编译能找到相关包,但是,打包成jar时,本地的jar引入失败 maven将系统用到的包从线上maven仓库下载到本地的 ...
- javac 不是内部或外部命令 和 错误 找不到或无法加载主类 的解决方法
使用package语句与import语句. 实验要求:按实验要求使用package语句,并用import语句使用Java平台提供的包中的类以及自定义包中的类.掌握一些重要的操作步骤. 代码: 模板1: ...
- 使用Huggingface在矩池云快速加载预训练模型和数据集
作为NLP领域的著名框架,Huggingface(HF)为社区提供了众多好用的预训练模型和数据集.本文介绍了如何在矩池云使用Huggingface快速加载预训练模型和数据集. 1.环境 HF支持Pyt ...
- eclipse 下找不到或无法加载主类的解决办法
有时候 Eclipse 会发神经,好端端的 project 就这么编译不了了,连 Hello World 都会报“找不到或无法加载主类”的错误,我已经遇到好几次了,以前是懒得深究就直接重建projec ...
- java HelloWorld 提示“错误: 找不到或无法加载主类 HelloWorld“解决方案
在检查环境变量等前提工作准确无误后,注意要配好CLASSPATH,仍然报“错误: 找不到或无法加载主类 HelloWorld“. 本人工程目录:mygs-maven/src/main/java/hel ...
- maven project中,在main方法上右键Run as Java Application时,提示错误:找不到或无法加载主类XXX.XXXX.XXX
新建了一个maven project项目,经过一大堆的修改操作之后,突然发现在main方法上右键运行时,竟然提示:错误:找不到或无法加载主类xxx.xxx.xxx可能原因1.eclipse出问题了,在 ...
随机推荐
- The Python Tutorial 和 documentation和安装库lib步骤
链接: The Python Tutorial : https://docs.python.org/3.6/tutorial/index.html Documentation: https://doc ...
- matlab中bitshift 将位移动指定位数
来源:https://ww2.mathworks.cn/help/matlab/ref/bitshift.html?searchHighlight=bitshift&s_tid=doc_src ...
- 【学习笔记/题解】分层图/[JLOI2011]飞行路线
题目戳我 \(\text{Solution:}\) 关于分层图: 一般用于处理:给你\(k\)次机会对边权进行修改的最短路问题. 算法流程: 建立出\(k\)层图,对应进行\(k\)次操作后的局面. ...
- ThreeJS系列1_CinematicCameraJS插件详解
ThreeJS系列1_CinematicCameraJS插件详解 接着上篇 ThreeJS系列1_CinematicCameraJS插件介绍 看属性的来龙去脉 看方法作用 通过调整属性查看效果 总结 ...
- 浅谈 Java集合
Java 集合 集合是对象的容器,定义了多个对象进行操作的常用方法,可实现数组的功能. Java集合类库所处位置:java.util.*. 与现代的数据结构类库的常见做法一样,Java集合类库也将接口 ...
- Markdown语法及使用方法完整手册
欢迎使用 Markdown在线编辑器 MdEditor Markdown是一种轻量级的「标记语言」 Markdown是一种可以使用普通文本编辑器编写的标记语言,通过简单的标记语法,它可以使普通文本内容 ...
- Presto在滴滴的探索与实践
桔妹导读:Presto在滴滴内部发展三年,已经成为滴滴内部Ad-Hoc和Hive SQL加速的首选引擎.目前服务6K+用户,每天读取2PB ~ 3PB HDFS数据,处理30万亿~35万亿条记录,为 ...
- 2020年9月程序员工资统计,平均14459元!你给程序员拖后腿了吗?https://jq.qq.com/?_wv=1027&k=JMPndqoM
2020年9月全国招收程序员362409人.2020年9月全国程序员平均工资14459元,工资中位数12500元,其中95%的人的工资介于5250元到35000元. 工资与上个月持平,但是岗位有所增加 ...
- 【不知道怎么分类】NOIP2016 蚯蚓
题目大意 洛谷链接 给出\(n\)条蚯蚓,给出\(m\)秒,每一秒都把蚯蚓中最长的蚯蚓分成两段,一段是原来的\(p\)倍,剩下的就是\((1-p)\)倍.每一秒,除了刚刚产生的两条新蚯蚓,其余蚯蚓长度 ...
- 安装JDK及环境变量配置
1.下载JDK: 下载地址:https://www.oracle.com/technetwork/java/javase/overview/index.html 2.解压,运行安装包,下一步,选择安装 ...