Sample Classification Code of CIFAR-10 in Torch
Sample Classification Code of CIFAR-10 in Torch
from: http://torch.ch/blog/2015/07/30/cifar.html
require 'xlua'
require 'optim'
require 'nn'
require 'image'
local c = require 'trepl.colorize' opt = lapp[[
-s,--save (default "logs") subdirectory to save logs
-b,--batchSize (default 128) batch size
-r,--learningRate (default 1) learning rate
--learningRateDecay (default 1e-7) learning rate decay
--weightDecay (default 0.0005) weightDecay
-m,--momentum (default 0.9) momentum
--epoch_step (default 25) epoch step
--model (default vgg_bn_drop) model name
--max_epoch (default 300) maximum number of iterations
--backend (default nn) backend
--type (default cuda) cuda/float/cl
]] print(opt) do -- data augmentation module
local BatchFlip,parent = torch.class('nn.BatchFlip', 'nn.Module') function BatchFlip:__init()
parent.__init(self)
self.train = true
end function BatchFlip:updateOutput(input)
if self.train then
local bs = input:size()
local flip_mask = torch.randperm(bs):le(bs/)
for i=,input:size() do
if flip_mask[i] == then image.hflip(input[i], input[i]) end
end
end
self.output:set(input)
return self.output
end
end local function cast(t)
if opt.type == 'cuda' then
require 'cunn'
return t:cuda()
elseif opt.type == 'float' then
return t:float()
elseif opt.type == 'cl' then
require 'clnn'
return t:cl()
else
error('Unknown type '..opt.type)
end
end print(c.blue '==>' ..' configuring model')
local model = nn.Sequential()
model:add(nn.BatchFlip():float())
model:add(cast(nn.Copy('torch.FloatTensor', torch.type(cast(torch.Tensor())))))
model:add(cast(dofile('models/'..opt.model..'.lua')))
model:get().updateGradInput = function(input) return end if opt.backend == 'cudnn' then
require 'cudnn'
cudnn.benchmark=true
cudnn.convert(model:get(), cudnn)
end print(model) print(c.blue '==>' ..' loading data') -------------------------------------------------------------------------------------------
---------------------------- Load the Train and Test data -------------------------------
------------------------------------------------------------------------------------------- local trsize =
local tesize =
-- load dataset
trainData = {
data = torch.Tensor(, ),
labels = torch.Tensor(),
size = function() return trsize end
}
local trainData = trainData
for i = , do
local subset = torch.load('cifar-10-batches-t7/data_batch_' .. (i+) .. '.t7', 'ascii')
trainData.data[{ {i*+, (i+)*} }] = subset.data:t()
trainData.labels[{ {i*+, (i+)*} }] = subset.labels
end
trainData.labels = trainData.labels + local subset = torch.load('cifar-10-batches-t7/test_batch.t7', 'ascii')
testData = {
data = subset.data:t():double(),
labels = subset.labels[]:double(),
size = function() return tesize end
}
local testData = testData
testData.labels = testData.labels + -- resize dataset (if using small version)
trainData.data = trainData.data[{ {,trsize} }]
trainData.labels = trainData.labels[{ {,trsize} }] testData.data = testData.data[{ {,tesize} }]
testData.labels = testData.labels[{ {,tesize} }] -- reshape data
trainData.data = trainData.data:reshape(trsize,,,)
testData.data = testData.data:reshape(tesize,,,) ----------------------------------------------------------------------------------
----------------------------------------------------------------------------------
-- preprocessing data (color space + normalization)
----------------------------------------------------------------------------------
----------------------------------------------------------------------------------
print '<trainer> preprocessing data (color space + normalization)'
collectgarbage() -- preprocess trainSet
local normalization = nn.SpatialContrastiveNormalization(, image.gaussian1D())
for i = ,trainData:size() do
xlua.progress(i, trainData:size())
-- rgb -> yuv
local rgb = trainData.data[i]
local yuv = image.rgb2yuv(rgb)
-- normalize y locally:
yuv[] = normalization(yuv[{{}}])
trainData.data[i] = yuv
end
-- normalize u globally:
local mean_u = trainData.data:select(,):mean()
local std_u = trainData.data:select(,):std()
trainData.data:select(,):add(-mean_u)
trainData.data:select(,):div(std_u)
-- normalize v globally:
local mean_v = trainData.data:select(,):mean()
local std_v = trainData.data:select(,):std()
trainData.data:select(,):add(-mean_v)
trainData.data:select(,):div(std_v) trainData.mean_u = mean_u
trainData.std_u = std_u
trainData.mean_v = mean_v
trainData.std_v = std_v -- preprocess testSet
for i = ,testData:size() do
xlua.progress(i, testData:size())
-- rgb -> yuv
local rgb = testData.data[i]
local yuv = image.rgb2yuv(rgb)
-- normalize y locally:
yuv[{}] = normalization(yuv[{{}}])
testData.data[i] = yuv
end
-- normalize u globally:
testData.data:select(,):add(-mean_u)
testData.data:select(,):div(std_u)
-- normalize v globally:
testData.data:select(,):add(-mean_v)
testData.data:select(,):div(std_v) ----------------------------------------------------------------------------------
----------------------------- END --------------------------------------------- trainData.data = trainData.data:float()
testData.data = testData.data:float() confusion = optim.ConfusionMatrix() print('Will save at '..opt.save)
paths.mkdir(opt.save)
testLogger = optim.Logger(paths.concat(opt.save, 'test.log'))
testLogger:setNames{'% mean class accuracy (train set)', '% mean class accuracy (test set)'}
testLogger.showPlot = false parameters,gradParameters = model:getParameters() print(c.blue'==>' ..' setting criterion')
criterion = cast(nn.CrossEntropyCriterion()) print(c.blue'==>' ..' configuring optimizer')
optimState = {
learningRate = opt.learningRate,
weightDecay = opt.weightDecay,
momentum = opt.momentum,
learningRateDecay = opt.learningRateDecay,
} function train()
model:training()
epoch = epoch or -- drop learning rate every "epoch_step" epochs
if epoch % opt.epoch_step == then optimState.learningRate = optimState.learningRate/ end print(c.blue '==>'.." online epoch # " .. epoch .. ' [batchSize = ' .. opt.batchSize .. ']') local targets = cast(torch.FloatTensor(opt.batchSize))
local indices = torch.randperm(trainData.data:size()):long():split(opt.batchSize)
-- remove last element so that all the batches have equal size
indices[#indices] = nil local tic = torch.tic()
for t,v in ipairs(indices) do
xlua.progress(t, #indices) local inputs = trainData.data:index(,v)
targets:copy(trainData.labels:index(,v)) local feval = function(x)
if x ~= parameters then parameters:copy(x) end
gradParameters:zero() local outputs = model:forward(inputs)
local f = criterion:forward(outputs, targets)
local df_do = criterion:backward(outputs, targets)
model:backward(inputs, df_do) confusion:batchAdd(outputs, targets) return f,gradParameters
end
optim.sgd(feval, parameters, optimState)
end confusion:updateValids()
print(('Train accuracy: '..c.cyan'%.2f'..' %%\t time: %.2f s'):format(
confusion.totalValid * , torch.toc(tic))) train_acc = confusion.totalValid * confusion:zero()
epoch = epoch +
end function test()
-- disable flips, dropouts and batch normalization
model:evaluate()
print(c.blue '==>'.." testing")
local bs =
for i=,testData.data:size(),bs do
local outputs = model:forward(testData.data:narrow(,i,bs))
confusion:batchAdd(outputs, testData.labels:narrow(,i,bs))
end confusion:updateValids()
print('Test accuracy:', confusion.totalValid * ) if testLogger then
paths.mkdir(opt.save)
testLogger:add{train_acc, confusion.totalValid * }
testLogger:style{'-','-'}
testLogger:plot() if paths.filep(opt.save..'/test.log.eps') then
local base64im
do
os.execute(('convert -density 200 %s/test.log.eps %s/test.png'):format(opt.save,opt.save))
os.execute(('openssl base64 -in %s/test.png -out %s/test.base64'):format(opt.save,opt.save))
local f = io.open(opt.save..'/test.base64')
if f then base64im = f:read'*all' end
end local file = io.open(opt.save..'/report.html','w')
file:write(([[
<!DOCTYPE html>
<html>
<body>
<title>%s - %s</title>
<img src="data:image/png;base64,%s">
<h4>optimState:</h4>
<table>
]]):format(opt.save,epoch,base64im))
for k,v in pairs(optimState) do
if torch.type(v) == 'number' then
file:write('<tr><td>'..k..'</td><td>'..v..'</td></tr>\n')
end
end
file:write'</table><pre>\n'
file:write(tostring(confusion)..'\n')
file:write(tostring(model)..'\n')
file:write'</pre></body></html>'
file:close()
end
end -- save model every 50 epochs
if epoch % == then
local filename = paths.concat(opt.save, 'model.net')
print('==> saving model to '..filename)
torch.save(filename, model:get():clearState())
end confusion:zero()
end for i=,opt.max_epoch do
train()
test()
end
the original version code:
why they written like this ?
It can not run ...
Sample Classification Code of CIFAR-10 in Torch的更多相关文章
- 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network (CNN)| CIFAR 10 TensorFlow
原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...
- code::blocks(版本10.05) 配置opencv2.4.3
(1)首先下载opencv2.4.3, 解压缩到D:下: (2)配置code::blocks, 具体操作如下: 第一步, 配置compiler, 操作步骤为Settings -> Compil ...
- code::blocks(版本号10.05) 配置opencv2.4.3
(1)首先下载opencv2.4.3, 解压缩到D:下: (2)配置code::blocks, 详细操作例如以下: 第一步, 配置compiler, 操作步骤为Settings -> Comp ...
- DL Practice:Cifar 10分类
Step 1:数据加载和处理 一般使用深度学习框架会经过下面几个流程: 模型定义(包括损失函数的选择)——>数据处理和加载——>训练(可能包括训练过程可视化)——>测试 所以自己写代 ...
- 【神经网络与深度学习】基于Windows+Caffe的Minst和CIFAR—10训练过程说明
Minst训练 我的路径:G:\Caffe\Caffe For Windows\examples\mnist 对于新手来说,初步完成环境的配置后,一脸茫然.不知如何跑Demo,有么有!那么接下来的教 ...
- Oracle Applications Multiple Organizations Access Control for Custom Code
档 ID 420787.1 White Paper Oracle Applications Multiple Organizations Access Control for Custom Code ...
- UWP深入学习六:Build better apps: Windows 10 by 10 development series
Promotion in the Windows Store In this article, I walk through how to Give your Store listing a mak ...
- Removing Columns 分类: 贪心 CF 2015-08-08 16:10 10人阅读 评论(0) 收藏
Removing Columns time limit per test 2 seconds memory limit per test 256 megabytes input standard in ...
- CV code references
转:http://www.sigvc.org/bbs/thread-72-1-1.html 一.特征提取Feature Extraction: SIFT [1] [Demo program][SI ...
随机推荐
- Sql日期时间格式转换[zhuan]
sql server2000中使用convert来取得datetime数据类型样式(全) 日期数据格式的处理,两个示例: CONVERT(varchar(16), 时间一, 20) 结果:2007-0 ...
- 20155228 2016-2017-2 《Java程序设计》第6周学习总结
20155228 2016-2017-2 <Java程序设计>第6周学习总结 教材学习内容总结 输入与输出 在Java中输入串流代表对象为java.io.InputStream实例,输出串 ...
- Spring 知识点提炼-转
https://www.cnblogs.com/baizhanshi/p/7717563.html 1. Spring框架的作用 轻量:Spring是轻量级的,基本的版本大小为2MB 控制反转:Spr ...
- linux监控性能和网络的命令
vmstat查看机器实时的综合情况:load,内存,swap,cpu使用率等方面 procs: r:运行队列中进程数量 b:等待IO的进程数量 memory(内存): swpd:使用虚拟内存大小 fr ...
- RocketMQ 问题汇总
1. rocketMQ安装: 编译完成以后准备启动项目,注意:bin的位置是编译后target目录下,启动命令在这里. linux命令目录:你的目录/rocketmq-all-4.2.0/distri ...
- java初学者必看的学习路线
不管在编程语言的排行榜中,还是在大多数企业应用的广泛程度来看,Java一直都是当之无愧的榜首.Java语言有着独特的魅力吸引着广大的年轻人去学习,每个人学习的方式方法不一样. 第一步:首先要做好学习前 ...
- 一起学习在 Ubuntu 上授予和移除 sudo 权限
如你所知,用户可以在 Ubuntu 系统上使用 sudo 权限执行任何管理任务.在 Linux 机器上创建新用户时,他们无法执行任何管理任务,直到你将其加入 sudo 组的成员.在这个简短的教程中,我 ...
- 自写Jquery插件 Combobox
原创文章,转载请注明出处,https://www.cnblogs.com/GaoAnLee/p/9092421.html 上效果 html <span id='combobox' class=' ...
- 原生tab选项卡制作
html部分 <div class="tab"> <div class="nav"> <ul> <li class=& ...
- Java程序员秋招面经大合集(BAT美团网易小米华为中兴等)
Cvte提前批 阿里内推 便利蜂内推 小米内推 金山wps内推 多益网络 拼多多学霸批 搜狗校招 涂鸦移动 中国电信it研发中心 中兴 华为 苏宁内推 美团内推 百度 腾讯 招商银行信用卡 招银网络科 ...