torchnet package (2)
torchnet package (2)
Dataset Iterators
尽管是用for loop语句很容易处理Dataset,但有时希望以on-the-fly manner或者在线程中读取数据,这时候Dataset Iterator就是个好的选择
注意,iterators是用于特殊情况的,一般情况下还是使用Dataset比较好
Iteartor 的两个主要方法:
* run() 返回一个Lua 迭代器,也可以使用()操作符,因为iterator源码中定义了__call事件
* exec(funcname,...) 在指定的dataset上执行funcname方法,funcname是dataset自己的方法,比如size
tnt.DatasetIterator(self,dataset[,perm][,filter][,transform])
The default dataset iterator
perm(idx), 实现shuffle功能,即对idx进行变换,更复杂的变换可以使用ShuffleDataset
filter(sample), 闭包函数,筛选样本是否用于迭代,返回bool值
transform(sample),闭包函数,实现对样本的变换,更复杂的变换可以结合TransformDataset和transform.compose等实现
- ldata = tnt.ListData{list=torch.range(1,10):long(),load = function(x) return {x,x+1} end}
- dIter = tnt.DatasetIterator{dataset = ldata,filter = function(x) if x[1]<2 then return false else return true end end}
- for v in dIter:run()
- print(v)
- end
tnt.ParallelDatasetIterator(self[,init],closure,nthread[,perm][,filter][,transform][,ordered])
这个才是迭代器的重点,用于以多线程方式迭代数据。
The purpose of this class is to have a zero pre-processing cose. when reading datasets on the fly from disk(not loading thenm fully in memory), or performing complex pre-processing this canbe of interest.
nthreads 指定了线程的个数
init(threadid) 闭包函数,指定了线程threadid的初始化工作,如果啥都不做可以省略
closure(threadid) 每个线程的job,返回的必须时tnt.Dataset的一个实例
perm(idx) 用于shuffle
filter(sample) 闭包函数,指定哪些样本不用于迭代
transform(sample) 对样本进行变换,在filter之前执行
order 线程之间数据的处理是否有序,主要是为了程序的可重现性,当order=true时,多次执行程序,顺序是相同的
- tnt=require'torchnet'
- local list=torch.Tensor{{2,2},{2,2},{2,2},{2,2}}:long()
- ldata = tnt.ListDataset{list=list,load=function(x) return torch.Tensor(x[1],x[2]) end}
- local bdata = tnt.BatchDataset{batchsize=2,dataset = tnt.TransformDataset{dataset = ldata,transform=function(x) return 2*x end}}
- Padata = tnt.ParallelDatasetIterator{
- nthread = 4,
- init = function(tid)
- print ('init thread id: '.. tid)
- tnt=require'torchnet'
- end,
- closure = function(tid)
- print('closure of threadid: '.. tid)
- return bdata
- end
- }
尤其需要注意的是,closure中的所有upvalues都必须是可序列化的,最好是避免使用upvalues,并保证closure中使用的package都在init中require
tnt.Engine
在网络训练的过程中,都是计算前向误差,误差反传,更新权重这些过程,只是模型,数据和评价函数不同而已,所以Engine给训练过程提供了一个模板,该模板建立了model,DatasetIterator,Criterion和Meter之间的联系
engine=tnt.Engine()包含两个主要方法
* engine:train() 在数据集上训练数据
* engine:test() 评估模型,可选
Engine不仅实现了训练和评估的一般模板,还提供了许多接口,用于控制训练过程
tnt.SGDEngine
SGDEngine 模块在train过程中使用Stochastic Gradient Descent方法训练,模块包含数据采样,前向传递,反向传递,参数更新等,还有一些钩子函数
hooks = {
['onStart'] = function() end, --用于训练开始前的设置和初始化
['onStartEpoch'] = function() end, -- 每一个epoch前的操作
['onSample'] = function() end, -- 每次采样一个样本之后的操作
['onForward'] = function() end, -- 在model:forward()之后的操作
['onForwardCriterion'] = function() end, -- 前向计算损失函数之后的操作
['onBackwardCriterion'] = function() end, -- 反向计算损失误差之后的操作
['onBackward'] = function() end, -- 反向传递误差之后的操作
['onUpdate'] = function() end, -- 权重参数更新之后的操作
['onEndEpoch'] = function() end, -- 每一个epoch结束时的操作
['onEnd'] = function() end, -- 整个训练过程结束后的收拾现场
}
可以发现Engine给的hook函数还是很全面的,几乎训练过程的每一个节点都允许用户制定操作,使用hook函数
- local engine = SGDEngine()
- local meter = tnt.AverageValueMeter()
- engine.hooks.onStartEpoch = function(state) meter:reset() end
一般而言,训练过程最少应该知道训练模型,损失函数,数据和学习率,这里学习方法已经知道了SGD,Engine用到的数据是tnt.DatasetIterator类型的。 评估过程只需要数据和模型就可以了
外部可以通过state变量与Engine训练过程交互
state = {
['network'] = network, --设置了model
['criterion'] = criterion, -- 设置损失函数
['iterator'] = iterator, -- 数据迭代器
['lr'] = lr, -- 学习率
['lrcriterion'] = lrcriterion, --
['maxepoch'] = maxepoch, --最大epoch数
['sample'] = {}, -- 当前采集的样本,可以在onSample中通过该阈值查看采样样本
['epoch'] = 0 , -- 当前的epoch
['t'] = 0, -- 已经训练样本的个数
['training'] = true -- 训练过程
}
评估时需要指定:
state = {
['netwrok'] = network
['iterator'] = iterator
['criterion'] = criterion
}
tnt.OptimEngine
这个方法和SGDEngine的最大的区别在于封装了optim中的多种优化方法。在训练开始的时候,engine会通过getParameters获取model的参数
train需要附加两个量:optimMethod 优化方法,比如optim.sgd
config 优化方法对应的参数
Example:
- local engine = tnt.OptimEngine{
- network = network,
- criterion=criterion,
- iterator = iterator,
- optimMethod = optim.sgd,
- config = {
- learningRate = 0.1,
- momentum = 0.9,
- },
- }
tnt.Meter
和Engine配合使用,用于measure the model.
几乎所有的meters都会有3个方法:
* add() 给待统计的meter添加一个观测值,其输入参数一般形式为(output,value),output为model的输出,target为真实值
* value() 获得待统计的meter的当前值
* reset() 重新计数
Meter的使用示例:
- local meter = tnt.<Measure>Meter() -- <Measure> 可以选择具体的度量
- for state,event in tnt.<Optimization>Engine:train{ --定义Engine
- network = network,
- criterion=criterion,
- iterator=iterator,
- } do
- if state == 'start-epoch' then
- meter:reset() -- reset meter
- elseif state == 'forward-criterion' then
- meter:add(state.network.output,sample.target)
- elseif state == 'end-epoch' then
- print('value of meter:) .. meter:value())
- end
- end
tnt.APMeter(self)
评估每一类的平均正确率
APMeter的操作对象是一个的Tensor,表示N个样本对应在K类中的值,另外可选的一个的 Tensor表示每个样本的权重
- target = torch.Tensor{
- {0,0,0,1},{0,0,1,0},{0,1,0,0},{1,0,0,0},{1,0,0,0}}
- apm = tnt.APMeter()
- for i=1,5 do
- apm:add{output=torch.rand(1,4),target=target[i]:size(1,4)} -- 注意N*K的Tensor
- end
- print(apm:value())
tnt.AverageValueMeter(self)
用于统计任意添加的变量的方差和均值,可以用来测量平均损失等
add()的输入必须时number类型,另外在add的时候可以有一个可选的参数n,表示对应值的权重
- avm = tnt.AverageValueMeter()
- for i=1,10 do
- avm:add(i,10-i)
- end
- print(avm:value()) -- 输出 4 2.4720...
tnt.AUCMeter(self)
对于二分类问题计算Area Under Curve (AUC).
AUCMeter操作的变量是1D的tensortnt.ConfusionMeter(self,k[,nirmalized])
多类之间的混淆矩阵,注意不是多类多标签问题,多标签是指一个类的实例可能分配多个标签,这类问题参见tnt.MultiLabelConfusionMeter
初始化的时候,需要指定类别数k,normalized指定是否将confuse matrix 归一化,归一化之后输出的是百分比,否则是数值
add(output,target) 输入都是的tensor,这里为什么每次都是N个样本一起输入呢?这是因为往往训练模型都是Batch模式处理的,target可以是N的tensor,每个值表示对应类别标号,也可以时NK的tensor表示类别的one-hot vector
value()返回KK的混淆矩阵行表示groundtruth,列表示predicted targetstnt.mAPMeter(self)
统计所有类别之间的平均正确率,和APMeter参数完全一致,不同的时value()返回的是多个类别总的正确率tnt.MovingAverageValueMeter(self,windowsize)
该meter和AverageValueMeter非常类似,输入的也是number,不同在于他统计的不是所有的number的均值和方差,而是往前windowsize时间窗内的numbers的均值和方差,windowsize在初始化时需要指定tnt.MultiLabelConfusionMeter(self,k[,normalized])
多类多标签混淆矩阵,这个没接触过,不知道理解对不对,先放这吧,需要的时候再看
The tnt.MultiLabelConfusionMeter constructs a confusion matrix for multi- label, multi-class classification problems. In constructing the confusion matrix, the number of positive predictions is assumed to be equal to the number of positive labels in the ground-truth. Correct predictions (that is, labels in the prediction set that are also in the ground-truth set) are added to the diagonal of the confusion matrix. Incorrect predictions (that is, labels in the prediction set that are not in the ground-truth set) are equally divided over all non-predicted labels in the ground-truth set.
At initialization time, the k parameter that indicates the number of classes in the classification problem under consideration must be specified. Additionally, an optional parameter normalized (default = false) may be specified that determines whether or not the confusion matrix is normalized (that is, it contains percentages) or not (that is, it contains counts).
The add(output, target) method takes as input an NxK tensor output that contains the output scores obtained from the model for N examples and K classes, and a corresponding NxK-tensor target that provides the targets for the N examples using one-hot vectors (that is, vectors that contain only zeros and a single one at the location of the target value to be encoded).
tnt.ClassErrorMeter(self[,topk][,accuracy])
参数: topk = table
accuracy = boolean
该meter用于统计分类误差,topk是一个table指定分别统计前k类预测误差,如ImageNet Competition中的Top5类误差,accuracy表示返回的是正确了还是错误率,accuracy=true,返回的就是1-error
add(output,target),output是一个的tensor,target可以使一个N的tensor也可以是一个的tensor,参考之前的AUCMeter
value()返回的时topk误差,value(k)返回的是第topk类误差tnt.TimeMeter(self[,unit])
这个Meter用于统计events之间的时间,也可以用来统计batch数据的平均处理数据。她很特别!
unit在初始的时候给定,是一个布尔值,默认false,当设置为true时,返回值将会被incUnit()值平均,计算平均时间消耗。
tnt.TimeMeter提供的方法有:reset() 重置timer,unit counter
stop() stop the timer
resume() 唤醒timer
incUnit() uint+1
value() 返回从reset()到现在的时间消耗
tnt.PrecisionAtKMeter(self[,topk][,dim][,online])
待补充
tnt.RecallMeter(self[,threshold][,preclass])
统计threshold下的召回率,threshold是一个table类型,每个元素是一个阈值,默认值为0.5. perclass是一个布尔值,表示是单独统计每一类的召回率还是统计整个召回率,默认值是false
add(output,target) output是N*K的概率矩阵,行和为1;target是NK的二值矩阵,不一定行和为1,如{0,1,0,1}
value()返回的是table值,对应的是threshold table中指定阈值下的召回率,如果perclass = true,那么table的每个元素就是一个tabletnt.PrecisionMeter(self[,threshold][,perclass])
参考RecallMeter,这里计算的是正确率tnt.NDCGMeter(self[,K])
计算normalized discounted cumulative gain,没使用过。。。。
tnt.Log
Log是一个由sting key索引的table,这些keys必须在构造函数中指定,有一个特殊的键 __status__可以在log:status()函数中设置用于记录一些基本的messages
Log中提供的一些closures以及对应attached events
* onSet(log,key,value) 对应着给键赋值 log:set{}
* onGet(log,key) 对应着读取key对应的值 log:get()
* onFlush(log) 对应着清空log log:flush()
* onClose(log) 对应log:close() 关闭log
示例:
- tnt = require'torchnet'
- logtext = require 'torchnet.log.view.text'
- logstatus = require 'torchnet.log.view.status'
- log = tnt.log{
- keys = {'loss','accuracy'}
- onFlush = {
- -- write out all keys in "log" file
- logtext{filename='log.txt', keys={"loss", "accuracy"}, format={"%10.5f", "%3.2f"}},
- -- write out loss in a standalone file
- logtext{filename='loss.txt', keys={"loss"}},
- -- print on screen too
- logtext{keys={"loss", "accuracy"}},
- },
- onSet = {
- -- add status to log
- logstatus{filename='log.txt'},
- -- print status to screen
- logstatus{},
- }
- }
- -- set values
- log:set{
- loss = 0.1,
- accuracy = 97
- }
- -- write some info
- log:status("hello world")
- -- flush out log
- log:flush()
后面我们来看一个具体的例子,以VGG16为例实现一个Siamese CNN网络计算patch之间的相似度
torchnet package (2)的更多相关文章
- torchnet package (1)
torchnet package (1) torchnet torchnet torchnet是用于torch的代码复用和模块化编程的框架,主要包含四个类 Dataset 以不同的方式对数据进行预处理 ...
- torchnet+VGG16计算patch之间相似度
torchnet+VGG16计算patch之间相似度 torch VGG16 similarity 本来打算使用VGG实现siamese CNN的,但是没想明白怎么使用torchnet对模型进行微调. ...
- NPM (node package manager) 入门 - 基础使用
什么是npm ? npm 是 nodejs 的包管理和分发工具.它可以让 javascript 开发者能够更加轻松的共享代码和共用代码片段,并且通过 npm 管理你分享的代码也很方便快捷和简单. 截至 ...
- npm package.json属性详解
概述 本文档是自己看官方文档的理解+翻译,内容是package.json配置里边的属性含义.package.json必须是一个严格的json文件,而不仅仅是js里边的一个对象.其中很多属性可以通过np ...
- 关于Visual Studio 未能加载各种Package包的解决方案
问题: 打开Visual Studio 的时候,总提示未能加载相应的Package包,有时候还无法打开项目,各种提示 解决方案: 进入用户目录 C:\Users\用户名\AppData\Local\M ...
- SSIS 包部署 Package Store 后,在 IS 中可以执行,AGENT 执行却报错
可以执行 SSIS Package ,证明用 SSIS Package 的账户是可以执行成功的.SQL Server Agent 默认指定账号是 Network Service. 那么可以尝试一下将 ...
- 如何使用yum 下载 一个 package ?如何使用 yum install package 但是保留 rpm 格式的 package ? 或者又 如何通过yum 中已经安装的package 导出它,即yum导出rpm?
注意 RHEL5 和 RHEL6 的不同 How to use yum to download a package without installing it Solution Verified - ...
- [转]安装 SciTE 报错 No package ‘gtk+-2.0′ found
centos 记事本,有时候感觉不够用,或者 出毛病,打不开文件 然后决定安装个其他的记事本, 找来找去, 感觉 SciTE 还可以,于是下载源码编译安装,结果 No package ‘gtk+-2 ...
- ERROR ITMS-90167: "No .app bundles found in the package"错误
ERROR ITMS-90167: "No .app bundles found in the package" 出现如上错误请查检以下2个方向: 1.macOS Sierra 1 ...
随机推荐
- pta习题集5-16 朋友圈
某学校有N个学生,形成M个俱乐部.每个俱乐部里的学生有着一定相似的兴趣爱好,形成一个朋友圈.一个学生可以同时属于若干个不同的俱乐部.根据"我的朋友的朋友也是我的朋友"这个推论可以得 ...
- python重定向sys.stdin、sys.stdout和sys.stderr
转自:https://www.cnblogs.com/guyuyuan/p/6885448.html 标准输入.标准输出和错误输出. 标准输入:一般是键盘.stdin对象为解释器提供输入字符流,一般使 ...
- 记一次踩坑:使用ksoap-android时造成的okhttp依赖冲突问题
项目中需要调用webservice接口,android SDK中并没有直接访问webservice接口的方法,于是我引入了ksoap-android的jar包,来实现访问webservice接口.刚开 ...
- Scrapy框架(3)
一.如何提升scrapy框架的爬取效率 增加并发: 默认scrapy开启的并发线程为32个,可以适当进行增加.在settings配置文件中修改CONCURRENT_REQUESTS = 100,并发设 ...
- 大话存储4——RAID磁盘阵列
RAID是英文Redundant Array of Independent Disks(独立磁盘冗余阵列),简称磁盘阵列.下面将各个级别的RAID介绍如下. RAID0 条带化(Stripe)存储.理 ...
- git-【七】bug分支
在开发中,会经常碰到bug问题,那么有了bug就需要修复,在Git中,分支是很强大的,每个bug都可以通过一个临时分支来修复,修复完成后,合并分支,然后将临时的分支删除掉. 比如我在开发中接到一个40 ...
- HBase1.2.0增删改查Scala代码实现
增删改查工具类 class HbaseUtils { /** * 获取管理员对象 * * @param conf 对hbase client配置一些参数 * @return 返回hbase的HBase ...
- 3.5 Templates -- Binding Element Attributes(绑定元素属性)
一.概述 除了正常的文本,你可能还需要你的模板中包含的HTML元素的属性绑定到controller. 1. 例如,设想controller有一个属性包含一个图片的URL: <div id=&qu ...
- (16)Cocos2d-x 多分辨率适配完全解析
Overview 从Cocos2d-x 2.0.4开始,Cocos2d-x提出了自己的多分辨率支持方案,废弃了之前的retina相关设置接口,提出了design resolution概念. 3.0中有 ...
- HDU - 2844 Coins(多重背包+完全背包)
题意 给n个币的价值和其数量,问能组合成\(1-m\)中多少个不同的值. 分析 对\(c[i]*a[i]>=m\)的币,相当于完全背包:\(c[i]*a[i]<m\)的币则是多重背包,考虑 ...