记得在tensorflow的入门里,介绍梯度下降算法的有效性时使用的例子求一个二次曲线的最小值。

这里使用pytorch复现如下:

1、手动计算导数,按照梯度下降计算

import torch

#使用梯度下降法求y=x^2+2x+1 最小值 从x=3开始

x=torch.Tensor([3])
for epoch in range(100):
y=x**2+2*x+1 x-=(2*x+2)*0.05 #导数是 2*x+2 print('min y={1:.2}, x={0:.2}'.format(x[0],y[0]))
min y=1.6e+01, x=2.6
min y=1.3e+01, x=2.2
min y=1e+01, x=1.9
...
min y=0.0, x=-1.0
min y=0.0, x=-1.0
min y=0.0, x=-1.0
min y=0.0, x=-1.0

2、使用torch的autograd计算

import torch
from torch.autograd import Variable
#使用梯度下降法求y=x^2+2x+1 最小值 从x=3开始 x=Variable(torch.Tensor([3]),requires_grad=True)
for epoch in range(100):
y=x**2+2*x+1
y.backward()
x.data-=x.grad.data*0.05
x.grad.data.zero_()
print('min y={1:.2}, x={0:.2}'.format(x.data[0],y.data[0])) min y=1.6e+01, x=2.6
min y=1.3e+01, x=2.2
min y=1e+01, x=1.9
...
min y=0.0, x=-1.0
min y=0.0, x=-1.0
min y=0.0, x=-1.0
min y=0.0, x=-1.0

下边来实验下使用梯度下降法求解直线回归问题,也就是最小二乘法的梯度下降求解(实际上回归问题的最优方式解 广义逆矩阵和值的乘积)

#最小二乘法 拟合y=3x+1
n=100
x=torch.rand((n))
y=x*3+1+torch.rand(n)/5 #y=3x+1
k=Variable(torch.Tensor([1]),requires_grad=True)
b=Variable(torch.Tensor([0]),requires_grad=True) for epoch in range(100):
l=torch.sum((k*x+b-y)**2)/100 #MSE 最小二乘法 加上随即噪声
l.backward()
k.data-=k.grad.data*0.3
b.data-=b.grad.data*0.3
print("k={:.2},b={:.2},l={:.2}".format(k.data[0],b.data[0],l.data))
k.grad.data.zero_()
b.grad.data.zero_()
k=1.7,b=1.3,l=4.7
k=1.9,b=1.5,l=0.37
k=2.0,b=1.6,l=0.11
k=2.1,b=1.6,l=0.088
k=2.1,b=1.6,l=0.081
k=2.1,b=1.6,l=0.075
...
k=3.0,b=1.1,l=0.0033
k=3.0,b=1.1,l=0.0033
k=3.0,b=1.1,l=0.0033

同样也可以使用torch里内置的mseloss

#最小二乘法 拟合y=3x+1
n=100
x=torch.rand((n))
y=x*3+1+torch.rand(n)/5 #y=3x+1 加上随机噪声
k=Variable(torch.Tensor([1]),requires_grad=True)
b=Variable(torch.Tensor([0]),requires_grad=True)
loss=torch.nn.MSELoss()
for epoch in range(100):
l=loss(k*x+b,y) #MSE 最小二乘法
l.backward()
k.data-=k.grad.data*0.3
b.data-=b.grad.data*0.3
print("k={:.2},b={:.2},l={:.2}".format(k.data[0],b.data[0],l.data))
k.grad.data.zero_()
b.grad.data.zero_()
k=1.7,b=1.3,l=4.7
k=1.9,b=1.6,l=0.35
k=2.0,b=1.6,l=0.09
...
k=2.9,b=1.1,l=0.0035
k=2.9,b=1.1,l=0.0035
k=2.9,b=1.1,l=0.0035
k=2.9,b=1.1,l=0.0035
k=2.9,b=1.1,l=0.0035
 

备注:新版本的torch里把torch.Variable 废除了,合并到torch.Tensor里了,好消息。数据类型统一了。原文:https://pytorch.org/docs/stable/autograd.html

Variable (deprecated)

The Variable API has been deprecated: Variables are no longer necessary to use autograd with tensors.

梯度下降与pytorch的更多相关文章

  1. 梯度下降优化算法综述与PyTorch实现源码剖析

    现代的机器学习系统均利用大量的数据,利用梯度下降算法或者相关的变体进行训练.传统上,最早出现的优化算法是SGD,之后又陆续出现了AdaGrad.RMSprop.ADAM等变体,那么这些算法之间又有哪些 ...

  2. [深度学习] pytorch学习笔记(2)(梯度、梯度下降、凸函数、鞍点、激活函数、Loss函数、交叉熵、Mnist分类实现、GPU)

    一.梯度 导数是对某个自变量求导,得到一个标量. 偏微分是在多元函数中对某一个自变量求偏导(将其他自变量看成常数). 梯度指对所有自变量分别求偏导,然后组合成一个向量,所以梯度是向量,有方向和大小. ...

  3. 梯度下降(Gradient Descent)小结

    在求解机器学习算法的模型参数,即无约束优化问题时,梯度下降(Gradient Descent)是最常采用的方法之一,另一种常用的方法是最小二乘法.这里就对梯度下降法做一个完整的总结. 1. 梯度 在微 ...

  4. 从梯度下降到Fista

    前言: FISTA(A fast iterative shrinkage-thresholding algorithm)是一种快速的迭代阈值收缩算法(ISTA).FISTA和ISTA都是基于梯度下降的 ...

  5. 线性回归、梯度下降(Linear Regression、Gradient Descent)

    转载请注明出自BYRans博客:http://www.cnblogs.com/BYRans/ 实例 首先举个例子,假设我们有一个二手房交易记录的数据集,已知房屋面积.卧室数量和房屋的交易价格,如下表: ...

  6. 随机梯度下降(Stochastic gradient descent)和 批量梯度下降(Batch gradient descent )的公式对比、实现对比[转]

    梯度下降(GD)是最小化风险函数.损失函数的一种常用方法,随机梯度下降和批量梯度下降是两种迭代求解思路,下面从公式和实现的角度对两者进行分析,如有哪个方面写的不对,希望网友纠正. 下面的h(x)是要拟 ...

  7. 为什么是梯度下降?SGD

    在机器学习算法中,为了优化损失函数loss function ,我们往往采用梯度下降算法来进行优化.举个例子: 线性SVM的得分函数和损失函数分别为:                         ...

  8. Stanford大学机器学习公开课(二):监督学习应用与梯度下降

    本课内容: 1.线性回归 2.梯度下降 3.正规方程组   监督学习:告诉算法每个样本的正确答案,学习后的算法对新的输入也能输入正确的答案   1.线性回归 问题引入:假设有一房屋销售的数据如下: 引 ...

  9. Matlab梯度下降解决评分矩阵分解

    for iter = 1:num_iters %梯度下降 用户向量 for i = 1:m %返回有0有1 是逻辑值 ratedIndex1 = R_training(i,:)~=0 ; %U(i,: ...

随机推荐

  1. MySQL/MariaDB 版本选择

    ALPHA.BETA.Release Candidate(RC).Release.GA等版本号的意义 MySQL数据库会存在很多版本,在这么多的版本中,我们如何进行选择,那么,首先我们要了解各个版本号 ...

  2. python中字典常用的方法

    #定义一个空字典: a={ } 定义一个字典: d={'age':18} #增加一个元素: d['age']=20   d[k]=v d.setdefault('age',18)    d.setde ...

  3. 卷积与反卷积以及步长stride

    1. 卷积与反卷积 如上图演示了卷积核反卷积的过程,定义输入矩阵为 I(4×4),卷积核为 K(3×3),输出矩阵为 O(2×2): 卷积的过程为:Conv(I,W)=O 反卷积的过称为:Deconv ...

  4. Cracking The Coding Interview2.4

    删除前面的linklist,使用node来表示链表 // You have two numbers represented by a linked list, where each node cont ...

  5. Mysql text类型的最大长度

    MySQL 3种text类型的最大长度如下: TEXT 65,535 bytes ~64kb MEDIUMTEXT 16,777,215 bytes ~16Mb LONGTEXT 4,294,967, ...

  6. spring — jdbc 配置文件的设置

    ---参考配置,  链接mysql 数据库 <!-- 1.配置数据源 --><bean id="dataSource" class="org.sprin ...

  7. easyUI datagrid值转义

    数据库表里面字段的值想用另一种命名形式展示,如1是 是,2是 否     解决方法: 用到formatter ,{field:'params', title:'参数', width:100, sort ...

  8. ubuntu apt-get failed

    Err http://mirrors.163.com/ubuntu/ trusty/main libtinfo-dev i386 5.9+20140118-1ubuntu1 Could not res ...

  9. 使用rsync, 向另外一台服务器同步目录和文件的脚本

    #!/bin/bash #亚特兰蒂斯-同步目录#定时任务ini_file="/usr/local/sunlight/conf/rsync-file.ini"target_ip=&q ...

  10. golang图片裁剪和缩略图生成

    直接贴代码了 package main import ( "errors" "fmt" "image" "image/gif&qu ...