torchnet package (1)

torchnet

torchnet

torchnet是用于torch的代码复用和模块化编程的框架,主要包含四个类

  • Dataset 以不同的方式对数据进行预处理

  • Engine 训练/测试机器学习方法

  • Meter 评估方法性能

  • Log 日志

Documentation

torchnet的调用

local tnt = require 'torchnet'

tnt.Dataset()

torchnet提供了多种即插即用的数据容器(data container),例如 concat,split,batch,resample,etc ... 操作。

tnt.Dataset()实例包含两种主要方法

  • dataset:size() 返回数据集的大小

  • dataset:get(idx) 其中idx是1到size中的数字,返回数据集的第idx个样本

尽管可以简单的通过for loop循环实现数据集的迭代,为了用户能够以on-the-fly manner找出某些样本或者并行的数据读取,torchnet还提供了一些DatasetIterator类型的迭代器

在torchnet中,dataset:get()返回的可以是一个Lua table。table中的阈值可以是任意的,即使大多数的数据集都是tensor类型。

需要注意的是,并不能直接的使用tnt.Dataset()创建该类型,该类型类似于一个抽象类,其下面的具体类包括batchdataset,splitdataset等

-tnt.ListDataset(self,list,load,[,path])

参数:self = tnt.ListDataset

list = tds.Hash

load = function

[path = string]

其中list可以是tds.Hash,table或者torch.LongTensor类型,当访问第i个样本时返回的是load(list[i]),这里load() 是由user提供的闭包函数

当path参数非空的时候,list对应的应该是string队列,这样传递给load()函数的参数自动加上path前缀,比如访问文件夹'd:/data/mot2015/'下的数据时,不同的子数据集存放在不同的文件里'1.txt','2.txt','3.txt',...这时候 list={'1.txt','2.txt',...},path='d:/data/mot2015/',那么load(x)内的x=path .. x

  1. a={{1,2,3},{2,3,4},{2,2,2}} 

  2. b=torch.Tensor(a) 

  3. f=tnt.ListDataset({list=b:long(),function(x) return x:sum() end}) 

  4. print(f:size()) -- 3 

  5. print(f:get(1)) -- 6 

注意list只能是hash,table或者longtensor这里容易出现错误的是习惯使用:long()将tensor类型转换,但是对于元素含小数部分的tensor直接类型转换会出现错误!

  • tnt.ListDataset(self,filename,load[,maxload][,path])

    参数: self = tnt.ListDataset

    filename = string 这里filename指定的文件的每一行都是list的一个元素,类似于io.lines(filename)

    load = function 闭包函数

    [maxload = number] 最大加载条目数

    [path = string] 同之前

  • tnt.TableDataset(self,data)

    参数: data = table 针对于小型数据集,data必须Hash索引,对data数据浅层拷贝

  1. a= tnt.TableDataset{data={1,2,3}} 

  2. print(a:get(1)) 

tnt.TableDataset假定table中key从1连续

  • tnt.TransformDataset(self,dataset,transform[,key])

    参数: self = tnt.TransformDataset

    dataset = tnt.Dataset

    transform = function -- 变换函数

    [key = string] -- 需要变换的key值,如果没有则对dataset中所有数据操作

    当使用tnt.Dataset:get()查询数据集中的数据时,tnt.TransformDataset()以on-the-fly 方式执行闭包函数transform并返回值。

    on-the-fly我的理解是不需要中断过程去执行闭包函数,不涉及从内存中读取数据,而是直接通过cache形式执行,速度很快

  1. a=torch.Tensor{{1,2,3,4},{2,3,4,4}}:long() 

  2. ldata=tnt.ListDataset({list=a,load=function(x) return x end}) 

  3. tdata=tnt.TransformDataset({dataset=ldata,transform=function(x) return x-10 end}) 

  4. print(tdata:get(1)) 

-- tnt.TransformDataset(self,dataset,transforms)

注意这个方法是transforms 是一个table,table中的键值对应着dataset[list[i]]的域,如果我们使用tnt.TableDataset{{a=1,b=2,c=3},{a=0,b=3,c=5}}创建Dataset,如下

  1. Tdata = tnt.TableDataset{data={{a=1,b=2,c=3},{a=0,b=3,c=5}}} 

  2. f=tnt.TransformDataset({dataset = Tdata,transforms={a=function(x) return 2*x end,b=function(x) return x-20 end}}) 

  3. f:get(1) -- 这时候输出{a:2 b:-18 c:3},即Tdata[i]的域a执行了transforms.a函数,域b执行transforms.b函数 

-- tnt.BatchDataset(self,dataset,batchsize[,perm][,merge][,policy][,filter])

参数:self = tnt.BatchDataset

dataset = tnt.Dataset

batchsize = number

[perm = function]

[merge = function]

[policy = string]

[filter = function]

功能:将dataset中的batchsize个样本组成一个样本,方便batch处理

merge函数主要是将batchsize个样本的不同域组合起来,比如数据集的第i个样本写作

{input = <input_i>,target = <target_i>}

那么merge()使数据组合为

  1. {<input_i_1>,<ingput_i_2>,... <input_i_n>} 和 {<target_i_1>,<target_i_2>,... <target_i_n>} 

  1. ldata = tnt.ListDataset({ 

  2. list = torch.range(1,40):long(), 

  3. load = function(x) return {input={torch.randn(2,2),torch.randn(3,3)},target =x,target_t = -x } end 

  4. }) 

  5. bdata = tnt.BatchDataset{ 

  6. dataset = ldata, 

  7. batchsize=10 



  8. print(bdata:size()) --输出 4 

  9. print(bdata:get(1)) -- 输出第一个batch,包含3个field:target,input,target_t 

  10. print(bdata:get(1).input[1]) -- 输出一个input元素  

batch方式操作时,shuffle很重要,所以perm(idx,size)是一个闭包函数,该函数返回shuffle之后idx位置索引的样本,size是dataset的大小。

dataset的size可能不能被batchsize整除,于是 policy指定了截取方式

* include-last 不能整除的最后一个batch大小非必要等于batchsize

* skip-last 最后余出的部分样本舍掉,这并不意味着那些样本就不用了,因为shuffle后的样本排序不定

* divisible-only 不能整除则报错

  • tnt.CoroutineBatchDataset(self,dataset,batchsize[,perm][,merge][,policy][,filter])

    该方法和BatchDataset方法参数完全一致,实现的功能也几乎一致,唯一不同的地方是该方法可以用于协同程序,用到的时候再看吧。。。

  • tnt.ConcatDataset(self,datasets)

    参数: self= tnt.ConcatDataset

    dataset = table

    功能:将table中的数据集concate

  • tnt.ResampleDataset(self,dataset[,sampler][,size])

    给定一个数据集dataset,然后通过sampler(dataset,idx)闭包函数重采样获得新的数据集,size可以指定resample数据集的大小,若没指定则与原来的dataset大小相同,通过源码我们可以看到sampler这个函数其实是用来实现idx的改变

  1. ldata = tnt.ListDataset{list = torch.range(1,40):long(), 

  2. load = function(x) return {input={torch.randn(2,2),torch.randn(3,3)},target =x,target_t = -x } end} 

  3. iidx = tnt.transform.randperm(ldata:size()) 

  4. rdata = tnt.ResampleDataset{dataset = ldata,sampler = function(dataset,idx) return iidx(idx) end} --这其实实现了shuffle功能 

  5. print(rdata:get(1)) 

  • tnt.ShuffleDataset(self,dataset[,size][,replacement])

    实现dataset的shuffle,如果replacement=true,那么指定的size可以大于dataset:size(),大于的部分通过redraw获得

    tnt.ShuffleDataset.resample(self)

    通过该函数在构建ShuffleDataset时就创建fixed的permutation,能够保证多次index同一个值得到的结果相同

  • tnt.SplitDataset(self,dataset,partitions[,initialpartition])

    partitions = table

    [initialpartition = string]

    partitions是一个lua table,table中的元素<key,value>,key是对应partition的名,value是一个0-1的数表示取dataset:size()的比例,或者直接是个number表示对应partitions的大小,initialpartition指定了初始化时加载的partition

    注意 ,该方法在交叉验证时,用起来很爽

    tnt.SplitDataset.select(self,partition) 改变当前选择的partition

  1. sdata = tnt.SplitDataset{data=ldata,partitions={train=0.5,ver=0.25,test=0.25}} 

  2. sdata:select('train') --因为没有指定initialpartition所以需要指定当前的partition才能访问,当指定initialpartition后,该行可以不要,如 sdata = tnt.SplitDataset{data=ldata,partitions={train=0.5,ver=0.25,test=0.25},initialpartitial='train'} 

  3. print(sdata:get(1)) 

  4. print(sdata:size()) 

tnt.utils

torchnet提供了许多工具函数

  • tnt.utils.table.clone(table) 实现table的深度拷贝

  • tnt.utils.table.merge(dst,src) 将src合并到dst中,实现的是浅层拷贝,如果src中的key在dst中已经存在,则覆盖dst中的key值

  1. src={{1,2,3},{4,5,6}} 

  2. dst1={} 

  3. dst2={{1,2,3}} 

  4. dst1=tnt.utils.table.clone(src) 

  5. dst1[1][2]=10 

  6. print(dst1) -- 此时dst1[1][2]=10 

  7. print(src) -- src[1][2]=2 

  8. tnt.utils.table.merge(dst2,src) 

  9. dst2[1][2]=10 

  10. print(dst2) -- dst2[1][2]=10 

  11. print(src) -- src[1][2]=10 

  12. src={a={1,2,3},b={2,3,4}} 

  13. dst1={c={2,2,2}} 

  14. dst2={a={2,3}} 

  15. tnt.utils.table.merge(dst1,src) 

  16. tnt.utils.table.merge(dst2,src) 

  17. print(dst1) -- dst1包含三个元素a,b,c 

  18. print(dst2) -- dst2仅包含2个元素a,b,其中dst2中原来的a被src中的a覆盖 

  • tnt.utils.table.foreach(tbl,closure[,recursive])

    参数: tbl 是一个lua table; closure 闭包函数; [recursive = boolean] 默认值为false

    功能: 对tbl中的每一个元素执行closure函数,如果recursive=true那么tbl将被递归的采用closure函数

    示例:

  1. a={{1,2,3},{2,3,4},{{2,2,2},{1,1,1}}} 

  2. fun = function(v)print('------');print(v)end 

  3. tnt.utils.table.foreach(a,fun) 

  4. tnt.utils.table.foreach(a,fun,true) 

输出:

  1. ------ 



  2. 1 : 1 

  3. 2 : 2 

  4. 3 : 3 



  5. ------ 



  6. 1 : 2 

  7. 2 : 3 

  8. 3 : 4 



  9. ------ 



  10. 1 : 



  11. 1 : 2 

  12. 2 : 2 

  13. 3 : 2 



  14. 2 : 



  15. 1 : 1 

  16. 2 : 1 

  17. 3 : 1 





  1. ------ 



  2. ------ 



  3. ------ 



  4. ------ 



  5. ------ 



  6. ------ 



  7. ------ 



  8. ------ 



  9. ------ 



  10. ------ 



  11. ------ 



  12. ------ 



可以发现,recursive = true下递归调用表中元素,直至最里层的单个元素,而在false下,最外层table中每个元素作为输入参数输入到closure函数中

  • tnt.utils.table.canmergetensor(tbl)

    tbl是否能够merge成一个tensor,table中元素是相同规模的tensor则可以mergetensor

  • tnt.utils.table.mergetensor(tbl)

    将tbl中的元素合并成tensor

  1. a={torch.Tensor(3,2),torch.Tensor(3,3)} 

  2. b={torch.Tensor(3):float(),torch.Tensor(3):double()} 

  3. c={torch.Tensor(4),torch.Tensor(4)} 

  4. var={a,b,c} 

  5. for i=1,3 do 

  6. if tnt.utils.table.canmergetensor(var[i]) then 

  7. print(i) 

  8. tnt.utils.table.mergetensor(var[i]) 

  9. end 

  10. end 

此时显示b,c可以mergetensor,说明只要tensor的规模相同就可以,与其type是否一致无关

tnt.transform

该package提供了数据的基本变换,这些变换有的直接作用在数据上,有的作用在数据结构上,使得操作tnt.Dataset非常方便

这些变换虽然都很简单,但是这些边还可以通过compose或者merge方式实现复杂的变换,compose就是将变换串起来,merge是将变换同时执行,返回每个变换的结果

  • transform.identity(...)

    该变换返回输入本身,这个暂时没想到使用的地方

  • transform.compose(transforms)

    其中参数transforms是一个函数列表,每个函数可以实现一种变换。注意该函数认为transforms中的函数是从1开始连续索引的,如果碰到不连续的了,那么只执行前面连续索引的变换

  1. transform=tnt.transform 

  2. f=transform.compose({ 

  3. function(x) return 2*x end, 

  4. function(x) return x+10 end, 

  5. foo = function(x) return x/2 end 

  6. }) 

  7. a={2,3,4} 

  8. _ =tnt.utils.table.foreach(a,function(x) print(f(x)) end) 输出 14,16,18,即只执行了f中前两个变换 

注意这里函数列表写成{[1]=function(x) return 2*x end,function(x) return x+10 end,foo = function(x) return x/2 end}则只执行第一个变换,因为key:[2]不存在

  • transform.merge(transforms)

    transforms是一个变换函数列表,对于一个输入,该函数使该输入经过所有变换函数得到的结果merge成table输出

  1. f = transform.merge{ 

  2. [1] = function(x) return torch.Tensor{2*x} end, 

  3. [2] = function(x) return torch.Tensor{x + 10} end, 

  4. [3] = function(x) return torch.Tensor{x / 2} end, 

  5. [4] = function(x) return torch.Tensor{x} end 



  6. f(3) 

注意这个例子输出的是一个tensor,并不是doc中说的输出一个table,我觉得这个函数除了bug,transform.lua的第144行应该直接return newz就可以了,源代码中使用utils.table.mergetensor(newz)反而会导致合并出错,要想让源代码能执行就必须像上面我给的例子似的,函数返回的是同等规模的tensor,且函数列表中的index必须是从1开始连续索引,源代码要是不改这个函数还是得特别注意

  • transform.tableapply(transform)

    这里的参数transform是一个变换函数,该变换作用于table变量

  1. a={1,2,3,4} 

  2. f=transform.tableapply(function(x) return x*2 end) 

  3. f(a) 

  • transform.tablemergekeys()

    得到的变换方法的输入必须是一个table的table

  1. x={{input=1,target='a'},{input=2,target='b',flag='hard'}} 

  2. transform.tablemergekeys(x) 

注意这个源码也有问题,源码transform.lua中的243行中ipairs应该修改为pairs,否则给的例子运行不了,因为ipairs从1开始index到第一个非整数key就结束了

  • transform.randperm(size)

    randperm()函数,注意该函数返回的是一个函数句柄,想要获得第i个值,应该用f=transform.randperm(10);f(i)

  • trandform.normalize([threshold])

    输入必须是一个tensor,该函数能够实现标准化,即中心化+归一化,参数threshold是一个number,只有标准差大于threshold时,tensor才会normalize

  1. a=torch.rand(2,3)*10 

  2. print('the std of a is ' .. a:std()) 

  3. f=transform.normalize() 

  4. print('the std of normalized a is ' .. f(a):std() .. ' and the mean is ' .. f(a):sum()) 

torchnet package (1)的更多相关文章

  1. torchnet package (2)

    torchnet package (2) torchnet torch7 Dataset Iterators 尽管是用for loop语句很容易处理Dataset,但有时希望以on-the-fly m ...

  2. torchnet+VGG16计算patch之间相似度

    torchnet+VGG16计算patch之间相似度 torch VGG16 similarity 本来打算使用VGG实现siamese CNN的,但是没想明白怎么使用torchnet对模型进行微调. ...

  3. NPM (node package manager) 入门 - 基础使用

    什么是npm ? npm 是 nodejs 的包管理和分发工具.它可以让 javascript 开发者能够更加轻松的共享代码和共用代码片段,并且通过 npm 管理你分享的代码也很方便快捷和简单. 截至 ...

  4. npm package.json属性详解

    概述 本文档是自己看官方文档的理解+翻译,内容是package.json配置里边的属性含义.package.json必须是一个严格的json文件,而不仅仅是js里边的一个对象.其中很多属性可以通过np ...

  5. 关于Visual Studio 未能加载各种Package包的解决方案

    问题: 打开Visual Studio 的时候,总提示未能加载相应的Package包,有时候还无法打开项目,各种提示 解决方案: 进入用户目录 C:\Users\用户名\AppData\Local\M ...

  6. SSIS 包部署 Package Store 后,在 IS 中可以执行,AGENT 执行却报错

    可以执行 SSIS Package ,证明用 SSIS Package 的账户是可以执行成功的.SQL Server Agent 默认指定账号是 Network Service. 那么可以尝试一下将 ...

  7. 如何使用yum 下载 一个 package ?如何使用 yum install package 但是保留 rpm 格式的 package ? 或者又 如何通过yum 中已经安装的package 导出它,即yum导出rpm?

    注意 RHEL5 和 RHEL6 的不同 How to use yum to download a package without installing it Solution Verified - ...

  8. [转]安装 SciTE 报错 No package ‘gtk+-2.0′ found

    centos 记事本,有时候感觉不够用,或者 出毛病,打不开文件 然后决定安装个其他的记事本,  找来找去, 感觉 SciTE 还可以,于是下载源码编译安装,结果 No package ‘gtk+-2 ...

  9. ERROR ITMS-90167: "No .app bundles found in the package"错误

    ERROR ITMS-90167: "No .app bundles found in the package" 出现如上错误请查检以下2个方向: 1.macOS Sierra 1 ...

随机推荐

  1. centos7上搭建ftp服务器(亲测可用)

    1.安装vsftpd 首先要查看你是否安装vsftp [root@localhost /]# rpm -q vsftpd vsftpd-3.0.2-10.el7.x86_64 (显示以上相关信息也就安 ...

  2. Google Now 'not available in your country'

    Google Now 'not available in your country' Don't know how to cope with this problem.

  3. I/O排查命令

    I/O可以说是问题大户,线上的问题经常都是它引起的,很多人却不知道怎么定位这种问题.今天简单介绍一下,在此抛砖引玉. 此类问题我们一般分三步定位:按系统级I/O.进程级I/O.业务级I/O定位即可,一 ...

  4. Django - Ajax - 参数

    一.Jquery实现Ajax url   type   data   success   error  complete  statusCode {% load staticfiles %} < ...

  5. Supermarket---poj456(贪心并查集优化)

    题目链接:http://poj.org/problem?id=1456 题意是现有n个物品,每个物品有一个保质期和一个利润,现在每天只能卖一个商品,问最大的利润是多少,商品如果过期了就不能卖了: 暴力 ...

  6. cocos2d首印象

    一. 创建工程 从 2.1.4 版本开始,官方就不再为 VS 提供模板了,逐步在各平台采用统一的 Python 脚本创建跨平台工程. 要创建工程,我们需要先从命令行进入 tools/project-c ...

  7. android(十)smali

    Dalvik是google专门为Android操作系统设计的一个虚拟机,经过深度的优化.虽然Android上的程序是使用java来开发的,但是Dalvik和标准的java虚拟机JVM还是两回事. Da ...

  8. Openstack(十五)快速添加新计算节点

    当后期添加新物理服务器作为计算节点,如果按照上面的过程安装配置的话会非常的慢,但是可以通过复制配置文件的方式快速添加. 15.1计算节点服务安装 #提前将yum仓库.防火墙.selinux.主机名.时 ...

  9. 找回 linux root密码的几种方法

    第1种方法: 1.在系统进入单用户状态,直接用passwd root去更改  2.用安装光盘引导系统,进行linux rescue状态,将原来/分区挂接上来,作法如下: Java代码  #> c ...

  10. linux更改文件或目录的属主和属组

    chown  1.效用  更改一个或者多个文件或者目录的属主以及属组,使用职权范围是超等用户  2.格局  chown [选项] 用户或者组 文件  3.首要参量  --dereference:受影响 ...