1、直接奉献代码,后期有入门更新,之前一直在学的是TensorFlow,

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np x_data = np.arange(-2*np.pi,2*np.pi,0.1).reshape(-1,1)
y_data = np.sin(x_data).reshape(-1,1) x = torch.unsqueeze(torch.linspace(-5, 5, 100), dim=1) # 将1维的数据转换为2维数据
# y = x.pow(2) + 0.2 * torch.rand(x.size())
y = torch.cos(x)
# 将tensor置入Variable中
x, y = Variable(torch.from_numpy(x_data)).float(), Variable(torch.from_numpy(y_data)).float()
print(x.shape,y.shape) # plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show() # 定义一个构建神经网络的类
class Net(torch.nn.Module): # 继承torch.nn.Module类
def __init__(self):
super(Net, self).__init__() # 获得Net类的超类(父类)的构造方法
# 定义神经网络的每层结构形式
# 各个层的信息都是Net类对象的属性
self.hidden = torch.nn.Linear(1, 10) # 隐藏层线性输出
self.centre_1 = torch.nn.Linear(10,20)
self.predict = torch.nn.Linear(20, 1) # 输出层线性输出 # 将各层的神经元搭建成完整的神经网络的前向通路
def forward(self, x):
x = F.tanh(self.hidden(x)) # 对隐藏层的输出进行relu激活
x_1 = F.tanh(self.centre_1(x))
x =F.tanh(self.predict(x_1))
return x # 定义神经网络 net = Net()
print(net) # 打印输出net的结构 # 定义优化器和损失函数
optimizer = torch.optim.SGD(net.parameters(), lr=0.5) # 传入网络参数和学习率
loss_function = torch.nn.MSELoss() # 最小均方误差
acc = lambda y1,y2: np.sqrt(np.sum(y1**2+y2**2)/len(y1)) # 神经网络训练过程
plt.ion() # 动态学习过程展示
plt.show() for t in range(100):
prediction = net(x) # 把数据x喂给net,输出预测值
loss = loss_function(prediction, y) # 计算两者的误差,要注意两个参数的顺序
optimizer.zero_grad() # 清空上一步的更新参数值
loss.backward() # 误差反相传播,计算新的更新参数值
optimizer.step() # 将计算得到的更新值赋给net.parameters() # 可视化训练过程
if (t + 1) % 2 == 0:
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2)
aucc = acc(prediction.data.numpy(),y.data.numpy())
print("loss={} aucc={}".format(loss.data.numpy(),aucc))
plt.text(-4.5, 1,
'echo=%sL=%.4f acc=%s' % (t+1,loss.data.numpy(),aucc),
fontdict={'size': 15, 'color': 'red'})
plt.pause(0.1)
print("训练结束")
plt.ioff()
plt.show()

  

python pytorch numpy DNN 线性回归模型的更多相关文章

  1. Python机器学习/LinearRegression(线性回归模型)(附源码)

    LinearRegression(线性回归) 2019-02-20  20:25:47 1.线性回归简介 线性回归定义: 百科中解释 我个人的理解就是:线性回归算法就是一个使用线性函数作为模型框架($ ...

  2. 莫烦python教程学习笔记——线性回归模型的属性

    #调用查看线性回归的几个属性 # Youtube video tutorial: https://www.youtube.com/channel/UCdyjiB5H8Pu7aDTNVXTTpcg # ...

  3. 02_利用numpy解决线性回归问题

    02_利用numpy解决线性回归问题 目录 一.引言 二.线性回归简单介绍 2.1 线性回归三要素 2.2 损失函数 2.3 梯度下降 三.解决线性回归问题的五个步骤 四.利用Numpy实战解决线性回 ...

  4. 线性回归模型(Linear Regression)及Python实现

    线性回归模型(Linear Regression)及Python实现 http://www.cnblogs.com/sumai 1.模型 对于一份数据,它有两个变量,分别是Petal.Width和Se ...

  5. 吴裕雄 python 机器学习——线性回归模型

    import numpy as np from sklearn import datasets,linear_model from sklearn.model_selection import tra ...

  6. 【scikit-learn】scikit-learn的线性回归模型

     内容概要 怎样使用pandas读入数据 怎样使用seaborn进行数据的可视化 scikit-learn的线性回归模型和用法 线性回归模型的评估測度 特征选择的方法 作为有监督学习,分类问题是预 ...

  7. scikit-learn的线性回归模型

    来自 http://blog.csdn.net/jasonding1354/article/details/46340729 内容概要 如何使用pandas读入数据 如何使用seaborn进行数据的可 ...

  8. 用C++调用tensorflow在python下训练好的模型(centos7)

    本文主要参考博客https://blog.csdn.net/luoyexuge/article/details/80399265 [1] bazel安装参考:https://blog.csdn.net ...

  9. [tensorflow] 线性回归模型实现

    在这一篇博客中大概讲一下用tensorflow如何实现一个简单的线性回归模型,其中就可能涉及到一些tensorflow的基本概念和操作,然后因为我只是入门了点tensorflow,所以我只能对部分代码 ...

随机推荐

  1. IView 使用Table组件时实现给某一列添加click事件

    通过给 columns 数据的项,设置一个函数 render,可以自定义渲染当前列,包括渲染自定义组件,它基于 Vue 的 Render 函数. render 函数传入两个参数,第一个是 h,第二个是 ...

  2. 使用Nginx的proxy_cache缓存功能取代Squid[原创]

    使用Nginx的proxy_cache缓存功能取代Squid[原创] [文章作者:张宴 本文版本:v1.2 最后修改:2009.01.12 转载请注明原文链接:http://blog.zyan.cc/ ...

  3. 2017 网易游戏互娱游戏研发4.21(offer)

    网易游戏互娱(offer) 去年这个时候就参加过网易游戏的实习生招聘,到今年总共收到了4次拒信.不过这次运气好,终于get了最想要的offer.去年实习生互娱笔试挂,秋招笔试挂,今年春招互娱投了连笔试 ...

  4. mvc api 关于 post 跟get 请求的一些想法[FromUri] 跟[FromBody] 同一个控制器如何实现共存

    wep api  在设置接收请求参数的时候,会自动根据模型进行解析. [FromUrl] 跟[FromBody] 不可以同时使用. 要拆分开: [HttpGet] public object GetP ...

  5. jdk,jre下载安装

    JDK安装https://blog.csdn.net/u012934325/article/details/73441617/jre需要手动生成在JDK安装目录下,的bin cmd执行bin\ jli ...

  6. Android_(服务)Vibrator振动器

    Vibrator振动器是Android给我们提供的用于机身震动的一个服务,例如当收到推送消息的时候我们可以设置震动提醒,也可以运用到游戏当中增强玩家互动性 运行截图: 程序结构 <?xml ve ...

  7. JDBC连接数据库遇到的“驱动程序无法通过使用安全套接字层(SSL)加密与 SQL Server 建立安全连接。

    要从旧算法列表中删除3DES: 在JDK 8及更早版本中,编辑该 /lib/security/java.security文件并3DES_EDE_CBC从jdk.tls.legacyAlgorithms ...

  8. LeetCode 300. 最长上升子序列(Longest Increasing Subsequence)

    题目描述 给出一个无序的整形数组,找到最长上升子序列的长度. 例如, 给出 [10, 9, 2, 5, 3, 7, 101, 18], 最长的上升子序列是 [2, 3, 7, 101],因此它的长度是 ...

  9. centos7 php5.5 mongodb安装

    1.下载最新php MongoDB扩展源码 https://pecl.php.net/package/mongodb 最新的1.6不支持PHP5.5,得用老版本,1.5.5 wget https:// ...

  10. JavaScript getClass() 函数

    定义和用法 getClass() 函数可返回一个 JavaObject 的 JavaClass. 语法 getClass(javaobj) 参数 描述 javaobj 一个 JavaObject 对象 ...