torchnet package (1)

torchnet

torchnet

torchnet是用于torch的代码复用和模块化编程的框架,主要包含四个类

  • Dataset 以不同的方式对数据进行预处理

  • Engine 训练/测试机器学习方法

  • Meter 评估方法性能

  • Log 日志

Documentation

torchnet的调用

local tnt = require 'torchnet'

tnt.Dataset()

torchnet提供了多种即插即用的数据容器(data container),例如 concat,split,batch,resample,etc ... 操作。

tnt.Dataset()实例包含两种主要方法

  • dataset:size() 返回数据集的大小

  • dataset:get(idx) 其中idx是1到size中的数字,返回数据集的第idx个样本

尽管可以简单的通过for loop循环实现数据集的迭代,为了用户能够以on-the-fly manner找出某些样本或者并行的数据读取,torchnet还提供了一些DatasetIterator类型的迭代器

在torchnet中,dataset:get()返回的可以是一个Lua table。table中的阈值可以是任意的,即使大多数的数据集都是tensor类型。

需要注意的是,并不能直接的使用tnt.Dataset()创建该类型,该类型类似于一个抽象类,其下面的具体类包括batchdataset,splitdataset等

-tnt.ListDataset(self,list,load,[,path])

参数:self = tnt.ListDataset

list = tds.Hash

load = function

[path = string]

其中list可以是tds.Hash,table或者torch.LongTensor类型,当访问第i个样本时返回的是load(list[i]),这里load() 是由user提供的闭包函数

当path参数非空的时候,list对应的应该是string队列,这样传递给load()函数的参数自动加上path前缀,比如访问文件夹'd:/data/mot2015/'下的数据时,不同的子数据集存放在不同的文件里'1.txt','2.txt','3.txt',...这时候 list={'1.txt','2.txt',...},path='d:/data/mot2015/',那么load(x)内的x=path .. x

  1. a={{1,2,3},{2,3,4},{2,2,2}} 

  2. b=torch.Tensor(a) 

  3. f=tnt.ListDataset({list=b:long(),function(x) return x:sum() end}) 

  4. print(f:size()) -- 3 

  5. print(f:get(1)) -- 6 

注意list只能是hash,table或者longtensor这里容易出现错误的是习惯使用:long()将tensor类型转换,但是对于元素含小数部分的tensor直接类型转换会出现错误!

  • tnt.ListDataset(self,filename,load[,maxload][,path])

    参数: self = tnt.ListDataset

    filename = string 这里filename指定的文件的每一行都是list的一个元素,类似于io.lines(filename)

    load = function 闭包函数

    [maxload = number] 最大加载条目数

    [path = string] 同之前

  • tnt.TableDataset(self,data)

    参数: data = table 针对于小型数据集,data必须Hash索引,对data数据浅层拷贝

  1. a= tnt.TableDataset{data={1,2,3}} 

  2. print(a:get(1)) 

tnt.TableDataset假定table中key从1连续

  • tnt.TransformDataset(self,dataset,transform[,key])

    参数: self = tnt.TransformDataset

    dataset = tnt.Dataset

    transform = function -- 变换函数

    [key = string] -- 需要变换的key值,如果没有则对dataset中所有数据操作

    当使用tnt.Dataset:get()查询数据集中的数据时,tnt.TransformDataset()以on-the-fly 方式执行闭包函数transform并返回值。

    on-the-fly我的理解是不需要中断过程去执行闭包函数,不涉及从内存中读取数据,而是直接通过cache形式执行,速度很快

  1. a=torch.Tensor{{1,2,3,4},{2,3,4,4}}:long() 

  2. ldata=tnt.ListDataset({list=a,load=function(x) return x end}) 

  3. tdata=tnt.TransformDataset({dataset=ldata,transform=function(x) return x-10 end}) 

  4. print(tdata:get(1)) 

-- tnt.TransformDataset(self,dataset,transforms)

注意这个方法是transforms 是一个table,table中的键值对应着dataset[list[i]]的域,如果我们使用tnt.TableDataset{{a=1,b=2,c=3},{a=0,b=3,c=5}}创建Dataset,如下

  1. Tdata = tnt.TableDataset{data={{a=1,b=2,c=3},{a=0,b=3,c=5}}} 

  2. f=tnt.TransformDataset({dataset = Tdata,transforms={a=function(x) return 2*x end,b=function(x) return x-20 end}}) 

  3. f:get(1) -- 这时候输出{a:2 b:-18 c:3},即Tdata[i]的域a执行了transforms.a函数,域b执行transforms.b函数 

-- tnt.BatchDataset(self,dataset,batchsize[,perm][,merge][,policy][,filter])

参数:self = tnt.BatchDataset

dataset = tnt.Dataset

batchsize = number

[perm = function]

[merge = function]

[policy = string]

[filter = function]

功能:将dataset中的batchsize个样本组成一个样本,方便batch处理

merge函数主要是将batchsize个样本的不同域组合起来,比如数据集的第i个样本写作

{input = <input_i>,target = <target_i>}

那么merge()使数据组合为

  1. {<input_i_1>,<ingput_i_2>,... <input_i_n>} 和 {<target_i_1>,<target_i_2>,... <target_i_n>} 

  1. ldata = tnt.ListDataset({ 

  2. list = torch.range(1,40):long(), 

  3. load = function(x) return {input={torch.randn(2,2),torch.randn(3,3)},target =x,target_t = -x } end 

  4. }) 

  5. bdata = tnt.BatchDataset{ 

  6. dataset = ldata, 

  7. batchsize=10 



  8. print(bdata:size()) --输出 4 

  9. print(bdata:get(1)) -- 输出第一个batch,包含3个field:target,input,target_t 

  10. print(bdata:get(1).input[1]) -- 输出一个input元素  

batch方式操作时,shuffle很重要,所以perm(idx,size)是一个闭包函数,该函数返回shuffle之后idx位置索引的样本,size是dataset的大小。

dataset的size可能不能被batchsize整除,于是 policy指定了截取方式

* include-last 不能整除的最后一个batch大小非必要等于batchsize

* skip-last 最后余出的部分样本舍掉,这并不意味着那些样本就不用了,因为shuffle后的样本排序不定

* divisible-only 不能整除则报错

  • tnt.CoroutineBatchDataset(self,dataset,batchsize[,perm][,merge][,policy][,filter])

    该方法和BatchDataset方法参数完全一致,实现的功能也几乎一致,唯一不同的地方是该方法可以用于协同程序,用到的时候再看吧。。。

  • tnt.ConcatDataset(self,datasets)

    参数: self= tnt.ConcatDataset

    dataset = table

    功能:将table中的数据集concate

  • tnt.ResampleDataset(self,dataset[,sampler][,size])

    给定一个数据集dataset,然后通过sampler(dataset,idx)闭包函数重采样获得新的数据集,size可以指定resample数据集的大小,若没指定则与原来的dataset大小相同,通过源码我们可以看到sampler这个函数其实是用来实现idx的改变

  1. ldata = tnt.ListDataset{list = torch.range(1,40):long(), 

  2. load = function(x) return {input={torch.randn(2,2),torch.randn(3,3)},target =x,target_t = -x } end} 

  3. iidx = tnt.transform.randperm(ldata:size()) 

  4. rdata = tnt.ResampleDataset{dataset = ldata,sampler = function(dataset,idx) return iidx(idx) end} --这其实实现了shuffle功能 

  5. print(rdata:get(1)) 

  • tnt.ShuffleDataset(self,dataset[,size][,replacement])

    实现dataset的shuffle,如果replacement=true,那么指定的size可以大于dataset:size(),大于的部分通过redraw获得

    tnt.ShuffleDataset.resample(self)

    通过该函数在构建ShuffleDataset时就创建fixed的permutation,能够保证多次index同一个值得到的结果相同

  • tnt.SplitDataset(self,dataset,partitions[,initialpartition])

    partitions = table

    [initialpartition = string]

    partitions是一个lua table,table中的元素<key,value>,key是对应partition的名,value是一个0-1的数表示取dataset:size()的比例,或者直接是个number表示对应partitions的大小,initialpartition指定了初始化时加载的partition

    注意 ,该方法在交叉验证时,用起来很爽

    tnt.SplitDataset.select(self,partition) 改变当前选择的partition

  1. sdata = tnt.SplitDataset{data=ldata,partitions={train=0.5,ver=0.25,test=0.25}} 

  2. sdata:select('train') --因为没有指定initialpartition所以需要指定当前的partition才能访问,当指定initialpartition后,该行可以不要,如 sdata = tnt.SplitDataset{data=ldata,partitions={train=0.5,ver=0.25,test=0.25},initialpartitial='train'} 

  3. print(sdata:get(1)) 

  4. print(sdata:size()) 

tnt.utils

torchnet提供了许多工具函数

  • tnt.utils.table.clone(table) 实现table的深度拷贝

  • tnt.utils.table.merge(dst,src) 将src合并到dst中,实现的是浅层拷贝,如果src中的key在dst中已经存在,则覆盖dst中的key值

  1. src={{1,2,3},{4,5,6}} 

  2. dst1={} 

  3. dst2={{1,2,3}} 

  4. dst1=tnt.utils.table.clone(src) 

  5. dst1[1][2]=10 

  6. print(dst1) -- 此时dst1[1][2]=10 

  7. print(src) -- src[1][2]=2 

  8. tnt.utils.table.merge(dst2,src) 

  9. dst2[1][2]=10 

  10. print(dst2) -- dst2[1][2]=10 

  11. print(src) -- src[1][2]=10 

  12. src={a={1,2,3},b={2,3,4}} 

  13. dst1={c={2,2,2}} 

  14. dst2={a={2,3}} 

  15. tnt.utils.table.merge(dst1,src) 

  16. tnt.utils.table.merge(dst2,src) 

  17. print(dst1) -- dst1包含三个元素a,b,c 

  18. print(dst2) -- dst2仅包含2个元素a,b,其中dst2中原来的a被src中的a覆盖 

  • tnt.utils.table.foreach(tbl,closure[,recursive])

    参数: tbl 是一个lua table; closure 闭包函数; [recursive = boolean] 默认值为false

    功能: 对tbl中的每一个元素执行closure函数,如果recursive=true那么tbl将被递归的采用closure函数

    示例:

  1. a={{1,2,3},{2,3,4},{{2,2,2},{1,1,1}}} 

  2. fun = function(v)print('------');print(v)end 

  3. tnt.utils.table.foreach(a,fun) 

  4. tnt.utils.table.foreach(a,fun,true) 

输出:

  1. ------ 



  2. 1 : 1 

  3. 2 : 2 

  4. 3 : 3 



  5. ------ 



  6. 1 : 2 

  7. 2 : 3 

  8. 3 : 4 



  9. ------ 



  10. 1 : 



  11. 1 : 2 

  12. 2 : 2 

  13. 3 : 2 



  14. 2 : 



  15. 1 : 1 

  16. 2 : 1 

  17. 3 : 1 





  1. ------ 



  2. ------ 



  3. ------ 



  4. ------ 



  5. ------ 



  6. ------ 



  7. ------ 



  8. ------ 



  9. ------ 



  10. ------ 



  11. ------ 



  12. ------ 



可以发现,recursive = true下递归调用表中元素,直至最里层的单个元素,而在false下,最外层table中每个元素作为输入参数输入到closure函数中

  • tnt.utils.table.canmergetensor(tbl)

    tbl是否能够merge成一个tensor,table中元素是相同规模的tensor则可以mergetensor

  • tnt.utils.table.mergetensor(tbl)

    将tbl中的元素合并成tensor

  1. a={torch.Tensor(3,2),torch.Tensor(3,3)} 

  2. b={torch.Tensor(3):float(),torch.Tensor(3):double()} 

  3. c={torch.Tensor(4),torch.Tensor(4)} 

  4. var={a,b,c} 

  5. for i=1,3 do 

  6. if tnt.utils.table.canmergetensor(var[i]) then 

  7. print(i) 

  8. tnt.utils.table.mergetensor(var[i]) 

  9. end 

  10. end 

此时显示b,c可以mergetensor,说明只要tensor的规模相同就可以,与其type是否一致无关

tnt.transform

该package提供了数据的基本变换,这些变换有的直接作用在数据上,有的作用在数据结构上,使得操作tnt.Dataset非常方便

这些变换虽然都很简单,但是这些边还可以通过compose或者merge方式实现复杂的变换,compose就是将变换串起来,merge是将变换同时执行,返回每个变换的结果

  • transform.identity(...)

    该变换返回输入本身,这个暂时没想到使用的地方

  • transform.compose(transforms)

    其中参数transforms是一个函数列表,每个函数可以实现一种变换。注意该函数认为transforms中的函数是从1开始连续索引的,如果碰到不连续的了,那么只执行前面连续索引的变换

  1. transform=tnt.transform 

  2. f=transform.compose({ 

  3. function(x) return 2*x end, 

  4. function(x) return x+10 end, 

  5. foo = function(x) return x/2 end 

  6. }) 

  7. a={2,3,4} 

  8. _ =tnt.utils.table.foreach(a,function(x) print(f(x)) end) 输出 14,16,18,即只执行了f中前两个变换 

注意这里函数列表写成{[1]=function(x) return 2*x end,function(x) return x+10 end,foo = function(x) return x/2 end}则只执行第一个变换,因为key:[2]不存在

  • transform.merge(transforms)

    transforms是一个变换函数列表,对于一个输入,该函数使该输入经过所有变换函数得到的结果merge成table输出

  1. f = transform.merge{ 

  2. [1] = function(x) return torch.Tensor{2*x} end, 

  3. [2] = function(x) return torch.Tensor{x + 10} end, 

  4. [3] = function(x) return torch.Tensor{x / 2} end, 

  5. [4] = function(x) return torch.Tensor{x} end 



  6. f(3) 

注意这个例子输出的是一个tensor,并不是doc中说的输出一个table,我觉得这个函数除了bug,transform.lua的第144行应该直接return newz就可以了,源代码中使用utils.table.mergetensor(newz)反而会导致合并出错,要想让源代码能执行就必须像上面我给的例子似的,函数返回的是同等规模的tensor,且函数列表中的index必须是从1开始连续索引,源代码要是不改这个函数还是得特别注意

  • transform.tableapply(transform)

    这里的参数transform是一个变换函数,该变换作用于table变量

  1. a={1,2,3,4} 

  2. f=transform.tableapply(function(x) return x*2 end) 

  3. f(a) 

  • transform.tablemergekeys()

    得到的变换方法的输入必须是一个table的table

  1. x={{input=1,target='a'},{input=2,target='b',flag='hard'}} 

  2. transform.tablemergekeys(x) 

注意这个源码也有问题,源码transform.lua中的243行中ipairs应该修改为pairs,否则给的例子运行不了,因为ipairs从1开始index到第一个非整数key就结束了

  • transform.randperm(size)

    randperm()函数,注意该函数返回的是一个函数句柄,想要获得第i个值,应该用f=transform.randperm(10);f(i)

  • trandform.normalize([threshold])

    输入必须是一个tensor,该函数能够实现标准化,即中心化+归一化,参数threshold是一个number,只有标准差大于threshold时,tensor才会normalize

  1. a=torch.rand(2,3)*10 

  2. print('the std of a is ' .. a:std()) 

  3. f=transform.normalize() 

  4. print('the std of normalized a is ' .. f(a):std() .. ' and the mean is ' .. f(a):sum()) 

torchnet package (1)的更多相关文章

  1. torchnet package (2)

    torchnet package (2) torchnet torch7 Dataset Iterators 尽管是用for loop语句很容易处理Dataset,但有时希望以on-the-fly m ...

  2. torchnet+VGG16计算patch之间相似度

    torchnet+VGG16计算patch之间相似度 torch VGG16 similarity 本来打算使用VGG实现siamese CNN的,但是没想明白怎么使用torchnet对模型进行微调. ...

  3. NPM (node package manager) 入门 - 基础使用

    什么是npm ? npm 是 nodejs 的包管理和分发工具.它可以让 javascript 开发者能够更加轻松的共享代码和共用代码片段,并且通过 npm 管理你分享的代码也很方便快捷和简单. 截至 ...

  4. npm package.json属性详解

    概述 本文档是自己看官方文档的理解+翻译,内容是package.json配置里边的属性含义.package.json必须是一个严格的json文件,而不仅仅是js里边的一个对象.其中很多属性可以通过np ...

  5. 关于Visual Studio 未能加载各种Package包的解决方案

    问题: 打开Visual Studio 的时候,总提示未能加载相应的Package包,有时候还无法打开项目,各种提示 解决方案: 进入用户目录 C:\Users\用户名\AppData\Local\M ...

  6. SSIS 包部署 Package Store 后,在 IS 中可以执行,AGENT 执行却报错

    可以执行 SSIS Package ,证明用 SSIS Package 的账户是可以执行成功的.SQL Server Agent 默认指定账号是 Network Service. 那么可以尝试一下将 ...

  7. 如何使用yum 下载 一个 package ?如何使用 yum install package 但是保留 rpm 格式的 package ? 或者又 如何通过yum 中已经安装的package 导出它,即yum导出rpm?

    注意 RHEL5 和 RHEL6 的不同 How to use yum to download a package without installing it Solution Verified - ...

  8. [转]安装 SciTE 报错 No package ‘gtk+-2.0′ found

    centos 记事本,有时候感觉不够用,或者 出毛病,打不开文件 然后决定安装个其他的记事本,  找来找去, 感觉 SciTE 还可以,于是下载源码编译安装,结果 No package ‘gtk+-2 ...

  9. ERROR ITMS-90167: "No .app bundles found in the package"错误

    ERROR ITMS-90167: "No .app bundles found in the package" 出现如上错误请查检以下2个方向: 1.macOS Sierra 1 ...

随机推荐

  1. 全角半角符号引发的Entity Framework奇遇记

    SQL Server的SQL查询不区分大小写,而LINQ查询区分大小写,所以在写LINQ代码时需要注意的是——如果这段LINQ代码将会被Entity Framework解析为SQL语句(LINQ to ...

  2. Augmented reality in natural scenes

    Augmented reality in natural scenes (Iryna Gordon and David Lowe)2006年关于AR的研究成果 项目主页 http://www.cs.u ...

  3. HDU Today---hdu2112(最短路-_-坑在是无向图)

    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=2112 spfa或者迪杰斯特拉都可以 注意公交车是有来回的--- #include <iostre ...

  4. the source attachment does not contain the source for the file xxx.class无法关联到某个类

    问题描述: 按下列操作添加相应路径(这里是错误操作) 该问题仍旧无法解决: 注意:这里spring-webmvc-4.1.7.RELEASE.JAR中确实包含AnntationMethodHandle ...

  5. 【转发】Python使用openpyxl读写excel文件

    Python使用openpyxl读写excel文件 这是一个第三方库,可以处理xlsx格式的Excel文件.pip install openpyxl安装.如果使用Aanconda,应该自带了. 读取E ...

  6. 第1章 1.6计算机网络概述--OSI参考模型

    ISO七层模式:国际标准组织对互联网通信规则进行的定义. 7.应用层:所有能产生网络流量的程序,如:QQ. 6.表示层:传输前对数据进行进行处理,是一种数据处理的规则,如:加密.压缩.传输二进制(图片 ...

  7. 很靠谱linux常用命令

    vim是打开vim编辑器,别的编辑器还有vi(功能没有vim 强大),nano,emacs等等,感觉还是vim最强大,其次是vi,别的就要差一些了. 我听我们老师说,用图形界面本身已经会被高手笑了,如 ...

  8. java-基础-【二】内部类与静态内部类

    一.说明 java允许我们在一个类里面定义静态类.比如内部类(nested class).把nested class封闭起来的类叫外部类.在java中,我们不能用static修饰顶级类(top lev ...

  9. POJ2186:Popular Cows(tarjan+缩点)

    题目解析: 这题题意没什么好说的,解法也挺简单的,只要会tarjan算法+只有一个出度为0的强连通分量题目有解这题就迎刃而解了. #include <iostream> #include ...

  10. tar 压缩解压命令详解

    tar -c: 建立压缩档案-x:解压-t:查看内容-r:向压缩归档文件末尾追加文件-u:更新原压缩包中的文件 这五个是独立的命令,压缩解压都要用到其中一个,可以和别的命令连用但只能用其中一个.下面的 ...