各种优化器的比较

莫烦的对各种优化通俗理解的视频

 import torch

 import torch.utils.data as Data

 import torch.nn.functional as F

 from torch.autograd import Variable

 import matplotlib.pyplot as plt

 # 超参数

 LR = 0.01

 BATCH_SIZE = 

 EPOCH = 

 # 生成假数据

 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据

 x = torch.unsqueeze(torch.linspace(-, , ), dim=)  # x data (tensor), shape(, )

 # 0.2 * torch.rand(x.size())增加噪点

 y = x.pow() + 0.1 * torch.normal(torch.zeros(*x.size()))

 # 输出数据图

 # plt.scatter(x.numpy(), y.numpy())

 # plt.show()

 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=)

 class Net(torch.nn.Module):

     # 初始化

     def __init__(self):

         super(Net, self).__init__()

         self.hidden = torch.nn.Linear(, )

         self.predict = torch.nn.Linear(, )

     # 前向传递

     def forward(self, x):

         x = F.relu(self.hidden(x))

         x = self.predict(x)

         return x

 net_SGD = Net()

 net_Momentum = Net()

 net_RMSProp = Net()

 net_Adam = Net()

 nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]

 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)

 opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)

 opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)

 opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))

 optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]

 loss_func = torch.nn.MSELoss()

 loss_his = [[], [], [], []]  # 记录损失

 for epoch in range(EPOCH):

     print(epoch)

     for step, (batch_x, batch_y) in enumerate(loader):

         b_x = Variable(batch_x)

         b_y = Variable(batch_y)

         for net, opt,l_his in zip(nets, optimizers, loss_his):

             output = net(b_x)  # get output for every net

             loss = loss_func(output, b_y)  # compute loss for every net

             opt.zero_grad()  # clear gradients for next train

             loss.backward()  # backpropagation, compute gradients

             opt.step()  # apply gradients

             l_his.append(loss.data.numpy())  # loss recoder

 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']

 for i, l_his in enumerate(loss_his):

     plt.plot(l_his, label=labels[i])

 plt.legend(loc='best')

 plt.xlabel('Steps')

 plt.ylabel('Loss')

 plt.ylim((, 0.2))

 plt.show()

莫烦pytorch学习笔记(七)——Optimizer优化器的更多相关文章

  1. 莫烦 - Pytorch学习笔记 [ 一 ]

    1. Numpy VS Torch #相互转换 np_data = torch_data.numpy() torch_data = torch.from_numpy(np_data) #abs dat ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

  4. [PyTorch 学习笔记] 4.3 优化器

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/optimizer_methods.py https: ...

  5. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  6. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  7. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

随机推荐

  1. 20130330 printf数组改变 数组指针便利二维数组 二级指针遍历二维数组 ZigZag

    1.为什么printf之后数组的值会改变? #include<stdio.h> ; int * Zigzag() { ,j=,limit=; ; ; int a[N][N]; int (* ...

  2. Adobe Fireworks CS6 win64的安装

    网页三大剑客之一    FW的安装 本人也是找了半天才找到的. (没有视频)这里先感谢原帖给我的链接https://blog.csdn.net/qq_38053395/article/details/ ...

  3. jquery与zend framework编写的联动选项效果

    html部分: <pre name="code" class="html"><!DOCTYPE html PUBLIC "-//W3 ...

  4. 【Neo4j】踩坑大会-Neo4J用中文索引

    正在用的Neo4j是当前最新版:3.1.0,各种踩坑.说一下如何在Neo4j 3.1.0中使用中文索引.选用了IKAnalyzer做分词器. 1. 首先参考文章: https://segmentfau ...

  5. 终于搭好了WinCE上MFC的SDK环境

    终于可以我的嵌入式之旅了,幸福啊...

  6. echarts数据变了不重新渲染,以及重新渲染了前后数据会重叠渲染的问题

    1.echarts数据变了但是视图不重新渲染 新建Chart.vue文件 <template>  <p :id="id" :style="style&q ...

  7. Docker学习の更改Docker的目录

    一.更改虚拟磁盘的目录 虚拟机的默认存储位置是C:\Users\Administrator\.docker\machine\machines ,后期docke镜像文件会不断增加,为了给系统盘减负,最好 ...

  8. Benchmark of Large-scale Unconstrained Face Recognition-blufr 算法的理解

    Many efforts have been made in recent years to tackle the unconstrained face recognition challenge. ...

  9. CSIC_716_20191126【面向对象编程--继承】

    继承 什么是继承:继承是新建类的一种方式,通过此方式生成的类称为子类.或者 派生类,被继承的类称为父类.基类或超类.在python中,一个子类可以继承多个父类. 继承的作用:减少代码的冗余,提高开发效 ...

  10. Perl 环境安装

    Perl 环境安装 在我们开始学习 Perl 语言前,我们需要先安装 Perl 的执行环境. Perl 可以在以下平台下运行: Unix (Solaris, Linux, FreeBSD, AIX, ...