文章目录:

1 任务

首先说下我们要搭建的网络要完成的学习任务:

让我们的神经网络学会逻辑异或运算,异或运算也就是俗称的“相同取0,不同取1” 。再把我们的需求说的简单一点,也就是我们需要搭建这样一个神经网络,让我们在输入(1,1)时输出0,输入(1,0)时输出1(相同取0,不同取1),以此类推。

2 实现思路

因为我们的需求需要有两个输入,一个输出,所以我们需要在输入层设置两个输入节点,输出层设置一个输出节点。因为问题比较简单,所以隐含层我们只需要设置10个节点就可以达到不错的效果了,隐含层的激活函数我们采用ReLU函数,输出层我们用Sigmoid函数,让输出保持在0到1的一个范围,如果输出大于0.5,即可让输出结果为1,小于0.5,让输出结果为0.

3 实现过程

我们使用的简单的快速搭建法。

3.1 引入必要库

  1. import torch
  2. import torch.nn as nn
  3. import numpy as np

用pytorch当然要引入torch包,然后为了写代码方便将torch包里的nn用nn来代替,nn这个包就是neural network的缩写,专门用来搭神经网络的一个包。引入numpy是为了创建矩阵作为输入。

3.2 创建训练集

  1. # 构建输入集
  2. x = np.mat('0 0;'
  3. '0 1;'
  4. '1 0;'
  5. '1 1')
  6. x = torch.tensor(x).float()
  7. y = np.mat('1;'
  8. '0;'
  9. '0;'
  10. '1')
  11. y = torch.tensor(y).float()

我个人比较喜欢用np.mat这种方式构建矩阵,感觉写法比较简单,当然你也可以用其他的方法。但是构建完矩阵一定要有这一步torch.tensor(x).float(),必须要把你所创建的输入转换成tensor变量。

什么是tensor呢?你可以简单地理解他就是pytorch中用的一种变量,你想用pytorch这个框架就必须先把你的变量转换成tensor变量。而我们这个神经网络会要求你的输入和输出必须是float浮点型的,指的是tensor变量中的浮点型,而你用np.mat创建的输入是int型的,转换成tensor也会自动地转换成tensor的int型,所以要在后面加个.float()转换成浮点型。

这样我们就构建完成了输入和输出(分别是x矩阵和y矩阵),x是四行二列的一个矩阵,他的每一行是一个输入,一次输入两个值,这里我们把所有的输入情况都列了出来。输出y是一个四行一列的矩阵,每一行都是一个输出,对应x矩阵每一行的输入。

3.3 搭建网络

  1. myNet = nn.Sequential(
  2. nn.Linear(2,10),
  3. nn.ReLU(),
  4. nn.Linear(10,1),
  5. nn.Sigmoid()
  6. )
  7. print(myNet)

输出结果:

我们使用nn包中的Sequential搭建网络,这个函数就是那个可以让我们像搭积木一样搭神经网络的一个东西。

nn.Linear(2,10)的意思搭建输入层,里面的2代表输入节点个数,10代表输出节点个数。Linear也就是英文的线性,意思也就是这层不包括任何其它的激活函数,你输入了啥他就给你输出了啥。nn.ReLU()这个就代表把一个激活函数层,把你刚才的输入扔到了ReLU函数中去。 接着又来了一个Linear,最后再扔到Sigmoid函数中去。 2,10,1就分别代表了三个层的个数,简单明了。

3.4 设置优化器

  1. optimzer = torch.optim.SGD(myNet.parameters(),lr=0.05)
  2. loss_func = nn.MSELoss()

对这一步的理解就是,你需要有一个优化的方法来训练你的网络,所以这步设置了我们所要采用的优化方法。

torch.optim.SGD的意思就是采用SGD(随机梯度下降)方法训练,你只需要把你网络的参数和学习率传进去就可以了,分别是myNet.parametslrloss_func这句设置了代价函数,因为我们的这个问题比较简单,所以采用了MSE,也就是均方误差代价函数。

3.5 训练网络

  1. for epoch in range(5000):
  2. out = myNet(x)
  3. loss = loss_func(out,y)
  4. optimzer.zero_grad()
  5. loss.backward()
  6. optimzer.step()

我这里设置了一个5000次的循环(可能不需要这么多次),让这个训练的动作迭代5000次。每一次的输出直接用myNet(x),把输入扔进你的网络就得到了输出out(就是这么简单粗暴!),然后用代价函数和你的标准输出y求误差。 清除梯度的那一步是为了每一次重新迭代时清除上一次所求出的梯度,你就把这一步记住就行,初学不用理解太深。 loss.backward()当然就是让误差反向传播,接着optimzer.step()也就是让我们刚刚设置的优化器开始工作。

3.6 测试

  1. print(myNet(x).data)

运行结果:

可以看到这个结果已经非常接近我们期待的结果了,当然你也可以换个数据测试,结果也会是相似的。这里简单解释下为什么我们的代码末尾加上了一个.data,因为我们的tensor变量其实是包含两个部分的,一部分是tensor数据,另一部分是tensor的自动求导参数,我们加上.data意思是输出取tensor中的数据,如果不加的话会输出下面这样:

【小白学PyTorch】1 搭建一个超简单的网络的更多相关文章

  1. 搭建一个超好用的 cmdb 系统

    10 分钟为你搭建一个超好用的 cmdb 系统 CMDB 是什么,作为 IT 工程师的你想必已经听说过了,或者已经烂熟了,容我再介绍一下,以防有读者还不知道.CMDB 的全称是 Configurati ...

  2. 【小白学PyTorch】20 TF2的eager模式与求导

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  3. 搭建一个最简单的node服务器

    搭建一个最简单的node服务器 1.创建一个Http服务并监听8888端口 2.使用url模块 获取请求的路由和请求参数 var http = require('http'); var url = r ...

  4. Dubbo入门介绍---搭建一个最简单的Demo框架

    Dubbo入门---搭建一个最简单的Demo框架 置顶 2017年04月17日 19:10:44 是Guava不是瓜娃 阅读数:320947 标签: dubbozookeeper 更多 个人分类: D ...

  5. 【小白学PyTorch】15 TF2实现一个简单的服装分类任务

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  6. 打造支持apk下载和html5缓存的 IIS(配合一个超简单的android APP使用)具体解释

    为什么要做这个看起来不靠谱的东西呢? 由于刚学android开发,还不能非常好的熟练控制android界面的编辑和操作,所以我的一个急着要的运用就改为html5版本号了,反正这个运用也是须要从serv ...

  7. 【小白学PyTorch】18 TF2构建自定义模型

    [机器学习炼丹术]的炼丹总群已经快满了,要加入的快联系炼丹兄WX:cyx645016617 参考目录: 目录 1 创建自定义网络层 2 创建一个完整的CNN 2.1 keras.Model vs ke ...

  8. linux搭建一个配置简单的nginx反向代理服务器 2个tomcat

    1.我们只要实现访问nginx服务器能跳转到不同的服务器即可,我本地测试是这样的, 在nginx服务器里面搭建了2个tomcat,2个tomcat端口分别是8080和8081,当我输入我nginx服务 ...

  9. Dubbo入门—搭建一个最简单的Demo框架

    一.Dubbo背景和简介 1.电商系统的演进 Dubbo开始于电商系统,因此在这里先从电商系统的演变讲起. a.单一应用框架(ORM) 当网站流量很小时,只需一个应用,将所有功能如下单支付等都部署在一 ...

随机推荐

  1. luogu P1712 [NOI2016]区间 贪心 尺取法 线段树 二分

    LINK:区间 没想到尺取法. 先说暴力 可以发现答案一定可以转换到端点处 所以在每个端点从小到大扫描线段就能得到答案 复杂度\(n\cdot m\) 再说我的做法 想到了二分 可以进行二分答案 从左 ...

  2. jmeter如何设置全局变量

    场景:性能测试或者接口测试,如果想跨线程引用(案例:A线程组里面的一个输出,是B线程组里面的一个输入,这个时候如果要引用),这个时候你就必须要设置全局变量;全链路压测也需要分不同场景,通常情况,一个场 ...

  3. NCoreCoder.Aop详解

    于今天,功能终于完善度到比较满意的程度了 准备好好写一篇文章,而不是之前的流水账,分享一下最近这些天的踩坑 一开始AOP选的微软提供的DispatchProxy 关于这个,有大佬的文章,可以看看,了解 ...

  4. 实验08——java百文百鸡

    package cn.tedu.demo; /** * @author 赵瑞鑫 E-mail:1922250303@qq.com * @version 1.0 * @创建时间:2020年7月17日 下 ...

  5. 使用opencv在Qt控件上播放mp4文件

    文章目录 简介 核心代码 运行结果 简介 opencv是一个开源计算机视觉库,功能非常多,这里简单介绍一下OpenCV解码播放Mp4文件,并将图像显示到Qt的QLabel上面. 核心代码 头文件 #i ...

  6. qt中使用dll库的方法

    使用dll文件时首先通过dll文件导出符号表,如下面介绍 1. 制作def 直接调用 pexports mylib.dll > mylib.def 2. 生成a 需要mylib.dll和myli ...

  7. CI4框架应用二 - 项目目录

    我们之前搭建好了CI4的开发环境,下面我们来看一下CI4的目录结构. Administrator@PC- MINGW64 /c/wamp64/www/ci4 $ ls -l total drwxr-x ...

  8. sql server 存储过程的(包含事务)方法里面,采用游标循环,批量删除(修改)数据

    sqlserver 数据库 1.下面是完整的 在存储过程中 使用游标进行 循环删除的实例(包括存储过程中,事务的应用) 2.有问题的话,欢迎随时讨饶我,相信大家看下注释应该就能明白了,很简单的一个,小 ...

  9. Python 错误 异常

    8 错误,调试和测试 8.1错误处理 所有的异常来自 BaseException 记录错误 : # err_logging.py import logging def foo(s): return 1 ...

  10. python3.x与2.x中print输出不换行

    python3.x: print(i,end=' ') 循环输出: ... ------------------------- print(i,end='!') 循环输出:!!!... end=单引号 ...