Linear Regression with PyTorch

Problem Description

初始化一组数据 \((x,y)\),使其满足这样的线性关系 \(y = w x + b\) 。然后基于反向传播法,用均方误差(mean squared error)

\[MSE = \frac{1}{n} \sum_{n} (y- \hat y)^{2}
\]

去拟合这组数据。

衡量两个分布之间的距离,最直接的方法是用交叉熵。

我们用最简单的一元变量去拟合这组数据,其实一元线性回归的表达式 \(y = wx + b\) 用神经网络的形式可表示成如下图所示

该神经网络有一个输入、一个输出、不使用任何激活函数。这就是一元线性回归的神经网络表示结果。相比较于下图这种神经网络的形式化表示,上图是一种简单的特例。

Key Points

torch.unsqueeze

重塑一个张量的 size,见下面代码

>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1, 2, 3, 4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
[ 2],
[ 3],
[ 4]])

torch.linspace

得到一个在 start 和 end 之间等距的一维张量,见下面代码

>>> torch.linspace(1, 6, steps=3)
tensor([ 1.0000, 3.5000, 6.0000])

torch.rand

返回一个满足 size 维度要求的随机数组,随机数服从0-1均匀分布。

torch.nn.Linear(1,1)

self.prediction = torch.nn.Linear(1, 1)

这一行代码,实际是维护了两个变量,其描述了这样的一种关系:

\[prediction_{1\times1} = weight_{1\times1} \times input_{1\times1} + bias_{1\times1}
\]

其中,每个参数都是 \(1\times1\) 维的。

Code

import torch

epoch = 10000
lr = 0.01
w = 10
b = 5 x = torch.unsqueeze(torch.linspace(1, 10, 20), 1)
y = w*x + b + torch.rand(x.size()) class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.prediction = torch.nn.Linear(1, 1) def forward(self, x):
out = self.prediction(x)
return out net = Net()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criticism = torch.nn.MSELoss() for i in range(epoch):
y_pred = net(x)
loss = criticism(y_pred, y) # 先是 y_pred 然后是 y_true 参数顺序不能乱 optimizer.zero_grad()
loss.backward()
optimizer.step() print("%.5f" % loss.data)
print(net.state_dict()['prediction.weight'])
print(net.state_dict()['prediction.bias'])

输出:

0.08882
tensor([[ 9.9713]])
tensor([ 5.6524])

Results Analysis

输出显示:

  1. 均方误差(MSE)为 0.0882
  2. \(weight\) 的拟合结果为 9.9713
  3. \(bias\) 的拟合结果为 5.6524

分析:

  1. 因为我主动引入了误差(服从0-1均匀分布),而且是线性拟合,所以 MSE 几乎不能减小到零;
  2. 9.9713 的拟合值已经非常接近真实值 10 了;5.6524 的拟合值较真实值 5 的距离较大(距离约为自身的 10%)

Linear Regression with PyTorch的更多相关文章

  1. 线性回归、梯度下降(Linear Regression、Gradient Descent)

    转载请注明出自BYRans博客:http://www.cnblogs.com/BYRans/ 实例 首先举个例子,假设我们有一个二手房交易记录的数据集,已知房屋面积.卧室数量和房屋的交易价格,如下表: ...

  2. 局部加权回归、欠拟合、过拟合(Locally Weighted Linear Regression、Underfitting、Overfitting)

    欠拟合.过拟合 如下图中三个拟合模型.第一个是一个线性模型,对训练数据拟合不够好,损失函数取值较大.如图中第二个模型,如果我们在线性模型上加一个新特征项,拟合结果就会好一些.图中第三个是一个包含5阶多 ...

  3. Multivariance Linear Regression练习

    %% 方法一:梯度下降法 x = load('E:\workstation\data\ex3x.dat'); y = load('E:\workstation\data\ex3y.dat'); x = ...

  4. Kernel Methods (3) Kernel Linear Regression

    Linear Regression 线性回归应该算得上是最简单的一种机器学习算法了吧. 它的问题定义为: 给定训练数据集\(D\), 由\(m\)个二元组\(x_i, y_i\)组成, 其中: \(x ...

  5. Linear regression with multiple variables(多特征的线型回归)算法实例_梯度下降解法(Gradient DesentMulti)以及正规方程解法(Normal Equation)

    ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, ,, , ...

  6. Linear regression with one variable算法实例讲解(绘制图像,cost_Function ,Gradient Desent, 拟合曲线, 轮廓图绘制)_矩阵操作

    %测试数据 'ex1data1.txt', 第一列为 population of City in 10,000s, 第二列为 Profit in $10,000s 1 6.1101,17.592 5. ...

  7. Matlab实现线性回归和逻辑回归: Linear Regression & Logistic Regression

    原文:http://blog.csdn.net/abcjennifer/article/details/7732417 本文为Maching Learning 栏目补充内容,为上几章中所提到单参数线性 ...

  8. Stanford机器学习---第二讲. 多变量线性回归 Linear Regression with multiple variable

    原文:http://blog.csdn.net/abcjennifer/article/details/7700772 本栏目(Machine learning)包括单参数的线性回归.多参数的线性回归 ...

  9. Stanford机器学习---第一讲. Linear Regression with one variable

    原文:http://blog.csdn.net/abcjennifer/article/details/7691571 本栏目(Machine learning)包括单参数的线性回归.多参数的线性回归 ...

随机推荐

  1. CentOS7 开放服务端口

    CentOS 7 默认是firewall防火墙 如果你想让一个web服务可以被其它机子访问,就得开放这个服务的端口,不然就会被拦截 1. 开放端口命令 firewall-cmd --add-port= ...

  2. AsssetBunlder打包

    unity3d,资源过多的话.可以压缩成一个资源包.加载出来后.可以解压.找到自己需要的资源 就想.net网站.很多图标都是放一个大图片上.而不是一个图标就是一个图片 因为是在项目编辑时候给资源打包. ...

  3. MyBatis基础入门《十 一》修改数据

    MyBatis基础入门<十 一>修改数据 实体类: 接口类: xml文件: 测试类: 测试结果: 数据库: 如有问题,欢迎纠正!!! 如有转载,请标明源处:https://www.cnbl ...

  4. Ubuntu中使用pip3报错

    使用pip3 出现以下错误: Traceback (most recent call last): File “/usr/bin/pip3”, line 9, in from pip import m ...

  5. Browsersync结合gulp和nodemon实现express全栈自动刷新

    Browsersync能让浏览器实时.快速响应你的文件更改(html.js.css.sass.less等)并自动刷新页面.更重要的是 Browsersync可以同时在PC.平板.手机等设备下进项调试. ...

  6. CSS border-radius边框圆角

    在CSS3中提供了对边框进行圆角设定的支持,可对边框1~4个角进行圆角样式设置. 目录 1. 介绍 2. value值的格式和类型 3. border-radius 1~4个参数说明 4. 在线示例 ...

  7. 如何用vue组件做个机器人?有趣味的代码

      <!DOCTYPE html> <html lang="en"> <div>     <meta charset="UTF- ...

  8. 数据分析之Numpy库入门

    1.列表与数组 在python的基础语言部分,我们并没有介绍数组类型,但是像C.Java等语言都是有数组类型的,那python中的列表和数组有何区别呢? 一维数据:都表示一组数据的有序结构 区别: 列 ...

  9. 在统一软件开发过程中使用UML

    如何在统一软件开发过程中使用UML? 起始阶段常用UML图 在起始阶段,通常有用例图.类图.活动图.顺序图等UML图的参与. 获取用户需求之后首先要将这些需求转化为系统的顶层用例图. 在确定了用例之后 ...

  10. git分支流

    ## 新建一个iss1分支 $ git branch iss1 ## 切换到iss1分支 $ git checkout iss1 Switched to branch 'iss1' ## 查看分支,当 ...