用Pytorch训练线性回归模型
假定我们要拟合的线性方程是:\(y=2x+1\)
\(x\):[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
\(y\):[1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
'''生成输入输出'''
x_values = [i for i in range(15)]
x_train = np.array(x_values, dtype=np.float32)
x_train = x_train.reshape(-1,1)
y_values = [2*i+1 for i in x_values]
y_train = np.array(y_values, dtype=np.float32)
y_train = y_train.reshape(-1,1)
'''定义模型'''
class LinearRegressionModel(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearRegressionModel,self).__init__() #用nn.Module的init方法
self.linear = nn.Linear(input_dim, output_dim) #因为我们假设的函数是线性函数
def forward(self, x):
out = self.linear(x)
return out
''''''
input_dim = 1
output_dim = 1
model = LinearRegressionModel(input_dim, output_dim)
criterion = nn.MSELoss() #损失函数为均方差
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
'''训练网络'''
epochs = 30
for epoch in range(epochs):
epoch += 1
inputs = Variable(torch.from_numpy(x_train))
labels = Variable(torch.from_numpy(y_train))
#清空梯度参数
optimizer.zero_grad()
#获得输出
outputs = model(inputs)
#计算损失
loss = criterion(outputs, labels)
#反向传播
loss.backward()
#更新参数
optimizer.step()
print('epoch {}, loss {}'.format(epoch, loss.data[0]))
输出如下
epoch 1, loss 290.4517517089844
epoch 2, loss 39.308494567871094
epoch 3, loss 5.320824146270752
epoch 4, loss 0.721196711063385
epoch 5, loss 0.09870971739292145
epoch 6, loss 0.01445594523102045
epoch 7, loss 0.003041634801775217
epoch 8, loss 0.0014851536834612489
epoch 9, loss 0.0012628223048523068
epoch 10, loss 0.0012211636640131474
epoch 11, loss 0.0012040861183777452
epoch 12, loss 0.0011904657585546374
epoch 13, loss 0.001177445170469582
epoch 14, loss 0.0011646103812381625
epoch 15, loss 0.0011519324034452438
epoch 16, loss 0.0011393941240385175
epoch 17, loss 0.0011269855313003063
epoch 18, loss 0.0011147174518555403
epoch 19, loss 0.001102585345506668
epoch 20, loss 0.001090570935048163
epoch 21, loss 0.0010787042556330562
epoch 22, loss 0.0010669684270396829
epoch 23, loss 0.0010553498286753893
epoch 24, loss 0.001043855445459485
epoch 25, loss 0.0010324924951419234
epoch 26, loss 0.0010212488705292344
epoch 27, loss 0.0010101287625730038
epoch 28, loss 0.000999127165414393
epoch 29, loss 0.0009882354643195868
epoch 30, loss 0.0009774940554052591
#可以看出loss逐步缩小
画图观察
predicted = model(Variable(torch.from_numpy(x_train))).data.numpy()
plt.clf()
plt.plot(x_train, y_train, 'go', label="True Value", alpha=0.5)
plt.plot(x_train, predicted, '--', label='Predictions',alpha=0.5)
plt.legend(loc='best')
plt.show()
图如下:
用Pytorch训练线性回归模型的更多相关文章
- tensorflow训练线性回归模型
tensorflow安装 tensorflow安装过程不是很顺利,在这里记录一下 环境:Ubuntu 安装 sudo pip install tensorflow 如果出现错误 Could not f ...
- 1.1Tensorflow训练线性回归模型入门程序
tensorflow #-*- coding: utf-8 -*- # @Time : 2017/12/19 14:36 # @Author : Z # @Email : S # @File : 1. ...
- TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化
线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...
- 从头学pytorch(三) 线性回归
关于什么是线性回归,不多做介绍了.可以参考我以前的博客https://www.cnblogs.com/sdu20112013/p/10186516.html 实现线性回归 分为以下几个部分: 生成数据 ...
- 03_利用pytorch解决线性回归问题
03_利用pytorch解决线性回归问题 目录 一.引言 二.利用torch解决线性回归问题 2.1 定义x和y 2.2 自定制线性回归模型类 2.3 指定gpu或者cpu 2.4 设置参数 2.5 ...
- 【scikit-learn】scikit-learn的线性回归模型
内容概要 怎样使用pandas读入数据 怎样使用seaborn进行数据的可视化 scikit-learn的线性回归模型和用法 线性回归模型的评估測度 特征选择的方法 作为有监督学习,分类问题是预 ...
- R语言解读多元线性回归模型
转载:http://blog.fens.me/r-multi-linear-regression/ 前言 本文接上一篇R语言解读一元线性回归模型.在许多生活和工作的实际问题中,影响因变量的因素可能不止 ...
- PocketSphinx语音识别系统语言模型的训练和声学模型的改进
PocketSphinx语音识别系统语言模型的训练和声学模型的改进 zouxy09@qq.com http://blog.csdn.net/zouxy09 关于语音识别的基础知识和sphinx的知识, ...
- 深度学习入门实战(二)-用TensorFlow训练线性回归
欢迎大家关注腾讯云技术社区-博客园官方主页,我们将持续在博客园为大家推荐技术精品文章哦~ 作者 :董超 上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能 ...
随机推荐
- ABP大型项目实战(1) - 目录
前面我写了<如何用ABP框架快速完成项目>系列文章,讲述了如何用ABP快速完成项目. 然后我收到很多反馈,其中一个被经常问到的问题就是,“看了你的课程,发现ABP的优势是快速开发,那么 ...
- 小米Max 2获取ROOT超级权限的经验
小米Max 2有么好方法开通了root权限?大家都了解,安卓手机有root权限,如果手机开通了root相关权限,能够实现更完美的功能,打比方大家企业的营销部门的同事,使用某些营销工具都需要在root权 ...
- 使用docker部署skywalking
使用docker部署skywalking Intro 之前在本地搭建过一次 skywalking + elasticsearch ,但是想要迁移到别的机器上使用就很麻烦了,于是 docker 就成了很 ...
- WordCount结对项目
合作者:201631062124,201631062423 代码地址:https://gitee.com/yryx/WordCount 作业地址:https://edu.cnblogs.com/cam ...
- Vue双向绑定原理,教你一步一步实现双向绑定
当今前端天下以 Angular.React.vue 三足鼎立的局面,你不选择一个阵营基本上无法立足于前端,甚至是两个或者三个阵营都要选择,大势所趋. 所以我们要时刻保持好奇心,拥抱变化,只有在不断的变 ...
- php+qrcode类+生成二维码方法
//生成二维码 public function qrcode() { $data = input(); if(!$data['param']){ return json(['code ' => ...
- 本地系统服务例程:Nt和Zw系列函数
Windows本地操作系统服务API由一系列以Nt或Zw为前缀的函数实现的,这些函数以内核模式运行,内核驱动可以直接调用这些函数,而用户层程序只能通过系统进行调用.通常情况下用户层应用程序不会直接调用 ...
- Go语言学习笔记-流程控制(二)
Go语言流程控制 字典类型Map 1.上节遗留:map字典类型 变量声明:var myMap map[string] PersonInfo 其中,myMap是变量名,string是键的类型,Perso ...
- IIS出现The specified module could not be found的解决方法
1.打开IIS 信息服务,在左侧找到自己的计算机,点右键,选择属性,在主属性中选编辑,打开“目录安全性”选项卡,单击“匿名访问和验证控制”里的“编辑”按钮,在弹出的对话框中确保只选中了“匿名访问 ...
- Node、TS、Koa学习笔记
这样定义可以轻松拿到gender属性 这样定义,函数内显示没有gender 这种方法能得到gender但是函数内部没有gender 这种方式能到gender 但是在函数里施symbel属性,外部不能访 ...