在上一节中,我们介绍了如何使用Pytorch来搭建一个经典的分类神经网络。一般情况下,搭建完模型后训练不会一次就能达到比较好的效果,这样,就需要不断的调整和优化模型的各个部分。从而引出了本文的主旨:如何优化模型。

在本节中,我们将介绍从数据集到模型各个部分的调整,从而可以有一个完整的解决思路。

1、数据集部分

1.1 数据集划分

一般情况下,我们会把数据集分成三个部分:训练集,验证集和测试集。依据数据集的大小,如果数据集比较大,数万或数十万个,可以将数据集采用7:2:1或8:1:1的比例来划分。而如果数据集比较小,只有几百条,就不能简单的使用这个方法了。这时,需要使用K折验证法(具体方法可自行百度)。

当然,还有一些需要考虑的问题:数据表征,时间敏感性和数据冗余。在数据表征中,随机打乱(shuffle)是一个不错的选择;时间敏感性主要是针对回归问题象预测股票,不同的月份对回归结果有一个不同的贡献;数据冗余指的是,在数据集中,存在着一些相同的数据会对训练和测试结果产生影响,所以,需要事先过滤掉。

1.2数据预处理

数据向量化:数据源形式各异,需要提前把它转换成框架可以识别的形式,Pytorch统一使用向量(Vector)来表示数据。

正则化:数据的范围大小不一,如果直接使用,训练的收敛会很慢,甚至会出现异常。所以,需要统一数据的范围大小,也就是去除纲量,使用【0,1】区间来统一度量。

缺失数据的处理:如果没有对缺失数据进行处理,训练过程中会直接导致数据的权重分配异常,进而直接影响训练效果。

特征工程:对数据集的特征进行有效提取,是保证模型正常训练的前提。

1.3过拟合与欠拟合

过拟合:训练效果好而验证效果不好。

欠拟合:训练效果不好。

欠拟合的处理相对容易些,针对欠拟合,我们一般采用加大训练周期,降低训练损失,提高训练精度。

过拟合策略:

1、获取更多数据

2、减小网络规模

  1. 原始模型:
  2. class Architecture1(nn.Module):
  3. def __init__(self, input_size, hidden_size, num_classes):
  4. super(Architecture1, self).__init__()
  5. self.fc1 = nn.Linear(input_size, hidden_size)
  6. self.relu = nn.ReLU()
  7. self.fc2 = nn.Linear(hidden_size, num_classes)
  8. self.relu = nn.ReLU()
  9. self.fc3 = nn.Linear(hidden_size, num_classes)
  10. def forward(self, x):
  11. out = self.fc1(x)
  12. out = self.relu(out)
  13. out = self.fc2(out)
  14. out = self.relu(out)
  15. out = self.fc3(out)
  16. return out
  1. 减小规模后的模型:
  2. class Architecture2(nn.Module):
  3. def __init__(self, input_size, hidden_size, num_classes):
  4. super(Architecture2, self).__init__()
  5. self.fc1 = nn.Linear(input_size, hidden_size)
  6. self.relu = nn.ReLU()
  7. self.fc2 = nn.Linear(hidden_size, num_classes)
  8. def forward(self, x):
  9. out = self.fc1(x)
  10. out = self.relu(out)
  11. out = self.fc2(out)
  12. return out

3、使用权重正则化

正则化分为1阶正则化和2阶正则化

     1阶正则化是将权重协相关系数的相差绝对值加入权重。

2阶正则化是将权重协相关系数的相差平方和加入权重。示例如下:

  1. model = Architecture1(10,20,2)
  2. optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

4、使用DROPOUT

在隐藏层中去除某些节点,以达到防止过拟合的问题。

dropout的比率为0.2:

dropout的比率为0.5

  1. nn.dropout(x, training=True)

1.4问题定义与数据集获取

首先需要明确两个事情:问题的类别与数据的输入,确定是分类问题还是回归问题。

不同类别的问题有着不同的处理方法,对数据集的获取也是必须面对的一大难题。

1.5模型评估

对于分类问题,一般采用精度,ROC,AUC等方法来进行评估。

而对于排名问题,一般采用mAp。

2、模型部分

2.1 搭建完基础模型后,为了使该模型能够正常工作,我们需要做以下三部分工作:

1、选择网络输出的最后一层

不同的任务,输出最后一层也不尽相同。一般的回归问题只要输出一个标量就可以;向量回归问题则需要输出相同层的向量;对于BBOX问题,则需要输出四个值;对于

二分类,我们需要使用Sigmoid,对于多分类则使用softmax。

2、选择损失函数

对于分类问题,一般采用交叉熵损失;而对于回归问题,则采用均方差。

3、优化器

如何选择一个优化器及配置相关参数是一件非常有艺术性的事。有时需要通过实验来得到。很多时候:Adam和RMSProp是个不错的选择。

  1. Problem type Activation function Loss function
  2. Binary classification Sigmoid activation nn.CrossEntropyLoss()
  3. Multi-class classification Softmax activation nn.CrossEntropyLoss()
  4. Multi-label classification Sigmoid activation nn.CrossEntropyLoss()
  5. Regression None MSE
  6. Vector regression None MSE

2.2 提高模型规模

对于一个已搭建好的模型,如何提高模型的推理能力。可以从这三方面来提高:

1、增加更多的层

2、加入更多的权重系数

3、提高训练周期

2.3 加入泛化策略

1、加入dropout

2、使用不同的架构,不同的参数,不同的网络层数,权重。

3、使用L1或L2正则化

4、尝试不同的学习率

5、增加更多的数据或特征

2.4学习率的设置

学习率对于模型来说,是一个非常重要的超参数。它的设置很多时候直接决定着模型训练效果的好坏。所以,如何设置该参数就变得非常重要。有大量的研究就是针对于该参数进行的。

在Pytorch中,有一系列的方法:

1、stepLR:

  1. scheduler = StepLR(optimizer, step_size=30, gamma=0.1) #step_size:多少个周期后学习率发生改变 gamma:学习率如何你改变
  2. for epoch in range(100):
  3. scheduler.step()
  4. train(...)
  5. validate(...)

2、MultiStepLR

  1. scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
  1. #milestones:多少个周期后学习率发生改变 gamma:学习率如何你改变
  1. for epoch in range(100): scheduler.step() train(...) validate(...)

3、ExponentialLR

4、ReduceLROnPlateau

  1. optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
  2. momentum=0.9)
  3. scheduler = ReduceLROnPlateau(optimizer, 'min')
  4. for epoch in range(10):
  5. train(...)
  6. val_loss = validate(...)
  7. # Note that step should be called after validate()
  8. scheduler.step(val_loss)

上一篇:

如何入门Pytorch之二:如何搭建实用神经网络

下一篇:

待更新。。。

如何入门Pytorch之三:如何优化神经网络的更多相关文章

  1. 如何入门Pytorch之四:搭建神经网络训练MNIST

    上一节我们学习了Pytorch优化网络的基本方法,本节我们将以MNIST数据集为例,通过搭建一个完整的神经网络,来加深对Pytorch的理解. 一.数据集 MNIST是一个非常经典的数据集,下载链接: ...

  2. 如何入门Pytorch之二:如何搭建实用神经网络

    上一节中,我们介绍了Pytorch的基本知识,如数据格式,梯度,损失等内容. 在本节中,我们将介绍如何使用Pytorch来搭建一个经典的分类神经网络. 搭建一个神经网络并训练,大致有这么四个部分: 1 ...

  3. 60 分钟极速入门 PyTorch

    2017 年初,Facebook 在机器学习和科学计算工具 Torch 的基础上,针对 Python 语言发布了一个全新的机器学习工具包 PyTorch. 因其在灵活性.易用性.速度方面的优秀表现,经 ...

  4. 如何入门Pytorch之一:Pytorch基本知识介绍

    前言 PyTorch和Tensorflow是目前最为火热的两大深度学习框架,Tensorflow主要用户群在于工业界,而PyTorch主要用户分布在学术界.目前视觉三大顶会的论文大多都是基于PyTor ...

  5. 新手如何入门pytorch?

    我最近的文章中,专门为想学Pytorch的新手推荐了一些学习资源,包括教程.视频.项目.论文和书籍.希望能对你有帮助:一.PyTorch学习教程.手册 (1)PyTorch英文版官方手册:https: ...

  6. 【OpenCV入门教程之三】 图像的载入,显示和输出 一站式完全解析(转)

    本系列文章由@浅墨_毛星云 出品,转载请注明出处. 文章链接:http://blog.csdn.net/poem_qianmo/article/details/20537737 作者:毛星云(浅墨)  ...

  7. Asp.Net MVC4.0 官方教程 入门指南之三--添加一个视图

    Asp.Net MVC4.0 官方教程 入门指南之三--添加一个视图 在本节中,您需要修改HelloWorldController类,从而使用视图模板文件,干净优雅的封装生成返回到客户端浏览器HTML ...

  8. PyTorch-Adam优化算法原理,公式,应用

    概念:Adam 是一种可以替代传统随机梯度下降过程的一阶优化算法,它能基于训练数据迭代地更新神经网络权重.Adam 最开始是由 OpenAI 的 Diederik Kingma 和多伦多大学的 Jim ...

  9. 深度学习之入门Pytorch(1)------基础

    目录: Pytorch数据类型:Tensor与Storage 创建张量 tensor与numpy数组之间的转换 索引.连接.切片等 Tensor操作[add,数学运算,转置等] GPU加速 自动求导: ...

随机推荐

  1. 微信小程序的target和currentTarget的区别

    https://www.jb51.net/article/160886.htm 在小程序的事件回调触发时,会接收一个事件对象,事件对象的参数中包含一个target和currentTarget属性,接下 ...

  2. Unity 的 [HideInInspector]

    [HideInInspector] public Transform t; public Transform mm; public Transform nn3; 在变量前面加入,作用:隐藏下一条在In ...

  3. 对ysoserial工具及java反序列化的一个阶段性理解【未完成】

    经过一段时间的琢磨与反思,以及重读了大量之前看不懂的反序列化文章,目前为止算是对java反序列化这块有了一个阶段性的小理解. 目前为止,发送的所有java反序列化的漏洞中.主要需要两个触发条件: 1. ...

  4. ie兼容promise

    引入 <script src = "https://cdn.polyfill.io/v2/polyfill.min.js"></script> 或 < ...

  5. Odoo13 新变化:会计

    Odoo13将于2019年10月发布,本次发布也包含了大量的改进,例如,对会计的重构. 去掉了 account.invoice / account.invoice.line/ account.vouc ...

  6. eNSP——通过Stelnet登录系统

    Stelnet的原理 由于Telnet缺少安全的认证方式,而且传输过程采用TCP进行明文传输,存在很大的安全隐患,单纯提供Telnet服务容易招致主机IP地址欺骗.路由欺骗等恶意攻击.传统的Telne ...

  7. Java基础语法知识你真的都会吗?

    第一阶段 JAVA基础知识 第二章 Java基础语法知识 在我们开始讲解程序之前,命名规范是我们不得不提的一个话题,虽说命名本应该是自由的,但是仍然有一定的"潜规则",通过你对命名 ...

  8. [转帖]iphone11的部分参数 UX

    iPhone 11将于9月11号凌晨发布 靠谱爆料在这 https://www.cnbeta.com/articles/tech/884199.htm iphone的分辨率 非常高.. iphone ...

  9. Ubuntu中使用python3中的venv创建虚拟环境

    以前不知道Python3中内置了venv模块,一直用的就是virtualenv模块,venv相比virtualenv好用不少,可以替代virtualenv 一.安装venv包: $ sudo apt ...

  10. (六)Java秒杀项目之接口优化

    一.Redis预减库存减少数据库访问 思路:减少数据库访问 1.系统初始化,把商品库存数量加载到Redis 2.收到请求,Redis预减库存,库存不足,直接返回,否则进入3 3.请求入队,立即返回排队 ...