python日记:用pytorch搭建一个简单的神经网络
最近在学习pytorch框架,给大家分享一个最最最最基本的用pytorch搭建神经网络并且训练的方法。本人是第一次写这种分享文章,希望对初学pytorch的朋友有所帮助!
一、任务
首先说下我们要搭建的网络要完成的学习任务: 让我们的神经网络学会逻辑异或运算,异或运算也就是俗称的“相同取0,不同取1” 。再把我们的需求说的简单一点,也就是我们需要搭建这样一个神经网络,让我们在输入(1,1)时输出0,输入(1,0)时输出1(相同取0,不同取1),以此类推。
二、实现思路
因为我们的需求需要有两个输入,一个输出,所以我们需要在输入层设置两个输入节点,输出层设置一个输出节点。因为问题比较简单,所以隐含层我们只需要设置10个节点就可以达到不错的效果了,隐含层的激活函数我们采用ReLU函数,输出层我们用Sigmoid函数,让输出保持在0到1的一个范围,如果输出大于0.5,即可让输出结果为1,小于0.5,让输出结果为0.
三、实现过程
搭建网络时,我们使用简单的快速搭建法,这种方法就像搭积木一样可以让你快速高效地搭建起一个神经网络来,具体实现如下:
第一步:引入必要的库
代码如下:
- import torch
- import torch.nn as nn
- import numpy as np
用pytorch当然要引入torch包,然后为了写代码方便将torch包里的nn用nn来代替,nn这个包就是neural network的缩写,专门用来搭神经网络的一个包。引入numpy是为了创建矩阵作为输入。
第二步:创建输入集
代码如下:
- # 构建输入集
- x = np.mat('0 0;'
- '0 1;'
- '1 0;'
- '1 1')
- x = torch.tensor(x).float()
- y = np.mat('1;'
- '0;'
- '0;'
- '')
- y = torch.tensor(y).float()
我个人比较喜欢用np.mat这种方式构建矩阵,感觉写法比较简单,当然你也可以用其他的方法。但是构建完矩阵一定要有这一步torch.tensor(x).float(),必须要把你所创建的输入转换成tensor变量。
什么是tensor呢?你可以简单地理解他就是pytorch中用的一种变量,你想用pytorch这个框架就必须先把你的变量转换成tensor变量。而我们这个神经网络会要求你的输入和输出必须是float浮点型的,指的是tensor变量中的浮点型,而你用np.mat创建的输入是int型的,转换成tensor也会自动地转换成itensor的int型,所以要在后面加个.float()转换成浮点型。
这样我们就构建完成了输入和输出(分别是x矩阵和y矩阵),x是四行二列的一个矩阵,他的每一行是一个输入,一次输入两个值,这里我们把所有的输入情况都列了出来。输出y是一个四行一列的矩阵,每一行都是一个输出,对应x矩阵每一行的输入。
第三步:搭建网络
代码如下:
- # 搭建网络
- myNet = nn.Sequential(
- nn.Linear(2, 10),
- nn.ReLU(),
- nn.Linear(10, 1),
- nn.Sigmoid()
- )
- print(myNet)
我们使用nn包中的Sequential搭建网络,这个函数就是那个可以让我们像搭积木一样搭神经网络的一个东西。
nn.Linear(2,10)的意思搭建输入层,里面的2代表输入节点个数,10代表输出节点个数。Linear也就是英文的线性,意思也就是这层不包括任何其它的激活函数,你输入了啥他就给你输出了啥 。nn.ReLU()这个就代表把一个激活函数层,把你刚才的输入扔到了ReLU函数中去。 接着又来了一个Linear,最后再扔到Sigmoid函数中去。 2,10,1就分别代表了三个层的个数,简单明了。
第四步:设置优化器
代码如下:
- # 设置优化器
- optimzer = torch.optim.SGD(myNet.parameters(), lr=0.05)
- loss_func = nn.MSELoss()
对这一步的理解就是,你需要有一个优化的方法来训练你的网络,所以这步设置了我们所要采用的优化方法。
torch.optim.SGD的意思就是采用SGD方法训练,你只需要吧你网络的参数和学习率传进去就可以了,分别是myNet.paramets和lr。 loss_func这句设置了代价函数,因为我们的这个问题比较简单,所以采用了MSE,也就是均方误差代价函数。
第五步:训练网络
代码如下:
- for epoch in range(5000):
- out = myNet(x)
- loss = loss_func(out, y) # 计算误差
- optimzer.zero_grad() # 清除梯度
- loss.backward()
- optimzer.step()
我这里设置了一个5000次的循环(可能不需要这么多次),让这个训练的动作迭代5000次。每一次的输出直接用myNet(x),把输入扔进你的网络就得到了输出out(就是这么简单粗暴!),然后用代价函数和你的标准输出y求误差。 清除梯度的那一步是为了每一次重新迭代时清除上一次所求出的梯度,你就把这一步记住就行,初学不用理解太深。 loss.backward()当然就是让误差反向传播,接着optimzer.step()也就是让我们刚刚设置的优化器开始工作。
第六步:测试
代码如下:
- print(myNet(x).data)
输出结果:
- tensor([[0.9424],
- [0.0406],
- [0.0400],
- [0.9590]])
可以看到这个结果已经非常接近我们期待的结果了,当然你也可以换个数据测试,结果也会是相似的。这里简单解释下为什么我们的代码末尾加上了一个.data,因为我们的tensor变量其实是包含两个部分的,一部分是tensor数据,另一部分是tensor的自动求导参数,我们加上.data意思是输出取tensor中的数据,如果不加的话会输出下面这样:
- tensor([[0.9492],
- [0.0502],
- [0.0757],
- [0.9351]], grad_fn=<SigmoidBackward>)
今天的分享就到这里,本人第一次发博客,写得不妥当的地方还望大家批评指正,不知道有没有帮到初学pytorch的你,完整代码如下:
- import torch
- import torch.nn as nn
- import numpy as np
- # 构建输入集
- x = np.mat('0 0;'
- '0 1;'
- '1 0;'
- '1 1')
- x = torch.tensor(x).float()
- y = np.mat('1;'
- '0;'
- '0;'
- '')
- y = torch.tensor(y).float()
- # 搭建网络
- myNet = nn.Sequential(
- nn.Linear(2, 10),
- nn.ReLU(),
- nn.Linear(10, 1),
- nn.Sigmoid()
- )
- print(myNet)
- # 设置优化器
- optimzer = torch.optim.SGD(myNet.parameters(), lr=0.05)
- loss_func = nn.MSELoss()
- for epoch in range(5000):
- out = myNet(x)
- loss = loss_func(out, y) # 计算误差
- optimzer.zero_grad() # 清除梯度
- loss.backward()
- optimzer.step()
- print(myNet(x).data)
python日记:用pytorch搭建一个简单的神经网络的更多相关文章
- pytorch定义一个简单的神经网络
刚学习pytorch,简单记录一下 """ test Funcition """ import torch from torch.autog ...
- Python实现一个简单三层神经网络的搭建并测试
python实现一个简单三层神经网络的搭建(有代码) 废话不多说了,直接步入正题,一个完整的神经网络一般由三层构成:输入层,隐藏层(可以有多层)和输出层.本文所构建的神经网络隐藏层只有一层.一个神经网 ...
- 用nodejs搭建一个简单的服务器
使用nodejs搭建一个简单的服务器 nodejs优点:性能高(读写文件) 数据操作能力强 官网:www.nodejs.org 验证是否安装成功:cmd命令行中输入node -v 如果显示版本号表示安 ...
- 完成一段简单的Python程序,用于实现一个简单的加减乘除计算器功能
#!/bin/usr/env python#coding=utf-8'''完成一段简单的Python程序,用于实现一个简单的加减乘除计算器功能'''try: a=int(raw_input(" ...
- 初学Node(六)搭建一个简单的服务器
搭建一个简单的服务器 通过下面的代码可以搭建一个简单的服务器: var http = require("http"); http.createServer(function(req ...
- 【netty】(2)---搭建一个简单服务器
netty(2)---搭建一个简单服务器 说明:本篇博客是基于学习慕课网有关视频教学.效果:当用户访问:localhost:8088 后 服务器返回 "hello netty"; ...
- 使用gitblit搭建一个简单的局域网服务器
使用gitblit搭建一个简单的局域网服务器 1.使用背景 现在很多使用github管理代码,但是github需要互联网的支持,而且私有的git库需要收费.有一些项目的代码不能外泄,所以,搭建一个局域 ...
- Golang学习-第二篇 搭建一个简单的Go Web服务器
序言 由于本人一直从事Web服务器端的程序开发,所以在学习Golang也想从Web这里开始学起,如果对Golang还不太清楚怎么搭建环境的朋友们可以参考我的上一篇文章 Golang的简单介绍及Wind ...
- python+selenium之自定义封装一个简单的Log类
python+selenium之自定义封装一个简单的Log类 一. 问题分析: 我们需要封装一个简单的日志类,主要有以下内容: 1. 生成的日志文件格式是 年月日时分秒.log 2. 生成的xxx.l ...
随机推荐
- 字符串之————图文讲解字符串排序(LSD、MSD)
本篇文章围绕字符串排序的核心思想,通过图示例子和代码分析的方式讲解了两个经典的字符串排序方法,内容很详细,完整代码放在文章的最后. 一.键索引计数法 在一般排序中,都要用里面的元素不断比较,而字符串这 ...
- query 与 params 使用
这个是路由: { path:'/city/:city', name:'City', component:City } 下面使用query和params分别传参 quer ...
- [工具][vim] vim设置显示行号
转载自:electrocrazy的博客 在linux环境下,vim是常用的代码查看和编辑工具.在程序编译出错时,一般会提示出错的行号,但是用vim打开的代码确不显示行号,错误语句的定位非常不便.那么怎 ...
- 单元测试框架Uinttest一文详解
一谈及unittest,大家都知道,unittest是Python中自带的单元测试框架,它里面封装好了一些校验返回的结果方法和一些用例执行前的初始化操作.unittest单元测试框架不仅可以适用于单元 ...
- Python学习-函数,函数参数,作用域
一.函数介绍 函数定义:函数时组织好的,可重复使用的,用来实现单一,或相关联功能的代码段. 我们已经知道python提供了许多内建函数,print(), type()等.我们也可以自己创建函数,这被叫 ...
- filebeat的@timestamp字段时区问题
最近使用filebeat进行日志采集,并通过logstash对日志进行格式化处理. filebeat采集数据后,会给日志增加字段@timestamp,@timestamp是UTC时间,查看日志很不方便 ...
- 【已解决】关于IDEA中 Tomcat 控制台打印日志中文乱码的解决
在 Idea 上面使用 Tomcat 时,发现控制台打印信息的时候,出行中文乱码问题; 可以通过以下几种解决办法 1:在-Dfile.encoding=UTF-8 在vm中设置编码方式 2.然后从Fi ...
- Spring Security 梳理 - DelegatingFilterProxy
可能你会觉得奇怪,我们在web应用中使用Spring Security时只在web.xml文件中定义了如下这样一个Filter,为什么你会说是一系列的Filter呢? <filter> ...
- 阿里云服务器CentOS6.9 tomcat配置https安全访问
应用场景 上线微信小程序的时候,域名要求https安全格式,否则获取数据异常. 第一步.SSL证书获取 获取SSL证书方式很多种,包括网页生成.工具生成等,这里我使用阿里云平台获取免费ssl证书的方法 ...
- 玩转 SpringBoot 2 之整合 JWT 下篇
前言 在<玩转 SpringBoot 2 之整合 JWT 上篇> 中介绍了关于 JWT 相关概念和JWT 基本使用的操作方式.本文为 SpringBoot 整合 JWT 的下篇,通过解决 ...