Sample Classification Code of CIFAR-10 in Torch
Sample Classification Code of CIFAR-10 in Torch
from: http://torch.ch/blog/2015/07/30/cifar.html
require 'xlua'
require 'optim'
require 'nn'
require 'image'
local c = require 'trepl.colorize' opt = lapp[[
-s,--save (default "logs") subdirectory to save logs
-b,--batchSize (default 128) batch size
-r,--learningRate (default 1) learning rate
--learningRateDecay (default 1e-7) learning rate decay
--weightDecay (default 0.0005) weightDecay
-m,--momentum (default 0.9) momentum
--epoch_step (default 25) epoch step
--model (default vgg_bn_drop) model name
--max_epoch (default 300) maximum number of iterations
--backend (default nn) backend
--type (default cuda) cuda/float/cl
]] print(opt) do -- data augmentation module
local BatchFlip,parent = torch.class('nn.BatchFlip', 'nn.Module') function BatchFlip:__init()
parent.__init(self)
self.train = true
end function BatchFlip:updateOutput(input)
if self.train then
local bs = input:size()
local flip_mask = torch.randperm(bs):le(bs/)
for i=,input:size() do
if flip_mask[i] == then image.hflip(input[i], input[i]) end
end
end
self.output:set(input)
return self.output
end
end local function cast(t)
if opt.type == 'cuda' then
require 'cunn'
return t:cuda()
elseif opt.type == 'float' then
return t:float()
elseif opt.type == 'cl' then
require 'clnn'
return t:cl()
else
error('Unknown type '..opt.type)
end
end print(c.blue '==>' ..' configuring model')
local model = nn.Sequential()
model:add(nn.BatchFlip():float())
model:add(cast(nn.Copy('torch.FloatTensor', torch.type(cast(torch.Tensor())))))
model:add(cast(dofile('models/'..opt.model..'.lua')))
model:get().updateGradInput = function(input) return end if opt.backend == 'cudnn' then
require 'cudnn'
cudnn.benchmark=true
cudnn.convert(model:get(), cudnn)
end print(model) print(c.blue '==>' ..' loading data') -------------------------------------------------------------------------------------------
---------------------------- Load the Train and Test data -------------------------------
------------------------------------------------------------------------------------------- local trsize =
local tesize =
-- load dataset
trainData = {
data = torch.Tensor(, ),
labels = torch.Tensor(),
size = function() return trsize end
}
local trainData = trainData
for i = , do
local subset = torch.load('cifar-10-batches-t7/data_batch_' .. (i+) .. '.t7', 'ascii')
trainData.data[{ {i*+, (i+)*} }] = subset.data:t()
trainData.labels[{ {i*+, (i+)*} }] = subset.labels
end
trainData.labels = trainData.labels + local subset = torch.load('cifar-10-batches-t7/test_batch.t7', 'ascii')
testData = {
data = subset.data:t():double(),
labels = subset.labels[]:double(),
size = function() return tesize end
}
local testData = testData
testData.labels = testData.labels + -- resize dataset (if using small version)
trainData.data = trainData.data[{ {,trsize} }]
trainData.labels = trainData.labels[{ {,trsize} }] testData.data = testData.data[{ {,tesize} }]
testData.labels = testData.labels[{ {,tesize} }] -- reshape data
trainData.data = trainData.data:reshape(trsize,,,)
testData.data = testData.data:reshape(tesize,,,) ----------------------------------------------------------------------------------
----------------------------------------------------------------------------------
-- preprocessing data (color space + normalization)
----------------------------------------------------------------------------------
----------------------------------------------------------------------------------
print '<trainer> preprocessing data (color space + normalization)'
collectgarbage() -- preprocess trainSet
local normalization = nn.SpatialContrastiveNormalization(, image.gaussian1D())
for i = ,trainData:size() do
xlua.progress(i, trainData:size())
-- rgb -> yuv
local rgb = trainData.data[i]
local yuv = image.rgb2yuv(rgb)
-- normalize y locally:
yuv[] = normalization(yuv[{{}}])
trainData.data[i] = yuv
end
-- normalize u globally:
local mean_u = trainData.data:select(,):mean()
local std_u = trainData.data:select(,):std()
trainData.data:select(,):add(-mean_u)
trainData.data:select(,):div(std_u)
-- normalize v globally:
local mean_v = trainData.data:select(,):mean()
local std_v = trainData.data:select(,):std()
trainData.data:select(,):add(-mean_v)
trainData.data:select(,):div(std_v) trainData.mean_u = mean_u
trainData.std_u = std_u
trainData.mean_v = mean_v
trainData.std_v = std_v -- preprocess testSet
for i = ,testData:size() do
xlua.progress(i, testData:size())
-- rgb -> yuv
local rgb = testData.data[i]
local yuv = image.rgb2yuv(rgb)
-- normalize y locally:
yuv[{}] = normalization(yuv[{{}}])
testData.data[i] = yuv
end
-- normalize u globally:
testData.data:select(,):add(-mean_u)
testData.data:select(,):div(std_u)
-- normalize v globally:
testData.data:select(,):add(-mean_v)
testData.data:select(,):div(std_v) ----------------------------------------------------------------------------------
----------------------------- END --------------------------------------------- trainData.data = trainData.data:float()
testData.data = testData.data:float() confusion = optim.ConfusionMatrix() print('Will save at '..opt.save)
paths.mkdir(opt.save)
testLogger = optim.Logger(paths.concat(opt.save, 'test.log'))
testLogger:setNames{'% mean class accuracy (train set)', '% mean class accuracy (test set)'}
testLogger.showPlot = false parameters,gradParameters = model:getParameters() print(c.blue'==>' ..' setting criterion')
criterion = cast(nn.CrossEntropyCriterion()) print(c.blue'==>' ..' configuring optimizer')
optimState = {
learningRate = opt.learningRate,
weightDecay = opt.weightDecay,
momentum = opt.momentum,
learningRateDecay = opt.learningRateDecay,
} function train()
model:training()
epoch = epoch or -- drop learning rate every "epoch_step" epochs
if epoch % opt.epoch_step == then optimState.learningRate = optimState.learningRate/ end print(c.blue '==>'.." online epoch # " .. epoch .. ' [batchSize = ' .. opt.batchSize .. ']') local targets = cast(torch.FloatTensor(opt.batchSize))
local indices = torch.randperm(trainData.data:size()):long():split(opt.batchSize)
-- remove last element so that all the batches have equal size
indices[#indices] = nil local tic = torch.tic()
for t,v in ipairs(indices) do
xlua.progress(t, #indices) local inputs = trainData.data:index(,v)
targets:copy(trainData.labels:index(,v)) local feval = function(x)
if x ~= parameters then parameters:copy(x) end
gradParameters:zero() local outputs = model:forward(inputs)
local f = criterion:forward(outputs, targets)
local df_do = criterion:backward(outputs, targets)
model:backward(inputs, df_do) confusion:batchAdd(outputs, targets) return f,gradParameters
end
optim.sgd(feval, parameters, optimState)
end confusion:updateValids()
print(('Train accuracy: '..c.cyan'%.2f'..' %%\t time: %.2f s'):format(
confusion.totalValid * , torch.toc(tic))) train_acc = confusion.totalValid * confusion:zero()
epoch = epoch +
end function test()
-- disable flips, dropouts and batch normalization
model:evaluate()
print(c.blue '==>'.." testing")
local bs =
for i=,testData.data:size(),bs do
local outputs = model:forward(testData.data:narrow(,i,bs))
confusion:batchAdd(outputs, testData.labels:narrow(,i,bs))
end confusion:updateValids()
print('Test accuracy:', confusion.totalValid * ) if testLogger then
paths.mkdir(opt.save)
testLogger:add{train_acc, confusion.totalValid * }
testLogger:style{'-','-'}
testLogger:plot() if paths.filep(opt.save..'/test.log.eps') then
local base64im
do
os.execute(('convert -density 200 %s/test.log.eps %s/test.png'):format(opt.save,opt.save))
os.execute(('openssl base64 -in %s/test.png -out %s/test.base64'):format(opt.save,opt.save))
local f = io.open(opt.save..'/test.base64')
if f then base64im = f:read'*all' end
end local file = io.open(opt.save..'/report.html','w')
file:write(([[
<!DOCTYPE html>
<html>
<body>
<title>%s - %s</title>
<img src="data:image/png;base64,%s">
<h4>optimState:</h4>
<table>
]]):format(opt.save,epoch,base64im))
for k,v in pairs(optimState) do
if torch.type(v) == 'number' then
file:write('<tr><td>'..k..'</td><td>'..v..'</td></tr>\n')
end
end
file:write'</table><pre>\n'
file:write(tostring(confusion)..'\n')
file:write(tostring(model)..'\n')
file:write'</pre></body></html>'
file:close()
end
end -- save model every 50 epochs
if epoch % == then
local filename = paths.concat(opt.save, 'model.net')
print('==> saving model to '..filename)
torch.save(filename, model:get():clearState())
end confusion:zero()
end for i=,opt.max_epoch do
train()
test()
end
the original version code:
why they written like this ?
It can not run ...
Sample Classification Code of CIFAR-10 in Torch的更多相关文章
- 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network (CNN)| CIFAR 10 TensorFlow
原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...
- code::blocks(版本10.05) 配置opencv2.4.3
(1)首先下载opencv2.4.3, 解压缩到D:下: (2)配置code::blocks, 具体操作如下: 第一步, 配置compiler, 操作步骤为Settings -> Compil ...
- code::blocks(版本号10.05) 配置opencv2.4.3
(1)首先下载opencv2.4.3, 解压缩到D:下: (2)配置code::blocks, 详细操作例如以下: 第一步, 配置compiler, 操作步骤为Settings -> Comp ...
- DL Practice:Cifar 10分类
Step 1:数据加载和处理 一般使用深度学习框架会经过下面几个流程: 模型定义(包括损失函数的选择)——>数据处理和加载——>训练(可能包括训练过程可视化)——>测试 所以自己写代 ...
- 【神经网络与深度学习】基于Windows+Caffe的Minst和CIFAR—10训练过程说明
Minst训练 我的路径:G:\Caffe\Caffe For Windows\examples\mnist 对于新手来说,初步完成环境的配置后,一脸茫然.不知如何跑Demo,有么有!那么接下来的教 ...
- Oracle Applications Multiple Organizations Access Control for Custom Code
档 ID 420787.1 White Paper Oracle Applications Multiple Organizations Access Control for Custom Code ...
- UWP深入学习六:Build better apps: Windows 10 by 10 development series
Promotion in the Windows Store In this article, I walk through how to Give your Store listing a mak ...
- Removing Columns 分类: 贪心 CF 2015-08-08 16:10 10人阅读 评论(0) 收藏
Removing Columns time limit per test 2 seconds memory limit per test 256 megabytes input standard in ...
- CV code references
转:http://www.sigvc.org/bbs/thread-72-1-1.html 一.特征提取Feature Extraction: SIFT [1] [Demo program][SI ...
随机推荐
- La Vie en rose (模拟)
#include<bits/stdc++.h> using namespace std; ; ; int T, n, m; char str1[maxm], str2[maxn]; int ...
- uvalive 3353 Optimal Bus Route Design
题意: 给出n个点,以及每个点到其他点的有向距离,要求设计线路使得每一个点都在一个环中,如果设计的线路拥有最小值,那么这个线路就是可选的.输出这个最小值或者说明最小线路不存在. 思路: 在DAG的最小 ...
- 数据集成工具Kettle、Sqoop、DataX的比较
数据集成工具很多,下面是几个使用比较多的开源工具. 1.阿里开源软件:DataX DataX 是一个异构数据源离线同步工具,致力于实现包括关系型数据库(MySQL.Oracle等).H ...
- Druid-目前最好的连接池
https://blog.csdn.net/youanyyou/article/details/78992979 Druid是什么Druid是阿里开源的连接池,是Java语言中最好的数据库连接池.Dr ...
- 20165316 2017-2018-2《Java程序设计》课程总结
20165316 2017-2018-2<Java程序设计>课程总结 一.每周作业链接汇总 1. 预备作业一:我期望的师生关系 20165316 我期望的师生关系 摘要: 典型老师 师生关 ...
- docker能用来干嘛
http://blog.csdn.net/wangtaoking1/article/details/44340445 什么是Docker Docker 是一个开源项目,诞生于 2013 年初,最初 ...
- 10分钟看懂!基于Zookeeper的分布式锁
实现分布式锁目前有三种流行方案,分别为基于数据库.Redis.Zookeeper的方案,其中前两种方案网络上有很多资料可以参考,本文不做展开.我们来看下使用Zookeeper如何实现分布式锁. 什么是 ...
- word2vec原理(一) CBOW与Skip-Gram模型基础——转载自刘建平Pinard
转载来源:http://www.cnblogs.com/pinard/p/7160330.html word2vec是google在2013年推出的一个NLP工具,它的特点是将所有的词向量化,这样词与 ...
- The logback manual #02# Architecture
索引 Logback's architecture Logger, Appenders and Layouts Effective Level(有效等级)又名Level Inheritance Ret ...
- tp5 本地安装和调试的问题
安装的时候用官方下载的包或者用composer指定版本号,不要用git,会安装最新的包. 本地配置域名的时候出错,要不就是500要不就是找不到文件,原因是目录路径里的反斜杆加字母t被转义了,改成正斜杠 ...