使用随机梯度下降训练神经网络

StochasticGradient是一个比较高层次的类,它接受两个参数,module和criterion,前者是模型结构,后者是损失函数的类型。这个类本身有一些参数:

LearningRate: 这是学习率,不用多说

LearningRateDecay: 学习率衰减,current_learning_rate =learningRate / (1 + iteration * learningRateDecay)

maxIteration: 最大迭代次数

shuffleIndices 是否洗数据

hookExample 这个比较神奇,是一个钩子函数,具体功能不详。

hookIteration: 同样的。

如何使用StochasticGradient来训练神经网络?

只有两步

  1. 准备好你的数据
  2. 设计好神经网络结构和loss function

同样的用官方文档的一个例子:

准备数据集:

dataset={};

function dataset:size() return 100 end -- 100 examples

for i=1,dataset:size() do

local input = torch.randn(2);     -- normally distributed example in 2d

local output = torch.Tensor(1);

if input[1]*input[2]>0 then     -- calculate label for XOR function

output[1] = -1;

else

output[1] = 1

end

dataset[i] = {input, output}

end

定义神经网络:

require "nn"

mlp = nn.Sequential();  -- make a multi-layer perceptron

inputs = 2; outputs = 1; HUs = 20; -- parameters

mlp:add(nn.Linear(inputs, HUs))

mlp:add(nn.Tanh())

mlp:add(nn.Linear(HUs, outputs))

训练网络:

criterion = nn.MSECriterion()

trainer = nn.StochasticGradient(mlp, criterion)

trainer.learningRate = 0.01

trainer:train(dataset)

同样的,如果不使用stochasticGradient类,手动训练神经网络也是可以的。

这里举得例子是训练XOR问题。

带有一层隐藏层的神经网络:

require "nn"

mlp = nn.Sequential();  -- make a multi-layer perceptron

inputs = 2; outputs = 1; HUs = 20; -- parameters

mlp:add(nn.Linear(inputs, HUs))

mlp:add(nn.Tanh())

mlp:add(nn.Linear(HUs, outputs))

Loss function

Criterion = nn.MSECriterion()

Training:

for i = 1,2500 do

-- random sample(生成数据集)

local input= torch.randn(2);     -- normally distributed example in 2d

local output= torch.Tensor(1);

if input[1]*input[2] > 0 then  -- calculate label for XOR function

output[1] = -1

else

output[1] = 1

end

-- 这里需要注意的是criterion的forward和nn的forward的调用顺序

-- feed it to the neural network and the criterion

criterion:forward(mlp:forward(input), output)

-- train over this example in 3 steps

-- (1) zero the accumulation of the gradients

mlp:zeroGradParameters()

-- (2) accumulate gradients

mlp:backward(input, criterion:backward(mlp.output, output))

-- (3) update parameters with a 0.01 learning rate

mlp:updateParameters(0.01)

end

Torch7学习笔记(四)StochasticGradient的更多相关文章

  1. C#可扩展编程之MEF学习笔记(四):见证奇迹的时刻

    前面三篇讲了MEF的基础和基本到导入导出方法,下面就是见证MEF真正魅力所在的时刻.如果没有看过前面的文章,请到我的博客首页查看. 前面我们都是在一个项目中写了一个类来测试的,但实际开发中,我们往往要 ...

  2. IOS学习笔记(四)之UITextField和UITextView控件学习

    IOS学习笔记(四)之UITextField和UITextView控件学习(博客地址:http://blog.csdn.net/developer_jiangqq) Author:hmjiangqq ...

  3. java之jvm学习笔记四(安全管理器)

    java之jvm学习笔记四(安全管理器) 前面已经简述了java的安全模型的两个组成部分(类装载器,class文件校验器),接下来学习的是java安全模型的另外一个重要组成部分安全管理器. 安全管理器 ...

  4. Learning ROS for Robotics Programming Second Edition学习笔记(四) indigo devices

    中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...

  5. Typescript 学习笔记四:回忆ES5 中的类

    中文网:https://www.tslang.cn/ 官网:http://www.typescriptlang.org/ 目录: Typescript 学习笔记一:介绍.安装.编译 Typescrip ...

  6. ES6学习笔记<四> default、rest、Multi-line Strings

    default 参数默认值 在实际开发 有时需要给一些参数默认值. 在ES6之前一般都这么处理参数默认值 function add(val_1,val_2){ val_1 = val_1 || 10; ...

  7. muduo网络库学习笔记(四) 通过eventfd实现的事件通知机制

    目录 muduo网络库学习笔记(四) 通过eventfd实现的事件通知机制 eventfd的使用 eventfd系统函数 使用示例 EventLoop对eventfd的封装 工作时序 runInLoo ...

  8. python3.4学习笔记(四) 3.x和2.x的区别,持续更新

    python3.4学习笔记(四) 3.x和2.x的区别 在2.x中:print html,3.x中必须改成:print(html) import urllib2ImportError: No modu ...

  9. Go语言学习笔记四: 运算符

    Go语言学习笔记四: 运算符 这章知识好无聊呀,本来想跨过去,但没准有初学者要学,还是写写吧. 运算符种类 与你预期的一样,Go的特点就是啥都有,爱用哪个用哪个,所以市面上的运算符基本都有. 算术运算 ...

  10. 零拷贝详解 Java NIO学习笔记四(零拷贝详解)

    转 https://blog.csdn.net/u013096088/article/details/79122671 Java NIO学习笔记四(零拷贝详解) 2018年01月21日 20:20:5 ...

随机推荐

  1. 在桌面程序上和Metro/Modern/Windows store app的交互(相互打开,配置读取)

    这个标题真是取得我都觉得蛋疼..微软改名狂魔搞得我都不知道要叫哪个好.. 这边记录一下自己的桌面程序跟windows store app交互的过程. 由于某些原因,微软的商店应用的安全沙箱导致很多事情 ...

  2. Reverse Core 第二部分 - 13章 - PE文件格式

    @date: 2016/11/24 @author: dlive ​ PE (portable executable) ,它是微软在Unix平台的COFF(Common Object File For ...

  3. EF操作多数据库

    1.Account3_Register_DB_Model作为(空)模板库,根据此模板生成的其他数据除了数据库名称不一样,其他表,视图,字段等等都一致 2.Account3_Platform_Maste ...

  4. Flex中的initialize,creationComplete和applicationComp

    转自:http://blog.csdn.net/sjz168/article/details/7244374 1.Application标签中有三个事件initialize,creationCompl ...

  5. 3ds max 渲染模型

    有的模型因为法线方向问题,渲染的时候有的面缺失,只需要强制双面,如下图,就能把所有的面都渲染出来.

  6. OOP的四个魔术方法

    1 __autoload()自动包含类文件 通常会把类的定义单独写到一个文件里,要在另外的文件调用时需要引用require,但类的定义文件会很多就会造成一下问题 //1 如果包含多个类文件,需要一一引 ...

  7. ViewPager中Fragment切换过程不被销毁的方法

    背景:最近在写一个音乐播放器,然后一个ViewPager里面加载了四个Fragment,但是在切换过程中发现,Fragment总是被销毁,在网上查了一下,发现有两种办法可以保证Fragment不被销毁 ...

  8. 通过dll或def文件提取lib导入库文件

    很多时候第三方库或其他项目提供的库多数会以动态库的形式提供dll以及相应的lib导入库.头文件,不过也有的只是提供dll和头文件,或者也提供了def模块定义(用于导出函数)文件,此时若使用将不得不调用 ...

  9. NFC读写实例

    package com.sy.nfc.test; import java.io.IOException; import android.nfc.NdefMessage; import android. ...

  10. 【leetcode】Isomorphic Strings

    题目简述: Given two strings s and t, determine if they are isomorphic. Two strings are isomorphic if the ...