各种优化器的比较

莫烦的对各种优化通俗理解的视频

 import torch

 import torch.utils.data as Data

 import torch.nn.functional as F

 from torch.autograd import Variable

 import matplotlib.pyplot as plt

 # 超参数

 LR = 0.01

 BATCH_SIZE = 

 EPOCH = 

 # 生成假数据

 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据

 x = torch.unsqueeze(torch.linspace(-, , ), dim=)  # x data (tensor), shape(, )

 # 0.2 * torch.rand(x.size())增加噪点

 y = x.pow() + 0.1 * torch.normal(torch.zeros(*x.size()))

 # 输出数据图

 # plt.scatter(x.numpy(), y.numpy())

 # plt.show()

 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=)

 class Net(torch.nn.Module):

     # 初始化

     def __init__(self):

         super(Net, self).__init__()

         self.hidden = torch.nn.Linear(, )

         self.predict = torch.nn.Linear(, )

     # 前向传递

     def forward(self, x):

         x = F.relu(self.hidden(x))

         x = self.predict(x)

         return x

 net_SGD = Net()

 net_Momentum = Net()

 net_RMSProp = Net()

 net_Adam = Net()

 nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]

 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)

 opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)

 opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)

 opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))

 optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]

 loss_func = torch.nn.MSELoss()

 loss_his = [[], [], [], []]  # 记录损失

 for epoch in range(EPOCH):

     print(epoch)

     for step, (batch_x, batch_y) in enumerate(loader):

         b_x = Variable(batch_x)

         b_y = Variable(batch_y)

         for net, opt,l_his in zip(nets, optimizers, loss_his):

             output = net(b_x)  # get output for every net

             loss = loss_func(output, b_y)  # compute loss for every net

             opt.zero_grad()  # clear gradients for next train

             loss.backward()  # backpropagation, compute gradients

             opt.step()  # apply gradients

             l_his.append(loss.data.numpy())  # loss recoder

 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']

 for i, l_his in enumerate(loss_his):

     plt.plot(l_his, label=labels[i])

 plt.legend(loc='best')

 plt.xlabel('Steps')

 plt.ylabel('Loss')

 plt.ylim((, 0.2))

 plt.show()

莫烦pytorch学习笔记(七)——Optimizer优化器的更多相关文章

  1. 莫烦 - Pytorch学习笔记 [ 一 ]

    1. Numpy VS Torch #相互转换 np_data = torch_data.numpy() torch_data = torch.from_numpy(np_data) #abs dat ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

  4. [PyTorch 学习笔记] 4.3 优化器

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/optimizer_methods.py https: ...

  5. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  6. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  7. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

随机推荐

  1. AtCoder ABC 126F XOR Matching

    题目链接:https://atcoder.jp/contests/abc126/tasks/abc126_f 题目大意 给定两个整数 M 和 K ,用小于 2M 的的所有自然数,每个两个,用这些数排成 ...

  2. IPv6 三个访问本地地址的小Tips

    最近发现家里宽带支持IPv6了,这里分享三个利用IPv6访问本地地址(内网地址)的方法. 通常来说,我们用localhost来代表本地地址127.0.0.1.其实在IPv6中有他自己的表示方法ip6- ...

  3. 『BASH』——Learn BashScript from Daniel Robbins——[001-002]

    ABSTRACT: Daniel Robbins is best known as the creator of Gentoo Linux and author of many IBM develop ...

  4. windows下装LINUX后,进不了系统

    在网上找了一款叫"DisckGenius"的软件,运行后选“硬盘”/“重建主引导记录(MBR)”,然后重启,就正常了. 还有系统盘最好是FAT32格式的.

  5. mybatis 3 批量插入返回主键 Parameter 'id' not found

    @Insert("<script>INSERT INTO scp_activity_gift (activity_id,type,gift_id,status,limit_num ...

  6. python+tushare获取A股所有股票代码和名称列表

    接口:stock_basic 描述:获取基础信息数据,包括股票代码.名称.上市日期.退市日期等 注:tushare模块下载和安装教程,请查阅我之前的文章 输入参数 名称      |      类型  ...

  7. 【2018ACM/ICPC网络赛】徐州赛区

    呃.自闭了自闭了.我才不会说我写D写到昏天黑地呢. I  Characters with Hash 题目链接:https://nanti.jisuanke.com/t/31461 题意:给你一个字符串 ...

  8. C9 vs 三星

    我还是更喜欢 C9, 可惜当年的牛B人物 LemonNation 不在了,C9 赢 三星 一局的机会都没有了. 伟大的 LemonNation ,软件工程学硕士, 2014年,LemonNation ...

  9. ES6 学习 -- 箭头函数(=>)

    (1).只有一个参数且只有一句表达式语句的,函数表达式的花括号可以不写let test = a => a; // 只有一个参数a,这里的表达式相当于 "return a" ( ...

  10. QQ空间删除日志

    按下F12,贴上如下代码 var delay = 2000; function del() { document.querySelector(".app_canvas_frame" ...