使用pytorch构建神经网络的流程以及一些问题
使用PyTorch构建神经网络十分的简单,下面是我总结的PyTorch构建神经网络的一般过程以及我在学习当中遇到的一些问题,期望对你有所帮助。
PyTorch构建神经网络的一般过程
下面的程序是PyTorch官网60分钟教程上面构建神经网络的例子,版本0.4.1:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# 第一步:准备数据
# Compose是将两个转换的过程组合起来,ToTensor将numpy等数据类型转换为Tensor,将值变为0到1之间
# Normalize用公式(input-mean)/std 将值进行变换。这里mean=0.5,std=0.5,是将[0,1]区间转换为[-1,1]区间
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# trainloader 是一个将数据集和采样策略结合起来的,并提供在数据集上面迭代的方法
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 第二步:构建神经网络框架,继承nn.Module类
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
# 第三步:进行训练
# 定义损失策略和优化方法
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 训练神经网络
for epoch in range(4):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
# 训练过程1:前向过程,计算输入到输出的结果
outputs = net(inputs)
# 训练过程2:由结果和label计算损失
loss = criterion(outputs, labels)
# 训练过程3:在图的层次上面计算所有变量的梯度
# 每次计算梯度的时候,其实是有一个动态的图在里面的,求导数就是对图中的参数w进行求导的过程
# 每个参数计算的梯度值保存在w.grad.data上面,在参数更新时使用
loss.backward()
# 训练过程4:进行参数的更新
# optimizer不计算梯度,它利用已经计算好的梯度值对参数进行更新
optimizer.step()
running_loss += loss.item() # item 返回的是一个数字
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' %
(epoch+1, i+1, running_loss/2000))
running_loss = 0.0
print('Finished Training')
# 第四步:在测试集上面进行测试
total = 0
correct = 0
with torch.no_grad():
for data in testloader:
images, label = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == label).sum().item()
print("Accuracy of networkd on the 10000 test images: %d %%" % (100*correct/total))
这个例子说明了构建神经网络的四个步骤:1:准备数据集 。2:构建神经网络框架,实现神经网络的类。 3:在训练集上进行训练。 4:在测试集上面进行测试。
而在第三步的训练阶段,也可以分为四个步骤:1:前向过程,计算输入到输出的结果。2:由结果和labels计算损失。3:后向过程,由损失计算各个变量的梯度。4:优化器根据梯度进行参数的更新。
训练过程中第loss和optim是怎么联系在一起的
loss是训练阶段的第三步,计算参数的梯度。optim是训练阶段的第四步,对参数进行更新。在optimizer初始化的时候,optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
,获取了参数的指针,可以对参数进行修改。当loss计算好参数的梯度以后,把值放在参数w.grad.data上面,然后optimizer直接利用这个值对参数进行更新。
以SGD为例,它进行step的时候的基本操作是这样的: p.data.add_(-group['lr'], d_p)
,其中 d_p = p.grad.data
为什么要进行梯度清零
在backward每次计算梯度的时候,会将新的梯度值加到原来旧的梯度值上面,这叫做梯度累加。下面的程序可以说明什么是梯度累加:
import torch
x = torch.rand(2, requires_grad=True)
y = x.mean() # y = (x_1 + x_2) / 2 所以求梯度后应是0.5
y.backward()
print(x.grad.data) # 输出结果:tensor([0.5000, 0.5000])
y.backward()
print(x.grad.data) # 输出结果:tensor([1., 1.]) 说明进行了梯度累积
求解梯度过程和参数更新过程是分开的,这对于那些需要多次求导累计梯度,然后一次更新的神经网络可能是有帮助的,比如RNN,对于DNN和CNN不需要进行梯度累加,所以需要进行梯度清零。
如何使用GPU进行训练
旧版本:
use_cuda = True if torch.cuda.is_available() else False # 是否使用cuda
if use_cuda:
model = model.cuda() # 将模型的参数放入GPU
if use_cuda:
inputs, labels = inputs.cuda(), labels.cuda() # 将数据放入到GPU
0.4版本以后推荐新方法 to(device),
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device) #将模型的参数放入GPU中
inputs, labels = inputs.to(device), labels.to(device) # 将数据放入到GPU中
参考:
Pytorch内部中optim和loss是如何交互的? - 罗若天的回答 - 知乎
pytorch学习笔记(二):gradient
使用pytorch构建神经网络的流程以及一些问题的更多相关文章
- 使用PyTorch构建神经网络以及反向传播计算
使用PyTorch构建神经网络以及反向传播计算 前一段时间南京出现了疫情,大概原因是因为境外飞机清洁处理不恰当,导致清理人员感染.话说国外一天不消停,国内就得一直严防死守.沈阳出现了一例感染人员,我在 ...
- 使用PyTorch构建神经网络模型进行手写识别
使用PyTorch构建神经网络模型进行手写识别 PyTorch是一种基于Torch库的开源机器学习库,应用于计算机视觉和自然语言处理等应用,本章内容将从安装以及通过Torch构建基础的神经网络,计算梯 ...
- TFLearn构建神经网络
TFLearn构建神经网络 Building the network TFLearn lets you build the network by defining the layers. Input ...
- 在IDEA中构建Tomcat项目流程
在IDEA中构建Web项目流程 打开你的IDEA,跟着我走! 第一步:新建项目 第二步:找到Artifacts 点击绿色的+号,如图所示,点一下 这一步很关键,目的是设置输出格式为war包,如果你的项 ...
- pytorch构建自己的数据集
现在需要在json文件里面读取图片的URL和label,这里面可能会出现某些URL地址无效的情况. python读取json文件 此处只需要将json文件里面的内容读取出来就可以了 with open ...
- tensorflow之神经网络实现流程总结
tensorflow之神经网络实现流程总结 1.数据预处理preprocess 2.前向传播的神经网络搭建(包括activation_function和层数) 3.指数下降的learning_rate ...
- Tensorflow BatchNormalization详解:2_使用tf.layers高级函数来构建神经网络
Batch Normalization: 使用tf.layers高级函数来构建神经网络 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 吴恩达deeplearningai课程 课程笔 ...
- GeneXus DevOps 自动化构建和部署流程
以下视频详细介绍了GeneXus DevOps自动化构建和部署流程,包括通过MS Bulid来管理自动化流程,自动化的架构,以及在GeneXus Server上使用Jenkins做为自动化引擎. 视频 ...
- 使用 Visual Studio 2015 + Python3.6 + tensorflow 构建神经网络时报错:'utf-8' codec can't decode byte 0xcc in position 78: invalid continuation byte
使用 Visual Studio 2015 + Python3.6 + tensorflow 构建神经网络时报错:'utf-8' codec can't decode byte 0xcc in pos ...
随机推荐
- bzoj千题计划179:bzoj1237: [SCOI2008]配对
http://www.lydsy.com/JudgeOnline/problem.php?id=1237 如果没有相同的数不能配对的限制 那就是排好序后 Σ abs(ai-bi) 相同的数不能配对 交 ...
- ngx_lua_API 指令详解(六)ngx.thread.spawn、ngx.thread.wait、ngx.thread.kill介绍
摘要:通过lua-nginx-module中的ngx.thread同时执行多个任务. ngx_lua中访问多个第三方服务 ngx_lua中提供了ngx.socket API,可以方便的访问第三方网络服 ...
- elementUI 通用确认框
Util.vue <script> import VueResource from 'vue-resource' function confirm(_this, operate, fun) ...
- 20155206 2016-2017-2 《Java程序设计》第6周学习总结
20155206 2016-2017-2 <Java程序设计>第6周学习总结 教材学习内容总结 串流设计 流(Stream)是对「输入输出」的抽象,注意「输入输出」是相对程序而言的. Ja ...
- CString 与其它数据类型转换问题
CString 头文件#include <afx.h> string 头文件#include <string.h> CString 转char * CString cstr; ...
- python中的*号
from:https://www.douban.com/note/231603832/ 传递实参和定义形参(所谓实参就是调用函数时传入的参数,形参则是定义函数是定义的参数)的时候,你还可以使用两个特殊 ...
- .NET 的 WCF 和 WebService 有什么区别?(转载)
[0]问题: WCF与 Web Service的区别是什么? 和ASP.NET Web Service有什么关系? WCF与ASP.NET Web Service的区别是什么? 这是很多.NET开发人 ...
- 【坐在马桶上看算法】算法7:Dijkstra最短路算法
上周我们介绍了神奇的只有五行的Floyd最短路算法,它可以方便的求得任意两点的最短路径,这称为“多源最短路”.本周来来介绍指定一个点(源点)到其余各个顶点的最短路径,也叫做“单源最短路径 ...
- elasticsearch代码片段,及工具类SearchEsUtil.java
ElasticSearchClient.java package com.zbiti.framework.elasticsearch.utils; import java.util.Arrays; i ...
- android 调用系统照相机拍照后保存到系统相册,在系统图库中能看到
需求: 调用系统照相机进行拍照,并且保存到系统相册,调用系统相册的时候能看到 系统相册的路径:String cameraPath= Environment.getExternalStorageD ...