继续学习http://www.cnblogs.com/tornadomeet/archive/2013/03/15/2962116.html,上一节课学习速率是固定的,而这里我们的目的是找到一个比较好的学习速率。我们主要是观察 不同的学习速率对应的不同的损失值与迭代次数之间的函数曲线是怎么样的,找到那条最快达到收敛的函数曲线,其对应的学习速率就是我们要找的比较好的学习速率。在这里我们分别取速率值为:0.001,0.01,0.1,1,2,当我们选择完学习速率后,其余的都跟上一节课一样了。本文要解决的问题是给出了47个训练样本,训练样本的y值为房子的价格,x属性有2个,一个是房子的大小,另一个是房子卧室的个数。需要通过这些训练数据来学习系统的函数,从而预测房子大小为1650,且卧室有3个的房子的价格。

代码如下:

x = load('ex3x.dat');
y = load('ex3y.dat'); x = [ones(size(x,),) x];%每一行是一个样本,在这里每个样本增加一维1,原因在前面课说了(讲wx+b变成w'x齐次的)
meanx = mean(x);%求均值 接下来四行是让样本的每一维度(除第一维1外)的值标准化。
sigmax = std(x);%求标准偏差 但是前面不是说线性的不用进行feature scale吗(第一课讲的)
x(:,) = (x(:,)-meanx())./sigmax();
x(:,) = (x(:,)-meanx())./sigmax(); figure
itera_num = ; %尝试的迭代次数
sample_num = size(x,); %训练样本的个数
alpha = [0.01, 0.03, 0.1, 0.3, , 1.3];%因为差不多是选取每个3倍的学习率来测试,所以直接枚举出来
plotstyle = {'b', 'r', 'g', 'k', 'b--', 'r--'};%建了一个包,每一个值代表画出的曲线样式不同,b是blue蓝色,
%r是red ,g是green..b--是blue颜色--代表的是虚线,而前面那些不加的是实现。 theta_grad_descent = zeros(size(x(,:)));
for alpha_i = :length(alpha) %alpha_i是1,,...,表示的是学习速率向量和曲线格式向量的坐标:alpha(alpha_i),plotstyle(alpha_i)
theta = zeros(size(x,),); %theta是cost function的参数,初始值赋值为0向量(*1的向量,x有几维theta就是几维的参数向量)
Jtheta = zeros(itera_num, );%Jthete是个100*1的向量,第n个元素代表第n次迭代cost function的值(预测与真实y的总均方误差)
for i = :itera_num %计算出某个学习速率alpha下迭代itera_num次数后的参数
Jtheta(i) = (/(*sample_num)).*(x*theta-y)'*(x*theta-y);%Jtheta是个100*1的列向量。(x*theta-y)'*(x*theta-y)代表的就是
%cost function 公式的那个平方,因为在向量水平上平方没有直接平方,所以就是这种转置后内积的形式。并且得到的是
%一个标量,所以再与前面的系数相乘可以直接用*,而不用.* 还有一点是前面的系数 我还是不明白为什么
%是(/(*sample_num))
grad = (/sample_num).*x'*(x*theta-y);
theta = theta - alpha(alpha_i).*grad;
end
plot(:, Jtheta(:),char(plotstyle(alpha_i)),'LineWidth', )%此处一定要通过char函数来转换因为包用()索引后得到的还是包cell,
%所以才要用char函数转换,也可以用{}索引,这样就不用转换了。
%一个学习速率对应的图像画出来以后再画出下一个学习速率对应的图像。
hold on
if( == alpha(alpha_i)) %通过实验发现alpha为1时效果最好,则此时的迭代后的theta值为所求的值
theta_grad_descent = theta
end
end
legend('0.01','0.03','0.1','0.3','','1.3');
xlabel('Number of iterations')
ylabel('Cost function') %下面是预测公式
price_grad_descend = theta_grad_descent'*[1 (1650-meanx(2))/sigmax(2) (3-meanx(3)/sigmax(3))]'

实验结果:

deep learning 学习笔记(三) 线性回归学习速率优化寻找的更多相关文章

  1. 【Deep Learning读书笔记】深度学习中的概率论

    本文首发自公众号:RAIS,期待你的关注. 前言 本系列文章为 <Deep Learning> 读书笔记,可以参看原书一起阅读,效果更佳. 概率论 机器学习中,往往需要大量处理不确定量,或 ...

  2. Deep Learning论文笔记之(一)K-means特征学习

    Deep Learning论文笔记之(一)K-means特征学习 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文,但老感 ...

  3. Learning ROS for Robotics Programming Second Edition学习笔记(三) 补充 hector_slam

    中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...

  4. Learning ROS for Robotics Programming Second Edition学习笔记(三) indigo rplidar rviz slam

    中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...

  5. Deep Learning论文笔记之(三)单层非监督学习网络分析

    Deep Learning论文笔记之(三)单层非监督学习网络分析 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文,但老感 ...

  6. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1

    3.Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1 http://blog.csdn.net/sunbow0 ...

  7. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.2

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.2 http://blog.csdn.net/sunbow0 ...

  8. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.3

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.3 http://blog.csdn.net/sunbow0 ...

  9. Oracle学习笔记三 SQL命令

    SQL简介 SQL 支持下列类别的命令: 1.数据定义语言(DDL) 2.数据操纵语言(DML) 3.事务控制语言(TCL) 4.数据控制语言(DCL)  

随机推荐

  1. Linux中的grep和cut

    提取行: grep --color  着色 -v         不包含 提取列: cut -f      列号 提取第几列 -d     分隔符 以什么为分隔符,默认是制表键 局限性:如果分隔符不那 ...

  2. nodejs中全栈开发框架meteor的文档

    http://wiki.jikexueyuan.com/project/discover-meteor/routing.html,   这本书的源码地址: https://github.com/Dis ...

  3. 解读mysql主从配置及其原理分析(Master-Slave)

    在windows下配置的,后面会在Linux下配置进行测试,需要配置mysql数据库同步的朋友可以参考下. 1.在主数据库服务器为从服务器添加一个拥有权限访问主库的用户:GRANT REPLICATI ...

  4. corethink功能模块探索开发(三)让这个模块可见

    感觉corethink把thinkphp的思想复用到淋漓尽致. 1.把opencmf.php文件配置好了后台该模块的菜单就能在安装后自动读取(分析好父子关系,否则页面死循环,apache资源占用率10 ...

  5. python——异常

    一.什么是异常 1.错误 从软件方面来说,错误是语法或是逻辑上的.错误是语法或是逻辑上的. 语法错误指示软件的结构上有错误,导致不能被解释器解释或编译器无法编译.这些些错误必须在程序执行前纠正. 当程 ...

  6. 序列化+protobuff+redis

    背景: 当redis里面需要存储 “key-字符串,value-对象” 时,是不能直接存对象,而是需要将序列化后的对象存进redis. redis没有实现内部序列化对象的功能,所以需要自己提前序列化对 ...

  7. Linux服务器iops性能测试-fio

    FIO是测试IOPS的非常好的工具,用来对硬件进行压力测试和验证,支持13种不同的I/O引擎,包括:sync,mmap, libaio, posixaio, SG v3, splice, null, ...

  8. eclipse修改项目默认编码为UTF-8

    1.windows->Preferences...打开"首选项"对话框,左侧导航树,导航到general->Workspace,右侧 Text file encodin ...

  9. Qt如何重写虚函数

    eg:QWidget的有个虚函数,KeyPressEvent,当它的子类获得焦点的时候,如果有任何按键按下,就会触发这个虚函数. 1.在mainwindow.h中声明此虚函数 protected:vo ...

  10. java基础之bit、byte、char、String

    bit 位,二进制数据0或1 byte 字节,一个字节等于8位二进制数 char 字符, String 字符串,一串字符 常见转换 1 字母  = 1byte = 8 bit 1 汉字  = 2byt ...