前言

这篇博客以PTB数据集为例,详细讲解了如何将txt格式的数据集文件,转换为pytorch框架可以直接处理的tensor变量,并附上相应代码

@


1. PTB 数据集

PTB数据集含有三个txt文件,分别作为训练集(train),验证集(valid)和测试集(test);这三个txt文件分别包含42000,3000和3000句英文;

我们要将其转化为pytorch可处理的tensor类型数据集,需要以下几步:

  • 依次读取每一行的训练集文件(train.txt),为每一个读到的单词分配序号,构建词汇表

    • 出现频率低于min_ooc(通常默认为3)次的词汇,单词一率变为未知单词 < unk > ,分配序号1
    • < sos >为每句话的起始信号,分配序号2
    • < eos >为每句话的结束信号,分配序号3
    • 由于每句话长度不一样,而pytorch批处理数据,需要统一句子长度,因此长度较短的句子用 < pad > 填充,分配序号0
  • 统一句子长度为max_sentence_length(默认50)
    • 高于50个单词的句子,只保留前50个单词;
    • 低于50个单词的句子,用 < pad > 信号填充到50
  • 根据训练集构建的词汇表,将训练集,验证集和测试集都变成数字序号表示的句子,如 a cat not is dog变成 2 25 54 12 0 0
  • 构建三个数据集加载转换后的用数字序号表示的句子,并将其错位句子作为该句子的标签(target),例如, a cat not is dog变成 2 25 54 12 0 0, 那它对应的target就是 25 54 12 3 0 0 了 (2 3 分别为起始,末尾信号)
  • 将其转换为批处理的tensor变量

这样我们就能得到pytorch可以直接加载处理的tensor类型数据集了

2. 构建词汇表

我们先定义一个字典类型的变量,该字典类型变量,会将输入的句子里的新单词添加到字典中,并记录该单词的出现次数

引入库文件:

import json
import torch
import numpy as np
from nltk.tokenize import TweetTokenizer
from collections import Counter, OrderedDict,defaultdict
import io
import os
from torch.utils.data import DataLoader
class OrderedCounter(Counter, OrderedDict): #这样定义的字典类型变量,会将输入的句子里的新单词添加到字典中,并记录该单词的出现次数
"""Counter that remembers the order elements are first encountered"""
def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, OrderedDict(self)) def __reduce__(self):
return self.__class__, (OrderedDict(self),)

依次读入ptb.train.txt的每一句话,并对其进行分词,不区分大小写;

  • 分词:由于默认可以通过空格来分开每个单词,但专业的分词函数更好些
def create_vocab(split):
tokenizer = TweetTokenizer(preserve_case=False) #分词,不区分大小写
w2c = OrderedCounter()
w2i = dict()
i2w = dict()
special_tokens = ['<pad>', '<unk>', '<sos>', '<eos>']
for st in special_tokens:
i2w[len(w2i)] = st
w2i[st] = len(w2i)
with open(split, 'r') as file:
for i, line in enumerate(file):
words = tokenizer.tokenize(line)
w2c.update(words) #这段程序将文件中出现过的所有单词加载到字典类型变量w2c中,并存储了他们出现的次数
for w, c in w2c.items():
if c > 3 and w not in special_tokens: #依次为出现次数大于3,且不是那4种特殊信号的单词分配序号
i2w[len(w2i)] = w
w2i[w] = len(w2i) #w2i格式为'cat':50这种,i2w为50:'cat'这种
return w2i,i2w

实例化一下试试:

3. 将训练集,验证集和测试集根据词汇表转换为数字序号,并转换为tensor

def create_data(split,w2i): #split为待转换的txt文件地址
tokenizer = TweetTokenizer(preserve_case=False) #分词,不区分大小写
data = defaultdict(dict)
with open(split, 'r') as file: #读取该文件的每一行
for i, line in enumerate(file):
words = tokenizer.tokenize(line) #分词
input = ['<sos>'] + words #输入的开头增加<sos>信号
input = input[:50] #只保留前50个(起始信号<sos> + 文本的前49个单词)
target = words[:50-1] #输入对应的target,也只保留50个(取文本的前49个单词+ 结束信号<eos>)
target = target + ['<eos>']
length = len(input)
input.extend(['<pad>'] * (50-length)) #输入和target,不足50个的,用<pad>补足50个
target.extend(['<pad>'] * (50-length))
input = [w2i.get(w, w2i['<unk>']) for w in input]
target = [w2i.get(w, w2i['<unk>']) for w in target]
id = len(data) #id表示该数据集的第id句话
inpu_t = torch.from_numpy(np.asarray(input)) #转换为tensor形式
targe_t = torch.from_numpy(np.asarray(target))
data[id]['input'] = inpu_t
data[id]['target'] = targe_t
data[id]['length'] = length
return data

实例化一下试试:

3. 转换为批处理的tensor变量

data_loader = DataLoader(
dataset= data,
batch_size= 64,#批处理大小
shuffle=True #是否打乱排序
)

实例化试试:

总结

这篇博客以PTB数据集为例,介绍了如何将txt形式的数据集转换为pytorch框架中可以使用的,批处理的tensor形式

参考项目:github上以PTB数据集训练的一个语言模型的项目

Pytorch加载txt格式的数据集文件(以PTB数据集为例)的更多相关文章

  1. Away3D 学习笔记(一): 加载3DS格式的模型文件

    加载外部的3DS文件分为两种: 1: 模型与贴图独立于程序的,也就是从外部的文件夹中读取 private function load3DSFile():Loader3D { loader = new ...

  2. pytorch 加载mnist数据集报错not gzip file

    利用pytorch加载mnist数据集的代码如下 import torchvision import torchvision.transforms as transforms from torch.u ...

  3. 神坑 Resources.Load 不能实时加载TXT文件

    Resources.Load(fileName) as TextAsset; 这句话并不能实时加载文本文件,对文本文件进行修改之后,若是没有刷新的话,加载的还是之前的文件: 要实时读取文本文件还是要以 ...

  4. hive 压缩全解读(hive表存储格式以及外部表直接加载压缩格式数据);HADOOP存储数据压缩方案对比(LZO,gz,ORC)

    数据做压缩和解压缩会增加CPU的开销,但可以最大程度的减少文件所需的磁盘空间和网络I/O的开销,所以最好对那些I/O密集型的作业使用数据压缩,cpu密集型,使用压缩反而会降低性能. 而hive中间结果 ...

  5. 为不同分辨率单独做样式文件,在页面头部用js判断分辨率后动态加载定义好的样式文件

    为不同分辨率单独做样式文件,在页面头部用js判断分辨率后动态加载定义好的样式文件.样式文件命名格式如:forms[_屏幕宽度].css,样式文件中只需重新定义文本框和下拉框的宽度即可. 在包含的头文件 ...

  6. 使用getJSON()方法异步加载JSON格式数据

    使用getJSON()方法异步加载JSON格式数据 使用getJSON()方法可以通过Ajax异步请求的方式,获取服务器中的数组,并对获取的数据进行解析,显示在页面中,它的调用格式为: jQuery. ...

  7. cesium模型加载-加载fbx格式模型

    整体思路: fbx格式→dae格式→gltf格式→cesium加载gltf格式模型 具体方法: 1. fbx格式→dae格式 工具:3dsMax, 3dsMax插件:OpenCOLLADA, 下载地址 ...

  8. Lab_1:练习4——分析bootloader加载ELF格式的OS的过程

    一.实验内容 通过阅读bootmain.c,了解bootloader如何加载ELF文件.通过分析源代码和通过qemu来运行并调试bootloader&OS, bootloader如何读取硬盘扇 ...

  9. Lab1:练习四——分析bootloader加载ELF格式的OS的过程

    练习四:分析bootloader加载ELF格式的OS的过程. 1.题目要求 通过阅读bootmain.c,了解bootloader如何加载ELF文件.通过分析源代码和通过qemu来运行并调试bootl ...

  10. 如何实现通过Leaflet加载dwg格式的CAD图

    前言 ​ 在前面介绍了通过openlayers加载dwg格式的CAD图并与互联网地图叠加,openlayers功能很全面,但同时也很庞大,入门比较难,适合于大中型项目中.而在中小型项目中,一般用开源的 ...

随机推荐

  1. 学习ASP.NET Core Blazor编程系列二十五——登录(4)

    学习ASP.NET Core Blazor编程系列文章之目录 学习ASP.NET Core Blazor编程系列一--综述 学习ASP.NET Core Blazor编程系列二--第一个Blazor应 ...

  2. 重学SpringBoot. step6 SpringBoot高级技巧

    SpringBoot高级技术 博客地址: step6 SpringBoot高级技巧 异步线程池 书上讲的是什么像异步操作那样,然后不需要等待. 问题是,不需要等待,但数据在生成的时候的时间并不能省. ...

  3. 真正“搞”懂HTTPS协议15之安全的定义

    前面我们花了很大的篇幅来讲HTTP在性能上的改进,从1.0到1.1,再到2.0.3.0,HTTP通过替换底层协议,解决了一直阻塞性能提升的队头阻塞问题,在性能上达到了极致. 那么,接下来,我们来聊一聊 ...

  4. .NET WebAPI 跨域问题(has been blocked by CORS policy:No Access-Control-Allow-Ogigin header is present on the requested resource)

    一.什么是跨域 1. 跨域解释 跨域指的是浏览器不能执行其他网站的脚本.它是由浏览器的同源策略造成的,是浏览器施加的安全限制. 同源指的是:域名,协议,端口均相同. 2. 什么情况下会导致跨域 2.1 ...

  5. concurrent-map 和 sync.Map,我该选择哪个?

    concurrent-map 和 sync.Map,我该选择哪个? 官方的map并不是线程安全的,如果我们在多线程中并发对一个map进行读写操作,是会引发panic的.解决方案除了使用锁来对map进行 ...

  6. vulnhub靶场之DRIFTINGBLUES: 9 (FINAL)

    准备: 攻击机:虚拟机kali.本机win10. 靶机:DriftingBlues: 9 (final),下载地址:https://download.vulnhub.com/driftingblues ...

  7. NOIP 模拟赛 简单题

    \(\text{Solution}\) 发现题目就是求 \(\sum[\prod_{i=1}^k x_i \le n]\) \(k \le 10^9\) 太可怕了 然而发现如果限定 \(x_i > ...

  8. JZOJ 2020.07.16【NOIP提高组】模拟

    总结 这套题相比昨天,简单了不止一点 然而有的人拿了 \(300\) 多 而我只有 \(198\) 预估应该有 \(268\) 的,假了 \(70\) 分 出现了很多奇怪的 \(mistakes\) ...

  9. vue data为什么是函数

    vue data是函数的原因: 1.防止data复用: 2.data独立性: 3.作用域: 4.js的特性. 总结来说,如果data是一个函数的话,这样每复用一次组件,就会返回一份新的data(类似于 ...

  10. 基于C++的OpenGL 11 之投光物

    1. 引言 本文基于C++语言,描述OpenGL的投光物 前置知识可参考: 基于C++的OpenGL 10 之光照贴图 - 当时明月在曾照彩云归 - 博客园 (cnblogs.com) 笔者这里不过多 ...