import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt torch.manual_seed() # fake data
x = torch.unsqueeze(torch.linspace(-,,),dim=)
y = x.pow() + 0.2 * torch.rand(x.size())
x, y = Variable(x,requires_grad=False), Variable(y,requires_grad=False) def save():
net1 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss() for t in range():
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() plt.figure(,figsize=(,))
plt.subplot()
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
torch.save(net1, 'net.pkl') # 保存整个网络,包括整个计算图
torch.save(net1.state_dict(), 'net_params.pkl') # 只保存网络中的参数 (速度快, 占内存少) def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
plt.subplot()
plt.title('Net2')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
def restore_params():
net3 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x) plt.subplot()
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=)
# 将保存的参数复制到 net3
plt.show() save()
restore_net()
restore_params()

结果和莫烦的不一样,但是找不到问题的所在,,。。。

莫烦PyTorch学习笔记(五)——模型的存取的更多相关文章

  1. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...

  4. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  5. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

  6. 莫烦 - Pytorch学习笔记 [ 二 ] CNN ( 1 )

    CNN原理和结构 观点提出 关于照片的三种观点引出了CNN的作用. 局部性:某一特征只出现在一张image的局部位置中. 相同性: 同一特征重复出现.例如鸟的羽毛. 不变性:subsampling下图 ...

  7. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(一)——torch or numpy

    Q1:什么是神经网络? Q2:torch vs numpy Numpy:NumPy系统是Python的一种开源的数值计算扩展.这种工具可用来存储和处理大型矩阵,比Python自身的嵌套列表(neste ...

随机推荐

  1. Python3 From Zero——{最初的意识:005~文件和I/O}

    一.输出重定向到文件 >>> with open('/home/f/py_script/passwd', 'rt+') as f1: ... print('Hello Dog!', ...

  2. USACO2012 Haybale stacking /// 区间表示法 oj21556

    题目大意:N个方块 标号1~N  K个操作 操作a b 表示标号a~b区间每位多加一个方块 Input * Line 1: Two space-separated integers, N  K. * ...

  3. 请求参数MD5加密---函数助手

  4. c# 编写windows 服务,并制作安装包

    对服务的认识有很多个阶段. 第一阶段:当时还在用c++,知道在一个进程里while(True){},然后里面做很多很多事情,这就叫做服务了,界面可能当时还用Console控制台程序. 第二阶段:知道了 ...

  5. JDBC_数据库连接池DRUID

    /** * @Description: TODO(这里用一句话描述这个类的作用) * @Author aikang * @Date 2019/8/26 20:12 */ /* 1.数据库连接池: 1. ...

  6. go 简介与包

    简介 Go语言是一种新的语言,一种并发的.带垃圾回收的.快速编译的语言.它具有以下特点: 1.它可以在一台计算机上用几秒钟的时间编译一个大型的Go程序. 2.Go语言为软件构造提供了一种模型,它使依赖 ...

  7. 搜索框下面显示提示数据(数据是ajax读取)

    1.前台页面 <div style="margin: 0 auto"> <input type="text" id="wenxian ...

  8. Yii2 使用 npm 安装的包

    转载自: yii2.0.15 使用 npm 替换 bower,加速 composer 安装速度 [ 2.0 版本 ] 修改 ommon/config/main.php <?php return ...

  9. ReentrantLock中的公平锁与非公平锁

    简介 ReentrantLock是一种可重入锁,可以等同于synchronized的使用,但是比synchronized更加的强大.灵活. 一个可重入的排他锁,它具有与使用 synchronized ...

  10. python 之单例模式

    单例模式1 单例=>只有一个单例2 静态方法+静态字段3 所有实例中等转的内容相同时 用单例模式class Sqllite: __instance=None def __init__(self) ...