pytorch之 regression
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) # torch can only train on Variable, so convert them to Variable
# The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x), Variable(y) # plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show() class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
self.predict = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x):
x = F.relu(self.hidden(x)) # activation function for hidden layer
x = self.predict(x) # linear output
return x net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network
print(net) # net architecture optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss() # this is for regression mean squared loss plt.ion() # something about plotting for t in range(200):
prediction = net(x) # input x and predict based on x loss = loss_func(prediction, y) # must be (1. nn output, 2. target) optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients if t % 5 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1) plt.ioff()
plt.show()
pytorch之 regression的更多相关文章
- pytorch 4 regression 回归
import torch import torch.nn.functional as F import matplotlib.pyplot as plt # torch.manual_seed(1) ...
- Linear Regression with PyTorch
Linear Regression with PyTorch Problem Description 初始化一组数据 \((x,y)\),使其满足这样的线性关系 \(y = w x + b\) .然后 ...
- Task3.PyTorch实现Logistic regression
1.PyTorch基础实现代码 import torch from torch.autograd import Variable torch.manual_seed(2) x_data = Varia ...
- pytorch之 RNN regression
关于RNN模型参数的解释,可以参看RNN参数解释 1 import torch from torch import nn import numpy as np import matplotlib.py ...
- (转)Extracting knowledge from knowledge graphs using Facebook Pytorch BigGraph.
Extracting knowledge from knowledge graphs using Facebook Pytorch BigGraph 2019-04-27 09:33:58 This ...
- PyTorch(一)Basics
PyTorch Basics import torch import torchvision import torch.nn as nn import numpy as np import torch ...
- (转) The Incredible PyTorch
转自:https://github.com/ritchieng/the-incredible-pytorch The Incredible PyTorch What is this? This is ...
- Pytorch自定义dataloader以及在迭代过程中返回image的name
pytorch官方给的加载数据的方式是已经定义好的dataset以及loader,如何加载自己本地的图片以及label? 形如数据格式为 image1 label1 image2 label2 ... ...
- 【深度学习】Pytorch 学习笔记
目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...
随机推荐
- 前端Tips#4 - 用 process.hrtime 获取纳秒级的计时精度
本文同步自 JSCON简时空 - 前端Tips 专栏#4,点击阅读 视频讲解 视频地址 文字讲解 如果去测试代码运行的时长,你会选择哪个时间函数? 一般第一时间想到的函数是 Date.now 或 Da ...
- BZOJ4559&P3270[JLoi2016]成绩比较
题目描述 \(G\)系共有\(n\)位同学,\(M\)门必修课.这\(N\)位同学的编号为\(0\)到\(N-1\)的整数,其中\(B\)神的编号为\(0\)号.这\(M\)门必修课编号为\(0\)到 ...
- 2018徐州现场赛A
题目链接:http://codeforces.com/gym/102012/problem/A 题目给出的算法跑出的数据是真的水 #include<iostream> #include&l ...
- 使用远程接口库进一步扩展Robot Framework的测试能力
引言: Robot Framework的四层结构已经极大的提高了它的扩展性.我们可以使用它丰富的扩展库来完成大部分测试工作.可是碰到下面两种情况,仅靠四层结构就不好使了: 1.有些复杂的测试可能跨越多 ...
- Scrapy信号量
1.类 from scrapy import signals class MySingle(object): def __init__(self): pass @classmethod def fro ...
- 前端开发利器 Web Replay
前端开发人员收到测试发来的 bug 后,通常比较头疼复现的问题. 即使测试人员录了视频,照着一步步操作也不一定能复现,例如bug是与当时的数据相关的. 为了解决这个问题,Firefox 推出了一个重磅 ...
- c语言-输出圆形
#include<stdio.h> #define high 100//定义界面大小 #define wide 100 void Circle(int ridus) //确定坐标 {int ...
- CentOS7安装MySQL、Tomcat和GitBlit记录
一.安装MySQL 1.安装这个发布包 yum localinstall mysql-community-release-el6-5.noarch.rpm 可以通过下面的命令来确认这个仓库被成功添加: ...
- 数据库及ORM之Mysql
1. 数据库介绍 1.1什么是数据库? 数据库(Database)是按照数据结构来组织.存储和管理数据的仓库,每个数据库都有一个或多个不同的API用于创建,访问,管理,搜索和复制所保存的数据.我们也可 ...
- CentOS7下部署2套Python版本共存
参考地址:https://www.cnblogs.com/xuaijun/p/7985245.html 源码的安装一般由3个步骤组成:配置(configure).编译(make).安装(make in ...