torchnet package (2)

torchnet
torch7

Dataset Iterators

尽管是用for loop语句很容易处理Dataset,但有时希望以on-the-fly manner或者在线程中读取数据,这时候Dataset Iterator就是个好的选择

注意,iterators是用于特殊情况的,一般情况下还是使用Dataset比较好

Iteartor 的两个主要方法:

* run() 返回一个Lua 迭代器,也可以使用()操作符,因为iterator源码中定义了__call事件

* exec(funcname,...) 在指定的dataset上执行funcname方法,funcname是dataset自己的方法,比如size

  • tnt.DatasetIterator(self,dataset[,perm][,filter][,transform])

    The default dataset iterator

    perm(idx), 实现shuffle功能,即对idx进行变换,更复杂的变换可以使用ShuffleDataset

    filter(sample), 闭包函数,筛选样本是否用于迭代,返回bool值

    transform(sample),闭包函数,实现对样本的变换,更复杂的变换可以结合TransformDataset和transform.compose等实现

  1. ldata = tnt.ListData{list=torch.range(1,10):long(),load = function(x) return {x,x+1} end} 

  2. dIter = tnt.DatasetIterator{dataset = ldata,filter = function(x) if x[1]<2 then return false else return true end end} 

  3. for v in dIter:run() 

  4. print(v) 

  5. end 

  • tnt.ParallelDatasetIterator(self[,init],closure,nthread[,perm][,filter][,transform][,ordered])

    这个才是迭代器的重点,用于以多线程方式迭代数据。

The purpose of this class is to have a zero pre-processing cose. when reading datasets on the fly from disk(not loading thenm fully in memory), or performing complex pre-processing this canbe of interest.

nthreads 指定了线程的个数

init(threadid) 闭包函数,指定了线程threadid的初始化工作,如果啥都不做可以省略

closure(threadid) 每个线程的job,返回的必须时tnt.Dataset的一个实例

perm(idx) 用于shuffle

filter(sample) 闭包函数,指定哪些样本不用于迭代

transform(sample) 对样本进行变换,在filter之前执行

order 线程之间数据的处理是否有序,主要是为了程序的可重现性,当order=true时,多次执行程序,顺序是相同的

  1. tnt=require'torchnet' 

  2. local list=torch.Tensor{{2,2},{2,2},{2,2},{2,2}}:long() 

  3. ldata = tnt.ListDataset{list=list,load=function(x) return torch.Tensor(x[1],x[2]) end} 

  4. local bdata = tnt.BatchDataset{batchsize=2,dataset = tnt.TransformDataset{dataset = ldata,transform=function(x) return 2*x end}} 

  5. Padata = tnt.ParallelDatasetIterator{ 

  6. nthread = 4, 

  7. init = function(tid) 

  8. print ('init thread id: '.. tid) 

  9. tnt=require'torchnet' 

  10. end, 

  11. closure = function(tid) 

  12. print('closure of threadid: '.. tid) 

  13. return bdata 

  14. end 

  15. }  

尤其需要注意的是,closure中的所有upvalues都必须是可序列化的,最好是避免使用upvalues,并保证closure中使用的package都在init中require

tnt.Engine

在网络训练的过程中,都是计算前向误差,误差反传,更新权重这些过程,只是模型,数据和评价函数不同而已,所以Engine给训练过程提供了一个模板,该模板建立了model,DatasetIterator,Criterion和Meter之间的联系

engine=tnt.Engine()包含两个主要方法

* engine:train() 在数据集上训练数据

* engine:test() 评估模型,可选

Engine不仅实现了训练和评估的一般模板,还提供了许多接口,用于控制训练过程

  • tnt.SGDEngine

    SGDEngine 模块在train过程中使用Stochastic Gradient Descent方法训练,模块包含数据采样,前向传递,反向传递,参数更新等,还有一些钩子函数

    hooks = {

    ['onStart'] = function() end, --用于训练开始前的设置和初始化

    ['onStartEpoch'] = function() end, -- 每一个epoch前的操作

    ['onSample'] = function() end, -- 每次采样一个样本之后的操作

    ['onForward'] = function() end, -- 在model:forward()之后的操作

    ['onForwardCriterion'] = function() end, -- 前向计算损失函数之后的操作

    ['onBackwardCriterion'] = function() end, -- 反向计算损失误差之后的操作

    ['onBackward'] = function() end, -- 反向传递误差之后的操作

    ['onUpdate'] = function() end, -- 权重参数更新之后的操作

    ['onEndEpoch'] = function() end, -- 每一个epoch结束时的操作

    ['onEnd'] = function() end, -- 整个训练过程结束后的收拾现场

    }

    可以发现Engine给的hook函数还是很全面的,几乎训练过程的每一个节点都允许用户制定操作,使用hook函数

  1. local engine = SGDEngine() 

  2. local meter = tnt.AverageValueMeter() 

  3. engine.hooks.onStartEpoch = function(state) meter:reset() end 

一般而言,训练过程最少应该知道训练模型,损失函数,数据和学习率,这里学习方法已经知道了SGD,Engine用到的数据是tnt.DatasetIterator类型的。 评估过程只需要数据和模型就可以了

外部可以通过state变量与Engine训练过程交互

state = {

['network'] = network, --设置了model

['criterion'] = criterion, -- 设置损失函数

['iterator'] = iterator, -- 数据迭代器

['lr'] = lr, -- 学习率

['lrcriterion'] = lrcriterion, --

['maxepoch'] = maxepoch, --最大epoch数

['sample'] = {}, -- 当前采集的样本,可以在onSample中通过该阈值查看采样样本

['epoch'] = 0 , -- 当前的epoch

['t'] = 0, -- 已经训练样本的个数

['training'] = true -- 训练过程

}

评估时需要指定:

state = {

['netwrok'] = network

['iterator'] = iterator

['criterion'] = criterion

}

  • tnt.OptimEngine

    这个方法和SGDEngine的最大的区别在于封装了optim中的多种优化方法。在训练开始的时候,engine会通过getParameters获取model的参数

    train需要附加两个量:

    • optimMethod 优化方法,比如optim.sgd

    • config 优化方法对应的参数

      Example:

  1. local engine = tnt.OptimEngine{ 

  2. network = network, 

  3. criterion=criterion, 

  4. iterator = iterator, 

  5. optimMethod = optim.sgd, 

  6. config = { 

  7. learningRate = 0.1, 

  8. momentum = 0.9, 

  9. }, 



tnt.Meter

和Engine配合使用,用于measure the model.

几乎所有的meters都会有3个方法:

* add() 给待统计的meter添加一个观测值,其输入参数一般形式为(output,value),output为model的输出,target为真实值

* value() 获得待统计的meter的当前值

* reset() 重新计数

Meter的使用示例:

  1. local meter = tnt.<Measure>Meter() -- <Measure> 可以选择具体的度量 

  2. for state,event in tnt.<Optimization>Engine:train{ --定义Engine 

  3. network = network, 

  4. criterion=criterion, 

  5. iterator=iterator, 

  6. } do 

  7. if state == 'start-epoch' then  

  8. meter:reset() -- reset meter 

  9. elseif state == 'forward-criterion' then 

  10. meter:add(state.network.output,sample.target) 

  11. elseif state == 'end-epoch' then 

  12. print('value of meter:) .. meter:value()) 

  13. end 

  14. end 

  • tnt.APMeter(self)

    评估每一类的平均正确率

    APMeter的操作对象是一个的Tensor,表示N个样本对应在K类中的值,另外可选的一个的 Tensor表示每个样本的权重

  1. target = torch.Tensor{ 

  2. {0,0,0,1},{0,0,1,0},{0,1,0,0},{1,0,0,0},{1,0,0,0}} 

  3. apm = tnt.APMeter() 

  4. for i=1,5 do 

  5. apm:add{output=torch.rand(1,4),target=target[i]:size(1,4)} -- 注意N*K的Tensor 

  6. end 

  7. print(apm:value()) 

  • tnt.AverageValueMeter(self)

    用于统计任意添加的变量的方差和均值,可以用来测量平均损失等

    add()的输入必须时number类型,另外在add的时候可以有一个可选的参数n,表示对应值的权重

  1. avm = tnt.AverageValueMeter() 

  2. for i=1,10 do  

  3. avm:add(i,10-i) 

  4. end 

  5. print(avm:value()) -- 输出 4 2.4720... 

  • tnt.AUCMeter(self)

    对于二分类问题计算Area Under Curve (AUC).

    AUCMeter操作的变量是1D的tensor

  • tnt.ConfusionMeter(self,k[,nirmalized])

    多类之间的混淆矩阵,注意不是多类多标签问题,多标签是指一个类的实例可能分配多个标签,这类问题参见tnt.MultiLabelConfusionMeter

    初始化的时候,需要指定类别数k,normalized指定是否将confuse matrix 归一化,归一化之后输出的是百分比,否则是数值

    add(output,target) 输入都是的tensor,这里为什么每次都是N个样本一起输入呢?这是因为往往训练模型都是Batch模式处理的,target可以是N的tensor,每个值表示对应类别标号,也可以时NK的tensor表示类别的one-hot vector

    value()返回K
    K的混淆矩阵行表示groundtruth,列表示predicted targets

  • tnt.mAPMeter(self)

    统计所有类别之间的平均正确率,和APMeter参数完全一致,不同的时value()返回的是多个类别总的正确率

  • tnt.MovingAverageValueMeter(self,windowsize)

    该meter和AverageValueMeter非常类似,输入的也是number,不同在于他统计的不是所有的number的均值和方差,而是往前windowsize时间窗内的numbers的均值和方差,windowsize在初始化时需要指定

  • tnt.MultiLabelConfusionMeter(self,k[,normalized])

    多类多标签混淆矩阵,这个没接触过,不知道理解对不对,先放这吧,需要的时候再看

The tnt.MultiLabelConfusionMeter constructs a confusion matrix for multi- label, multi-class classification problems. In constructing the confusion matrix, the number of positive predictions is assumed to be equal to the number of positive labels in the ground-truth. Correct predictions (that is, labels in the prediction set that are also in the ground-truth set) are added to the diagonal of the confusion matrix. Incorrect predictions (that is, labels in the prediction set that are not in the ground-truth set) are equally divided over all non-predicted labels in the ground-truth set.

At initialization time, the k parameter that indicates the number of classes in the classification problem under consideration must be specified. Additionally, an optional parameter normalized (default = false) may be specified that determines whether or not the confusion matrix is normalized (that is, it contains percentages) or not (that is, it contains counts).

The add(output, target) method takes as input an NxK tensor output that contains the output scores obtained from the model for N examples and K classes, and a corresponding NxK-tensor target that provides the targets for the N examples using one-hot vectors (that is, vectors that contain only zeros and a single one at the location of the target value to be encoded).

  • tnt.ClassErrorMeter(self[,topk][,accuracy])

    参数: topk = table

    accuracy = boolean

    该meter用于统计分类误差,topk是一个table指定分别统计前k类预测误差,如ImageNet Competition中的Top5类误差,accuracy表示返回的是正确了还是错误率,accuracy=true,返回的就是1-error

    add(output,target),output是一个的tensor,target可以使一个N的tensor也可以是一个的tensor,参考之前的AUCMeter

    value()返回的时topk误差,value(k)返回的是第topk类误差

  • tnt.TimeMeter(self[,unit])

    这个Meter用于统计events之间的时间,也可以用来统计batch数据的平均处理数据。她很特别!

    unit在初始的时候给定,是一个布尔值,默认false,当设置为true时,返回值将会被incUnit()值平均,计算平均时间消耗。

    tnt.TimeMeter提供的方法有:

    • reset() 重置timer,unit counter

    • stop() stop the timer

    • resume() 唤醒timer

    • incUnit() uint+1

    • value() 返回从reset()到现在的时间消耗

  • tnt.PrecisionAtKMeter(self[,topk][,dim][,online])

待补充
  • tnt.RecallMeter(self[,threshold][,preclass])

    统计threshold下的召回率,threshold是一个table类型,每个元素是一个阈值,默认值为0.5. perclass是一个布尔值,表示是单独统计每一类的召回率还是统计整个召回率,默认值是false

    add(output,target) output是N*K的概率矩阵,行和为1;target是NK的二值矩阵,不一定行和为1,如{0,1,0,1}

    value()返回的是table值,对应的是threshold table中指定阈值下的召回率,如果perclass = true,那么table的每个元素就是一个table

  • tnt.PrecisionMeter(self[,threshold][,perclass])

    参考RecallMeter,这里计算的是正确率

  • tnt.NDCGMeter(self[,K])

    计算normalized discounted cumulative gain,没使用过。。。。

tnt.Log

Log是一个由sting key索引的table,这些keys必须在构造函数中指定,有一个特殊的键 __status__可以在log:status()函数中设置用于记录一些基本的messages

Log中提供的一些closures以及对应attached events

* onSet(log,key,value) 对应着给键赋值 log:set{}

* onGet(log,key) 对应着读取key对应的值 log:get()

* onFlush(log) 对应着清空log log:flush()

* onClose(log) 对应log:close() 关闭log

示例:

  1. tnt = require'torchnet' 

  2. logtext = require 'torchnet.log.view.text' 

  3. logstatus = require 'torchnet.log.view.status' 

  4. log = tnt.log{ 

  5. keys = {'loss','accuracy'} 

  6. onFlush = { 

  7. -- write out all keys in "log" file 

  8. logtext{filename='log.txt', keys={"loss", "accuracy"}, format={"%10.5f", "%3.2f"}}, 

  9. -- write out loss in a standalone file 

  10. logtext{filename='loss.txt', keys={"loss"}}, 

  11. -- print on screen too 

  12. logtext{keys={"loss", "accuracy"}}, 

  13. }, 

  14. onSet = { 

  15. -- add status to log 

  16. logstatus{filename='log.txt'}, 

  17. -- print status to screen 

  18. logstatus{}, 






  19. -- set values 

  20. log:set{ 

  21. loss = 0.1, 

  22. accuracy = 97 




  23. -- write some info 

  24. log:status("hello world") 


  25. -- flush out log 

  26. log:flush() 



后面我们来看一个具体的例子,以VGG16为例实现一个Siamese CNN网络计算patch之间的相似度


torchnet package (2)的更多相关文章

  1. torchnet package (1)

    torchnet package (1) torchnet torchnet torchnet是用于torch的代码复用和模块化编程的框架,主要包含四个类 Dataset 以不同的方式对数据进行预处理 ...

  2. torchnet+VGG16计算patch之间相似度

    torchnet+VGG16计算patch之间相似度 torch VGG16 similarity 本来打算使用VGG实现siamese CNN的,但是没想明白怎么使用torchnet对模型进行微调. ...

  3. NPM (node package manager) 入门 - 基础使用

    什么是npm ? npm 是 nodejs 的包管理和分发工具.它可以让 javascript 开发者能够更加轻松的共享代码和共用代码片段,并且通过 npm 管理你分享的代码也很方便快捷和简单. 截至 ...

  4. npm package.json属性详解

    概述 本文档是自己看官方文档的理解+翻译,内容是package.json配置里边的属性含义.package.json必须是一个严格的json文件,而不仅仅是js里边的一个对象.其中很多属性可以通过np ...

  5. 关于Visual Studio 未能加载各种Package包的解决方案

    问题: 打开Visual Studio 的时候,总提示未能加载相应的Package包,有时候还无法打开项目,各种提示 解决方案: 进入用户目录 C:\Users\用户名\AppData\Local\M ...

  6. SSIS 包部署 Package Store 后,在 IS 中可以执行,AGENT 执行却报错

    可以执行 SSIS Package ,证明用 SSIS Package 的账户是可以执行成功的.SQL Server Agent 默认指定账号是 Network Service. 那么可以尝试一下将 ...

  7. 如何使用yum 下载 一个 package ?如何使用 yum install package 但是保留 rpm 格式的 package ? 或者又 如何通过yum 中已经安装的package 导出它,即yum导出rpm?

    注意 RHEL5 和 RHEL6 的不同 How to use yum to download a package without installing it Solution Verified - ...

  8. [转]安装 SciTE 报错 No package ‘gtk+-2.0′ found

    centos 记事本,有时候感觉不够用,或者 出毛病,打不开文件 然后决定安装个其他的记事本,  找来找去, 感觉 SciTE 还可以,于是下载源码编译安装,结果 No package ‘gtk+-2 ...

  9. ERROR ITMS-90167: "No .app bundles found in the package"错误

    ERROR ITMS-90167: "No .app bundles found in the package" 出现如上错误请查检以下2个方向: 1.macOS Sierra 1 ...

随机推荐

  1. jenkins之升级

    首先查看系统war包放置的位置 rpm -ql jenkins 下载一个war包 下载地址 https://mirrors.tuna.tsinghua.edu.cn/jenkins/war/2.61/ ...

  2. 利用Python进行端口扫描

    利用Python进行端口扫描 - Dahlhin - 博客园 https://www.cnblogs.com/dachenzi/p/8676104.html Python实现对一个网络段扫描及端口扫描 ...

  3. tomcat设置编码格式utf8

    利用request.setCharacterEncoding("UTF-8");来设置Tomcat接收请求的编码格式,只对POST方式提交的数据有效,对GET方式提交的数据无效! ...

  4. IOS #ifdef 的那些事儿

    版权声明:本文为博主原创文章.未经博主同意不得转载. https://blog.csdn.net/u012884714/article/details/25188685 格式有点乱,整了几次都整只是来 ...

  5. C#知识点备忘

    1.结构体不能用判断符号==判断是否为null,结构体是值类型,不论采用new与否,结构体中的值类型都已经赋了初值. 2.整数相除: a=; b=: c=a/b; 结果c= 如果想得到double型需 ...

  6. c primer plus(五版)编程练习-第八章编程练习

    1.设计一个程序,统计从输入到文件结尾为止的字符数. #include<stdio.h> int main(void){ int ch; int i; i=; while((ch = ge ...

  7. PAT 1080 Graduate Admission[排序][难]

    1080 Graduate Admission(30 分) It is said that in 2011, there are about 100 graduate schools ready to ...

  8. mariadb10.1.13GTID实现主从复制

    ---恢复内容开始--- 环境:centos6.5       mariadb:10.1.13-MariaDB GTID:GTID是有服务器的UUID和事务序号组成的唯一事务序号 ---UUID:N ...

  9. 爬取51job职位信息之编码问题

    兴趣来潮,爬了下51job,但是遇到编码问题!以下是简单的一段代码 获取整个页面数据 # -*- coding:utf-8 -*- import requests import sysreload(s ...

  10. The adidas NMD Camo Singapore consists of four colorways

    Next within the popular selection of the adidas NMD Singapore is really a clean all-black form of th ...