各种优化器的比较

莫烦的对各种优化通俗理解的视频

 import torch

 import torch.utils.data as Data

 import torch.nn.functional as F

 from torch.autograd import Variable

 import matplotlib.pyplot as plt

 # 超参数

 LR = 0.01

 BATCH_SIZE = 

 EPOCH = 

 # 生成假数据

 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据

 x = torch.unsqueeze(torch.linspace(-, , ), dim=)  # x data (tensor), shape(, )

 # 0.2 * torch.rand(x.size())增加噪点

 y = x.pow() + 0.1 * torch.normal(torch.zeros(*x.size()))

 # 输出数据图

 # plt.scatter(x.numpy(), y.numpy())

 # plt.show()

 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=)

 class Net(torch.nn.Module):

     # 初始化

     def __init__(self):

         super(Net, self).__init__()

         self.hidden = torch.nn.Linear(, )

         self.predict = torch.nn.Linear(, )

     # 前向传递

     def forward(self, x):

         x = F.relu(self.hidden(x))

         x = self.predict(x)

         return x

 net_SGD = Net()

 net_Momentum = Net()

 net_RMSProp = Net()

 net_Adam = Net()

 nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]

 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)

 opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)

 opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)

 opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))

 optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]

 loss_func = torch.nn.MSELoss()

 loss_his = [[], [], [], []]  # 记录损失

 for epoch in range(EPOCH):

     print(epoch)

     for step, (batch_x, batch_y) in enumerate(loader):

         b_x = Variable(batch_x)

         b_y = Variable(batch_y)

         for net, opt,l_his in zip(nets, optimizers, loss_his):

             output = net(b_x)  # get output for every net

             loss = loss_func(output, b_y)  # compute loss for every net

             opt.zero_grad()  # clear gradients for next train

             loss.backward()  # backpropagation, compute gradients

             opt.step()  # apply gradients

             l_his.append(loss.data.numpy())  # loss recoder

 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']

 for i, l_his in enumerate(loss_his):

     plt.plot(l_his, label=labels[i])

 plt.legend(loc='best')

 plt.xlabel('Steps')

 plt.ylabel('Loss')

 plt.ylim((, 0.2))

 plt.show()

莫烦pytorch学习笔记(七)——Optimizer优化器的更多相关文章

  1. 莫烦 - Pytorch学习笔记 [ 一 ]

    1. Numpy VS Torch #相互转换 np_data = torch_data.numpy() torch_data = torch.from_numpy(np_data) #abs dat ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

  4. [PyTorch 学习笔记] 4.3 优化器

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/optimizer_methods.py https: ...

  5. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  6. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  7. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

随机推荐

  1. 剑指offer——25合并两个排序的链表

    题目描述 输入两个单调递增的链表,输出两个链表合成后的链表,当然我们需要合成后的链表满足单调不减规则.   题解: 使用普通方法,或者递归,注意新的头节点即可. //使用普通的合并方法 class S ...

  2. Ant属性文件

    直接在构建文件中设置属性是好的,如果你使用的是少数属性.然而,对于一个大型项目,是要存储在一个单独的属性文件中. 存储在一个单独的文件中的属性可以让你重复使用相同的编译文件,针对不同的执行环境不同的属 ...

  3. Centos 7 ping 不通外网

    首先检查添加DNS是否正常,如不存在则添加dns: [root@cgls]# vim /etc/resolv.conf nameserver 114.114.114.114 nameserver 8. ...

  4. 协方差及matlib绘制

    转自http://www.cnblogs.com/chaosimple/p/3182157.html 一.统计学的基本概念 统计学里最基本的概念就是样本的均值.方差.标准差.首先,我们给定一个含有n个 ...

  5. 数组模拟stack

    package com.cxy.springdataredis.data; import java.util.Scanner; public class StackDemo { public stat ...

  6. 科普帖:深度学习中GPU和显存分析

    知乎的一篇文章: https://zhuanlan.zhihu.com/p/31558973 关于如何使用nvidia-smi查看显存与GPU使用情况,参考如下链接: https://blog.csd ...

  7. mysql DOS中中文乱码 ERROR 1366 (HY000): Incorrect string value: '\xC4\xEA\xBC\xB6' for column 'xxx' at row 1

    问题:ERROR (HY000): Incorrect string value: 在DOS中插入或查询中文出现乱码 登入mysql,输入命令:show variables like '%char%' ...

  8. JavaScript 数组函数 map()

    JavaScript 数组函数 map() 学习心得 map()函数是一个数组函数: 它对数组每个原素进行操作,不对空数组进行操作: 不改变原本的数组,返回新数组: arr.map(function( ...

  9. 【笔记篇】斜率优化dp(五) USACO08MAR土地购(征)买(用)Land Acquisition

    好好的题目连个名字都不统一.. 看到这种最大最小的就先排个序嘛= =以x为第一关键字, y为第二关键字排序. 然后有一些\(x_i<=x_{i+1},且y_i<=y_{i+1}\)的土地就 ...

  10. CF 540D Bad Luck Island

    一看就是DP题(很水的一道紫题) 设\(dp[i][j][k]\)为留下\(i\)个\(r\)族的人,死去\(j\)个\(s\)族的人,死去\(k\)个\(p\)族的人的概率(跟其他的题解有点差别,但 ...