计算图和autograd是十分强大的工具,可以定义复杂的操作并自动求导;然而对于大规模的网络,autograd太过于底层。

在构建神经网络时,我们经常考虑将计算安排成,其中一些具有可学习的参数,它们将在学习过程中进行优化。

TensorFlow里,有类似KerasTensorFlow-SlimTFLearn这种封装了底层计算图的高度抽象的接口,这使得构建网络十分方便。

在PyTorch中,包nn完成了同样的功能。nn包中定义一组大致等价于层的模块。一个模块接受输入的tesnor,计算输出的tensor,而且还保存了一些内部状态比如需要学习的tensor的参数等。nn包中也定义了一组损失函数(loss functions),用来训练神经网络。同时nn包中不光有一些激活函数和层操作外,还包含常见的损失函数。

代码如下:

  1. import torch
  2.  
  3. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  4.  
  5. N, D_in, H, D_out = 64, 1000, 100, 10
  6.  
  7. #随机生成输入和输出
  8.  
  9. x = torch.randn(N, D_in, device=device)
    y = torch.randn(N, D_out, device=device)
  10.  
  11. # 使用nn包将我们的模型定义为一系列的层。
    # nn.Sequential是包含其他模块的模块,并按顺序应用这些模块来产生其输出。
    # 每个线性模块使用线性函数从输入计算输出,并保存其内部的权重和偏差张量。
    # 在构造模型之后,我们使用.to()方法将其移动到所需的设备。
  12.  
  13. model = torch.nn.Sequential(
    torch.nn.Linear(D_in,H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
    ).to(device)
  14.  
  15. '''
    nn包中还有常用的损失函数的定义
    MSELoss()中参数reducetion 初始为'mean',为均值,我们使用的是'sum'为和
    但是在实践中,通过设置reduction='elementwise_mean'来使用均方误差作为损失更为常见
    '''
    loss_fn = torch.nn.MSELoss(reduction='elementwise_mean')
  16.  
  17. learning_rate = 1e-4
  18.  
  19. for t in range(500):
  20.  
  21. '''
    该操作为前向传播,通过向模型中传入x,进而得到输出y
    同时该模块有__call__属性可以像调用函数一样调用他们
    这样我们输入张量x,得到了输出张量y_pred
    '''
    y_pred = model(x)
    loss = loss_fn(y_pred,y)
    print(t,loss.item())
  22.  
  23. #运算之前清除梯度
    model.zero_grad()
  24.  
  25. '''
    反向传播:计算模型的损失值对模型中可训练参数的梯度
    每个参数是否可训练取决于require_grad的布尔值
    所以此操作可以计算所有可训练参数的梯度
    '''
    loss.backward()
  26.  
  27. #使用梯度下降进行更新
    #利用for循环取出model中的parameters()
    #在对param.data进行操作
    with torch.no_grad():
    for param in model.parameters():
    param.data -= learning_rate * param.grad

Pytorch 初次使用nn包的更多相关文章

  1. PyTorch 中,nn 与 nn.functional 有什么区别?

    作者:infiniteft链接:https://www.zhihu.com/question/66782101/answer/579393790来源:知乎著作权归作者所有.商业转载请联系作者获得授权, ...

  2. pytorch中torch.nn构建神经网络的不同层的含义

    主要是参考这里,写的很好PyTorch 入门实战(四)--利用Torch.nn构建卷积神经网络 卷积层nn.Con2d() 常用参数 in_channels:输入通道数 out_channels:输出 ...

  3. [pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList

    1. torch.nn与torch.nn.functional之间的区别和联系 https://blog.csdn.net/GZHermit/article/details/78730856 nn和n ...

  4. pytorch中的nn.CrossEntropyLoss()

    nn.CrossEntropyLoss()这个损失函数和我们普通说的交叉熵还是有些区别 x是模型生成的结果,class是对应的label 具体代码可参见如下 import torch import t ...

  5. Pytorch并行计算:nn.parallel.replicate, scatter, gather, parallel_apply

    import torch import torch.nn as nn import ipdb class DataParallelModel(nn.Module): def __init__(self ...

  6. pytorch函数之nn.Linear

    class torch.nn.Linear(in_features,out_features,bias = True )[来源] 对传入数据应用线性变换:y = A x+ b 参数: in_featu ...

  7. pytorch 损失函数(nn.BCELoss 和 nn.CrossEntropyLoss)(思考多标签分类问题)

    一.BCELoss 二分类损失函数 输入维度为(n, ), 输出维度为(n, ) 如果说要预测二分类值为1的概率,则建议用该函数! 输入比如是3维,则每一个应该是在0--1区间内(随意通常配合sigm ...

  8. 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化

    上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...

  9. [PyTorch入门]之从示例中学习PyTorch

    Learning PyTorch with examples 来自这里. 本教程通过自包含的示例来介绍PyTorch的基本概念. PyTorch的核心是两个主要功能: 可在GPU上运行的,类似于num ...

随机推荐

  1. 基于SILVACO ATLAS的a-IGZO薄膜晶体管二维器件仿真(03)

    今天逛ResearchGate的时候发现了一个不错的Atlas入门教程:Step by step with ATLAS Silvaco点击链接免费下载.. Atlas代码结构 当然可能有一点太基础了. ...

  2. MySQL的多表查询学习笔记

    一.案例准备 create table dept( id int primary key auto_increment, name ) ); insert into dept values(null, ...

  3. C语言:根据以下公式计算s,s=1+1/(1+2)+1/(1+2+3)+...+1/(1+2+3+...+n) -在形参s所指字符串中寻找与参数c相同的字符,并在其后插入一个与之相同的字符,

    //根据一下公式计算s,并将计算结果作为函数返回值,n通过形参传入.s=1+1/(1+2)+1/(1+2+3)+...+1/(1+2+3+...+n) #include <stdio.h> ...

  4. $ git push -u origin master

    我们第一次推送master分支时,由于远程库是空的,加上了-u参数,Git不但会把本地的master分支内容推送的远程新的master分支,还会把本地的master分支和远程的master分支关联起来 ...

  5. Try-Catch无法正确定位异常位置,我推荐2个有效技巧

    宇宙第一开发IDE Visual Studio的调试功能非常强大,平常工作debug帮助我们解决不少问题.今天分享两个异常捕获的技巧,希望能够帮助解决一些问题. 以下两种情况,我相信大家都会遇到过. ...

  6. hbase(待完善)

    1. 应用 <1>  hbase解决海量图片存储 <2>

  7. 【原】rsync使用

    在使用jenkins当跳板机的场景下,有使用git pull 代码到jenkins机器后,需要将代码复制到另一台机器上,常用的复制命令有scp和rsync:现就使用到了rsync进行详解: rsync ...

  8. 【Html 页面布局】

    float:left方式布局 <!DOCTYPE html> <html> <head> <meta charset="utf-8" /& ...

  9. Linux kali安装或更新之后出现乱码

    打开终端,输入以下命令,之后重启. apt-get install ttf-wqy-zenhei

  10. maven搭建ssm 完整过程

    https://blog.csdn.net/qq_28008917/article/details/79755935