各种优化器的比较

莫烦的对各种优化通俗理解的视频

 import torch

 import torch.utils.data as Data

 import torch.nn.functional as F

 from torch.autograd import Variable

 import matplotlib.pyplot as plt

 # 超参数

 LR = 0.01

 BATCH_SIZE = 

 EPOCH = 

 # 生成假数据

 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据

 x = torch.unsqueeze(torch.linspace(-, , ), dim=)  # x data (tensor), shape(, )

 # 0.2 * torch.rand(x.size())增加噪点

 y = x.pow() + 0.1 * torch.normal(torch.zeros(*x.size()))

 # 输出数据图

 # plt.scatter(x.numpy(), y.numpy())

 # plt.show()

 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=)

 class Net(torch.nn.Module):

     # 初始化

     def __init__(self):

         super(Net, self).__init__()

         self.hidden = torch.nn.Linear(, )

         self.predict = torch.nn.Linear(, )

     # 前向传递

     def forward(self, x):

         x = F.relu(self.hidden(x))

         x = self.predict(x)

         return x

 net_SGD = Net()

 net_Momentum = Net()

 net_RMSProp = Net()

 net_Adam = Net()

 nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]

 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)

 opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)

 opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)

 opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))

 optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]

 loss_func = torch.nn.MSELoss()

 loss_his = [[], [], [], []]  # 记录损失

 for epoch in range(EPOCH):

     print(epoch)

     for step, (batch_x, batch_y) in enumerate(loader):

         b_x = Variable(batch_x)

         b_y = Variable(batch_y)

         for net, opt,l_his in zip(nets, optimizers, loss_his):

             output = net(b_x)  # get output for every net

             loss = loss_func(output, b_y)  # compute loss for every net

             opt.zero_grad()  # clear gradients for next train

             loss.backward()  # backpropagation, compute gradients

             opt.step()  # apply gradients

             l_his.append(loss.data.numpy())  # loss recoder

 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']

 for i, l_his in enumerate(loss_his):

     plt.plot(l_his, label=labels[i])

 plt.legend(loc='best')

 plt.xlabel('Steps')

 plt.ylabel('Loss')

 plt.ylim((, 0.2))

 plt.show()

莫烦pytorch学习笔记(七)——Optimizer优化器的更多相关文章

  1. 莫烦 - Pytorch学习笔记 [ 一 ]

    1. Numpy VS Torch #相互转换 np_data = torch_data.numpy() torch_data = torch.from_numpy(np_data) #abs dat ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

  4. [PyTorch 学习笔记] 4.3 优化器

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/optimizer_methods.py https: ...

  5. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  6. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  7. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

随机推荐

  1. Day 9 :初识函数

    Python函数:1.函数是组织好的,可重复使用的,用来实现单一,或相关联功能的代码段. 2.函数能提高应用的模块性,和代码的重复利用率. Python提供了许多内建函数,比如print().但你也可 ...

  2. ECMAScript1.1 js书写位置 | 声明变量 | 基本数据类型 | 数据类型转换 | 操作符 | 布尔类型的隐式转换

    js书写位置 由于在写css样式时使用的时双引号,所以我们在写js代码时建议使用单引号(‘’)! 行内式 <input type="button" value="点 ...

  3. 注意:字符串substring方法在jkd6,7,8中的差异。

    标题中的substring方法指的是字符串的substring(int beginIndex, int endIndex)方法,这个方法在jdk6,7是有差异的. substring有什么用? sub ...

  4. Android笔记之从图库选择图片

    Demo链接:https://pan.baidu.com/s/1T4T2pTEswmbcYYfpN3OwDw,提取码:pzqy 参考链接:[Android Example] Pick Image fr ...

  5. [笔记]Android开发环境配置及HelloWorld程序

    Android的开发须要下面四个工具: 1.JDK 2.Eclipse 3.Android SDK 4.ADT   具体功能: 1.JDK.JDK即Java Development Kit(Java开 ...

  6. 2019-4-29-WPF-如何判断一个控件在滚动条的里面是用户可见

    title author date CreateTime categories WPF 如何判断一个控件在滚动条的里面是用户可见 lindexi 2019-4-29 9:42:2 +0800 2019 ...

  7. 2018-8-10-WPF-使用-Direct2D1-画图-绘制基本图形

    title author date CreateTime categories WPF 使用 Direct2D1 画图 绘制基本图形 lindexi 2018-08-10 19:16:53 +0800 ...

  8. C 二维数组与指针

    http://c.biancheng.net/view/2022.html 1. 区分指针数组和数组指针 指针数组:存放指针的数组,如 int *pstr[5] = NULL; 数组中每个元素存放的是 ...

  9. Java Queue队列

    前言 Queue队列是一种特殊的线性表,它只允许在表的前端进行删除操作,而在表的后端进行插入操作,LinkedList类实现了Queue接口,因此我们可以把LinkedList当成Queue来用.  ...

  10. jquery实现文字由下到上循环滚动的实例代码

    <div id="oDiv"> <ul id="oUl"> <li>第1个li元素</li> <li> ...