各种优化器的比较

莫烦的对各种优化通俗理解的视频

 import torch

 import torch.utils.data as Data

 import torch.nn.functional as F

 from torch.autograd import Variable

 import matplotlib.pyplot as plt

 # 超参数

 LR = 0.01

 BATCH_SIZE = 

 EPOCH = 

 # 生成假数据

 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据

 x = torch.unsqueeze(torch.linspace(-, , ), dim=)  # x data (tensor), shape(, )

 # 0.2 * torch.rand(x.size())增加噪点

 y = x.pow() + 0.1 * torch.normal(torch.zeros(*x.size()))

 # 输出数据图

 # plt.scatter(x.numpy(), y.numpy())

 # plt.show()

 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=)

 class Net(torch.nn.Module):

     # 初始化

     def __init__(self):

         super(Net, self).__init__()

         self.hidden = torch.nn.Linear(, )

         self.predict = torch.nn.Linear(, )

     # 前向传递

     def forward(self, x):

         x = F.relu(self.hidden(x))

         x = self.predict(x)

         return x

 net_SGD = Net()

 net_Momentum = Net()

 net_RMSProp = Net()

 net_Adam = Net()

 nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]

 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)

 opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)

 opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)

 opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))

 optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]

 loss_func = torch.nn.MSELoss()

 loss_his = [[], [], [], []]  # 记录损失

 for epoch in range(EPOCH):

     print(epoch)

     for step, (batch_x, batch_y) in enumerate(loader):

         b_x = Variable(batch_x)

         b_y = Variable(batch_y)

         for net, opt,l_his in zip(nets, optimizers, loss_his):

             output = net(b_x)  # get output for every net

             loss = loss_func(output, b_y)  # compute loss for every net

             opt.zero_grad()  # clear gradients for next train

             loss.backward()  # backpropagation, compute gradients

             opt.step()  # apply gradients

             l_his.append(loss.data.numpy())  # loss recoder

 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']

 for i, l_his in enumerate(loss_his):

     plt.plot(l_his, label=labels[i])

 plt.legend(loc='best')

 plt.xlabel('Steps')

 plt.ylabel('Loss')

 plt.ylim((, 0.2))

 plt.show()

莫烦pytorch学习笔记(七)——Optimizer优化器的更多相关文章

  1. 莫烦 - Pytorch学习笔记 [ 一 ]

    1. Numpy VS Torch #相互转换 np_data = torch_data.numpy() torch_data = torch.from_numpy(np_data) #abs dat ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

  4. [PyTorch 学习笔记] 4.3 优化器

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/optimizer_methods.py https: ...

  5. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  6. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  7. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

随机推荐

  1. (二十三)Http请求的处理过程

  2. 9款很棒的网页绘制图表JavaScript框架脚本

    推荐9款很棒的可在网页中绘制图表的JavaScript脚本,这些有趣的JS脚本可以帮助你快速方便的绘制图表(线.面.饼.条…),其中包括jQuery.MooTools.Prototype和一些其它的J ...

  3. 网页开发人员收藏的16款HTML5工具

    本文收集的20款优秀的 HTML5 Web 应用程序,值得添加到您的 HTML5 的工具箱中,他们能够帮助你开发前端项目更快.更容易. Initializr Initializr 是一个可以让你创建 ...

  4. Java面试(1)

    一.Java基础 什么是字符串常量池? Java中的字符串常量池(String Pool)是存储在Java堆内存中的字符串池: String是java中比较特殊的类,我们可以使用new运算符创建Str ...

  5. linux 系统优化,调优

    1.系统安装前的规则 a.分区:不同环境不同分法,按自己的需求来 以硬盘500G为例 /boot 100M-200M(只存放grub和启动相关文件,不存放其他) /  80G-100G (因为很多人默 ...

  6. 深度探索C++对象模型之第二章:构造函数语意学之Copy constructor的构造操作

    C++ Standard将copy constructor分为trivial 和nontrivial两种:只有nontrivial的实例才会被合成于程序之中.决定一个copy constructor是 ...

  7. vue v-model :

    v-model : 通过v-model 进行双向绑定 ,将data的数据与input 绑定在一起,呈现在页面上 <!DOCTYPE html> <html lang="en ...

  8. Java 学习 时间格式化(SimpleDateFormat)与历法类(Calendar)用法详解

    基于Android一些时间创建的基本概念 获取当前时间 方式一: Date date = new Date(); Log.e(TAG, "当前时间="+date); 结果: E/T ...

  9. 【JZOJ6346】ZYB和售货机

    description analysis 其实这个连出来的东西叫基环内向树 先考虑很多森林的情况,也就是树根连回自己 明显树根物品是可以被取完的,那么买树根的价钱要是儿子中价钱最小的那个 或者把那个叫 ...

  10. 廖雪峰Java16函数式编程-2Stream-4map

    1. map()简介 Stream.map()是一个Stream的转换方法,把一个stream转换为另一个Stream,这2个Stream是按照映射函数一一对应的. 所谓map操作,就是把一种操作运算 ...