一、Tensor

Tensor是Pytorch中重要的数据结构,可以认为是一个高维数组。Tensor可以是一个标量、一维数组(向量)、二维数组(矩阵)或者高维数组等。Tensor和numpy的ndarrays相似。

import torch as t

构建矩阵:x = t.Tensor(m, n)

注意这种情况下只分配了空间,并没有初始化。

使用[0,1]均匀分布随机初始化矩阵:x = t.rand(m, n)

查看x的形状:x.size()

加法:

(1)x + y

(2)t.add(x, y)

(3)t.add(x, y, out = res)

(4)y.add(x) #不改变y的内容

(5)y.add_(x) #改变y的内容

注意,函数名后面带下划线_的函数会修改Tensor本身,而x.add(y)等则会返回一个新的Tensor,而x不变。

Tensor可以进行数学运算、线性代数、选择、切片等。而且Tensor与numpy的数组间互操作十分方便。对于Tensor不支持的操作可以先转为numpy处理,再转为Tensor。

  1. a = t.ones(5)
  2. b = a.numpy() # Tensor -> Numpy
  3.  
  4. a = np.ones(5)
  5. b = t.from_numpy(a) # Numpy -> Tensor

Tensor和numpy的对象是共享内存的,如果其中一个改变,那么另一个也会随之改变。

Tensor可以通过.cuda方法转为GPU的Tensor。GPU加速:

  1. if t.cuda.is_available():
  2. x = x.cuda()
  3. y = y.cuda()
  4. x + y

二、Autograd:自动微分

Pytorch的Autograd模块实现了求导功能,在Tensor上的所有操作,Autograd都能自动为它们提供微分。

autograd.Variable是Autograd的核心类,简单封装了Tensor,并几乎支持所有的Tensor操作。当Tensor被封装为Variable后,可以通过调用.backward实现反向传播,自动计算所有的梯度。

Variable包括三个属性:

(1)data:保存Variable包含的Tensor;

(2)grad:保存data对应的梯度,grad也是个Variable;

(3)grad_fn:指向一个Function对象,该Function用于反向传播计算输入的梯度。

  1. from torch.autograd import Variable
  2. x = Variable(t.ones(2, 2), requires_grad = True)
  3. y = x.sum()
  4. print(y.grad_fn)
  5. y.backward()
  6. print(x.grad)
  7. y.backward()
  8. print(x.grad)

注意,grad在反向传播中是累加的,也就是说每次运行反向传播时梯度都会累加之前的梯度,因此需要反向传播之前要把梯度清零。

  1. x.grad.data.zero_() # inplace操作,把x的data对应的梯度值清零
  2.  
  3. VariableTensor的转换:
  4. x = Variable(t.ones(4, 5))
  5. y = t.cos(x)
  6. x_tensor_cos = t.cos(x.data) # Variable x的data的cosine对应的Tensor

三、神经网络

torch.nn是专门为神经网络设计的模块化接口。nn.Module可以看做是一个网络的封装,包括网络各层的定义以及forward方法,调用forward(input)方法,可返回前向传播的结果。

LeNet网络结构如下图:

定义网络时,要继承nn.Module,并实现它的forward方法,把网络中具有可学习参数的层放在构造函数__init__中。如果某层不具有可学习的参数,则既可以放在构造函数中,也可以不放。

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3.  
  4. class Net(nn.Module):
  5. def __init__(self):
  6. # nn.Module子类的函数必须在构造函数中执行父类的构造函数
  7. # 下式等价于nn.Module.__init__(self)
  8. super(Net, self).__init__()
  9. # '1'表示输入图像为单通道,'6'表示输出通道数
  10. # '5'表示卷积核大小为5*5
  11. self.conv1 = nn.Conv2d(1, 6, 5)
  12. # 卷积层
  13. self.conv2 = nn.Conv2d(6, 16, 5)
  14. # 仿射层/全连接层,y = Wx + b
  15. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  16. self.fc2 = nn.Linear(120, 84)
  17. self.fc3 = nn.Linear(84, 10)
  18.  
  19. def forward(self, x):
  20. # 卷积 -> 激活 -> 池化
  21. x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
  22. x = F.max_pool2d(F.relu(self.conv2(x)), 2)
  23. # reshape, '-1'表示自适应
  24. x = x.view(x.size()[0], -1)
  25. x = F.relu(self.fc1(x))
  26. x = F.relu(self.fc2(x))
  27. x = self.fc3(x)
  28.  
  29. return x
  30.  
  31. net = Net()
  32. print(net)

运行结果:

只要在nn.Module的子类中定义了forward函数,backward函数就会被自动实现。网络的可学习参数通过net.parameters()返回,而net.named_parameters可同时返回可学习的参数及名称。

  1. params = list(net.parameters())
  2. print(len(params))
  3.  
  4. for name, parameters in net.named_parameters():
  5. print(name, ':', parameters.size())

运行结果:

注意,输入和输出都必须是Variable。

torch.nn只支持mini-batches,不支持一次只输入一个样本,也就是说一次必须是一个batch。如果只想输入一个样本,那么要用input.unsqueeze(0)将batch_size设为1。

损失函数:

(1)nn.MSELoss用于计算均方误差

(2)nn.CrossEntropyLoss用于计算交叉熵损失

  1. output = net(input)
  2. target = Variable(t.arange(0, 10))
  3. target = target.float()
  4. criterion = nn.MSELoss()
  5. loss = criterion(output, target)

在反向传播计算所有参数的梯度后,还需要使用优化方法更新网络的权重和参数,如SGD:

weight = weight - lr * gradient

手工实现如下:

  1. lr = 0.01
  2. for f in net.parameters():
  3. f.data.sub_(f.grad.data * lr)
  4. torch.optim中实现了深度学习中的绝大多数优化方法,如RMSPropAdamSGD等。
  5.  
  6. # 新建一个优化器,指定要调整的参数和学习率
  7. optimizer = optim.SGD(net.parameters(), lr = 0.01)
  8.  
  9. # 训练过程中,先梯度清零
  10. optimizer.zero_grad()
  11.  
  12. # 计算损失
  13. output = net(input)
  14. loss = criterion(output, target)
  15.  
  16. # 反向传播
  17. loss.backward()
  18. optimizer.step()

torchvision实现了常用的图像数据加载功能。

四、CIFAR-10 分类

  1. import torch.nn as nn
  2. import torch as t
  3. from torch.autograd import Variable
  4. import torch.nn.functional as F
  5. import torch.optim as optim
  6. import torchvision as tv
  7. import torchvision.transforms as transforms
  8. from torchvision.transforms import ToPILImage
  9.  
  10. show = ToPILImage() # 把Tensor转换成Image,方便可视化
  11.  
  12. # 数据预处理
  13. transform = transforms.Compose([transforms.ToTensor(), # 转为Tensor
  14. transforms.Normalize((0.5, 0.5, 0.5), (0.5,0.5, 0.5)),])
  15.  
  16. # 训练集
  17. trainset = tv.datasets.CIFAR10(root = 'F:/PycharmProjects/', train = True, download =
  18. True, transform = transform)
  19. trainloader = t.utils.data.DataLoader(trainset, batch_size = 4, shuffle = True, num_workers = 2)
  20.  
  21. # 测试集
  22. testset = tv.datasets.CIFAR10(root = 'F:/PycharmProjects/', train = False, download =
  23. True, transform = transform)
  24. testloader = t.utils.data.DataLoader(testset, batch_size = 4, shuffle = False, num_workers = 2)
  25.  
  26. classes = ('plane', 'car', 'bird', 'cat',
  27. 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  28.  
  29. (data, label) = trainset[100]
  30. print(classes[label])
  31.  
  32. show((data + 1) / 2).resize(100, 100)
  33.  
  34. # 将dataset返回的每一条数据样本拼接成一个batch
  35. dataiter = iter(trainloader)
  36. images, labels = dataiter.next()
  37. print(' '.join('%11s' % classes[labels[j]] for j in range(4)))
  38. show(tv.utils.make_grid((images + 1) / 2)).resize((400, 100))
  39.  
  40. class Net(nn.Module):
  41. def __init__(self):
  42. super(Net, self).__init__()
  43. self.conv1 = nn.Conv2d(3, 6, 5)
  44. self.conv2 = nn.Conv2d(6, 16, 5)
  45. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  46. self.fc2 = nn.Linear(120, 84)
  47. self.fc3 = nn.Linear(84, 10)
  48.  
  49. def forward(self, x):
  50. x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
  51. x = F.max_pool2d(F.relu(self.conv2(x)), 2)
  52. x = x.view(x.size()[0], -1)
  53. x = F.relu(self.fc1(x))
  54. x = F.relu(self.fc2(x))
  55. x = self.fc3(x)
  56. return x
  57.  
  58. net = Net()
  59. criterion = nn.CrossEntropyLoss()
  60. optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum = 0.9)
  61.  
  62. for epoch in range(2):
  63. running_loss = 0.0
  64. for i, data in enumerate(trainloader, 0):
  65. inputs, labels = data
  66. inputs, labels = Variable(inputs), Variable(labels)
  67.  
  68. optimizer.zero_grad()
  69.  
  70. outputs = net(inputs)
  71. loss = criterion(outputs, labels)
  72. loss.backward()
  73. # 参数更新
  74. optimizer.step()
  75.  
  76. running_loss += loss.data[0]
  77. if i % 2000 == 1999:
  78. print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
  79. running_loss = 0.0
  80. print('Finished Training')
  81.  
  82. dataiter = iter(testloader)
  83. images, labels = dataiter.next()
  84. outputs = net(Variable(images))
  85. _, predicted = t.max(outputs.data, 1)
  86. print('predicted result', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
  87.  
  88. correct = 0
  89. total = 0
  90. for data in testloader:
  91. images, labels = data
  92. outputs = net(Variable(images))
  93. _, predicted = t.max(outputs.data, 1)
  94. total += labels.size(0)
  95. correct += (predicted ==labels).sum
  96.  
  97. print('The accuracy is: %d %%' % (100 * correct / total))
  98.  
  99. #GPU加速操作:
  100. if t.cuda.is_available():
  101. net.cuda()
  102. images = images.cuda()
  103. labels = labels.cuda()
  104. output = net(Variable(images))
  105. loss = criterion(output, Variable(labels))

  

Pytorch学习笔记(一)——简介的更多相关文章

  1. Linux内核学习笔记-1.简介和入门

    原创文章,转载请注明:Linux内核学习笔记-1.简介和入门 By Lucio.Yang 部分内容来自:Linux Kernel Development(Third Edition),Robert L ...

  2. React学习笔记 - JSX简介

    React Learn Note 2 React学习笔记(二) 标签(空格分隔): React JavaScript 一.JSX简介 像const element = <h1>Hello ...

  3. CUBRID学习笔记 1 简介 cubrid教程

    CUBRID 是一个全面开源,且完全免费的关系数据库管理系统.CUBRID为高效执行Web应用进行了高度优化,特别是需要处理大数据量和高并发请求的复杂商务服务.通过提供独特的最优化特性,CUBRID可 ...

  4. [PyTorch 学习笔记] 1.1 PyTorch 简介与安装

    PyTorch 的诞生 2017 年 1 月,FAIR(Facebook AI Research)发布了 PyTorch.PyTorch 是在 Torch 基础上用 python 语言重新打造的一款深 ...

  5. shiro学习笔记_0100_shiro简介

    前言:第一次知道shiro是2016年夏天,做项目时候我要写springmvc的拦截器,申哥看到后,说这个不安全,就给我捣鼓了shiro,我就看了下,从此认识了shiro.此笔记是根据网上的视频教程记 ...

  6. Mybatis-Plus 实战完整学习笔记(一)------简介

    第一章    简介      1. 什么是MybatisPlus                MyBatis-Plus(简称 MP)是一个 MyBatis 的增强工具,在 MyBatis 的基础上只 ...

  7. ElasticSearch学习笔记-01 简介、安装、配置与核心概念

    一.简介 ElasticSearch是一个基于Lucene构建的开源,分布式,RESTful搜索引擎.设计用于云计算中,能够达到实时搜索,稳定,可靠,快速,安装使用方便.支持通过HTTP使用JSON进 ...

  8. python学习笔记1--python简介和第一行代码编写

    一.什么是python? python是一种面向对象,解释型语言,它语法简介,容易学习.本节博客就来说说本人学习python的心得体会. 二.python环境安装 目前python版本有python2 ...

  9. Pytorch学习笔记(二)---- 神经网络搭建

    记录如何用Pytorch搭建LeNet-5,大体步骤包括:网络的搭建->前向传播->定义Loss和Optimizer->训练 # -*- coding: utf-8 -*- # Al ...

随机推荐

  1. Spring Boot利用poi导出Excel

    至于poi的用法就不多说了,网上多得很,但是发现spring boot结合poi的就不多了,而且大多也有各种各样的问题. public class ExcelData implements Seria ...

  2. java Long、Integer 、Double、Boolean类型 不能直接比较

    测试: System.out.println(new Long(1000)==new Long(1000)); System.out.println(new Integer(1000)==new In ...

  3. RF和GBDT的区别

    Random Forest ​采用bagging思想,即利用bootstrap抽样,得到若干个数据集,每个数据集都训练一颗树. 构建决策树时,每次分类节点时,并不是考虑全部特征,而是从特征候选集中选取 ...

  4. CALayer, CoreGraphics与CABasicAnimation介绍

    今天我们来看一下CALayer.CoreGraphics和CABasicAnimation.这些东西在处理界面绘制.动画效果上非常有用. 本篇博文就讲介绍CALayer的基本概念,使用CoreGrap ...

  5. 说一下自己对于 Linux 哲学的理解

    查阅了一些资料,官方的哲学思想貌似是: 一切皆文件 由众多单一目的的小程序,一个程序只实现一个功能,多个程序组合完成复杂任务 文本文件保存配置信息 尽量避免与用户交互 什么,你问我的理解?哲学思想?E ...

  6. 8、Semantic-UI之其他按钮样式

    8.1 其他按钮样式定义 示例:定义其他按钮样式 定义圆形图标按钮样式 <div class="ui circular icon button"><i class ...

  7. WinAPI 字符及字符串函数(10): lstrcpy - 复制字符串

    unit Unit1; interface uses   Windows, Messages, SysUtils, Variants, Classes, Graphics, Controls, For ...

  8. Linq to Entities基础之需要熟知14个linq关键字(from,where,select,group,let,on,by...)

    1.Linq基础 <1> 关键词: from,in,group,by,where..... MSDN上总结的有14个关键词法... from xxxx in xxxx select =&g ...

  9. 从源代码分析DbSet如何通过ObjectStateManager管理entity lifecycle的生命周期

    一:Savechange的时候,怎么知道哪些entity被add,modify,delete,unchange ???? 如何来辨别... 在entity中打上标记来做表示...已经被跟踪了...当每 ...

  10. oracle数据库sqlldr命令的使用

    将数据导入 oracle 的方法应该很多 , 对于不同需求有不同的导入方式 , 最近使用oracle的sqlldr命令 导入数据库数据感觉是个挺不错的技术点 .  使用sqlldr命令 将文本文件导入 ...