import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) # torch can only train on Variable, so convert them to Variable
# The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x), Variable(y) # plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show() class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
self.predict = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x):
x = F.relu(self.hidden(x)) # activation function for hidden layer
x = self.predict(x) # linear output
return x net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network
print(net) # net architecture optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss() # this is for regression mean squared loss plt.ion() # something about plotting for t in range(200):
prediction = net(x) # input x and predict based on x loss = loss_func(prediction, y) # must be (1. nn output, 2. target) optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients if t % 5 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1) plt.ioff()
plt.show()

pytorch之 regression的更多相关文章

  1. pytorch 4 regression 回归

    import torch import torch.nn.functional as F import matplotlib.pyplot as plt # torch.manual_seed(1) ...

  2. Linear Regression with PyTorch

    Linear Regression with PyTorch Problem Description 初始化一组数据 \((x,y)\),使其满足这样的线性关系 \(y = w x + b\) .然后 ...

  3. Task3.PyTorch实现Logistic regression

    1.PyTorch基础实现代码 import torch from torch.autograd import Variable torch.manual_seed(2) x_data = Varia ...

  4. pytorch之 RNN regression

    关于RNN模型参数的解释,可以参看RNN参数解释 1 import torch from torch import nn import numpy as np import matplotlib.py ...

  5. (转)Extracting knowledge from knowledge graphs using Facebook Pytorch BigGraph.

    Extracting knowledge from knowledge graphs using Facebook Pytorch BigGraph 2019-04-27 09:33:58 This ...

  6. PyTorch(一)Basics

    PyTorch Basics import torch import torchvision import torch.nn as nn import numpy as np import torch ...

  7. (转) The Incredible PyTorch

    转自:https://github.com/ritchieng/the-incredible-pytorch The Incredible PyTorch What is this? This is ...

  8. Pytorch自定义dataloader以及在迭代过程中返回image的name

    pytorch官方给的加载数据的方式是已经定义好的dataset以及loader,如何加载自己本地的图片以及label? 形如数据格式为 image1 label1 image2 label2 ... ...

  9. 【深度学习】Pytorch 学习笔记

    目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...

随机推荐

  1. 前端Tips#4 - 用 process.hrtime 获取纳秒级的计时精度

    本文同步自 JSCON简时空 - 前端Tips 专栏#4,点击阅读 视频讲解 视频地址 文字讲解 如果去测试代码运行的时长,你会选择哪个时间函数? 一般第一时间想到的函数是 Date.now 或 Da ...

  2. BZOJ4559&P3270[JLoi2016]成绩比较

    题目描述 \(G\)系共有\(n\)位同学,\(M\)门必修课.这\(N\)位同学的编号为\(0\)到\(N-1\)的整数,其中\(B\)神的编号为\(0\)号.这\(M\)门必修课编号为\(0\)到 ...

  3. 2018徐州现场赛A

    题目链接:http://codeforces.com/gym/102012/problem/A 题目给出的算法跑出的数据是真的水 #include<iostream> #include&l ...

  4. 使用远程接口库进一步扩展Robot Framework的测试能力

    引言: Robot Framework的四层结构已经极大的提高了它的扩展性.我们可以使用它丰富的扩展库来完成大部分测试工作.可是碰到下面两种情况,仅靠四层结构就不好使了: 1.有些复杂的测试可能跨越多 ...

  5. Scrapy信号量

    1.类 from scrapy import signals class MySingle(object): def __init__(self): pass @classmethod def fro ...

  6. 前端开发利器 Web Replay

    前端开发人员收到测试发来的 bug 后,通常比较头疼复现的问题. 即使测试人员录了视频,照着一步步操作也不一定能复现,例如bug是与当时的数据相关的. 为了解决这个问题,Firefox 推出了一个重磅 ...

  7. c语言-输出圆形

    #include<stdio.h> #define high 100//定义界面大小 #define wide 100 void Circle(int ridus) //确定坐标 {int ...

  8. CentOS7安装MySQL、Tomcat和GitBlit记录

    一.安装MySQL 1.安装这个发布包 yum localinstall mysql-community-release-el6-5.noarch.rpm 可以通过下面的命令来确认这个仓库被成功添加: ...

  9. 数据库及ORM之Mysql

    1. 数据库介绍 1.1什么是数据库? 数据库(Database)是按照数据结构来组织.存储和管理数据的仓库,每个数据库都有一个或多个不同的API用于创建,访问,管理,搜索和复制所保存的数据.我们也可 ...

  10. CentOS7下部署2套Python版本共存

    参考地址:https://www.cnblogs.com/xuaijun/p/7985245.html 源码的安装一般由3个步骤组成:配置(configure).编译(make).安装(make in ...