import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt torch.manual_seed() # fake data
x = torch.unsqueeze(torch.linspace(-,,),dim=)
y = x.pow() + 0.2 * torch.rand(x.size())
x, y = Variable(x,requires_grad=False), Variable(y,requires_grad=False) def save():
net1 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss() for t in range():
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() plt.figure(,figsize=(,))
plt.subplot()
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
torch.save(net1, 'net.pkl') # 保存整个网络,包括整个计算图
torch.save(net1.state_dict(), 'net_params.pkl') # 只保存网络中的参数 (速度快, 占内存少) def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
plt.subplot()
plt.title('Net2')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
def restore_params():
net3 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x) plt.subplot()
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=)
# 将保存的参数复制到 net3
plt.show() save()
restore_net()
restore_params()

结果和莫烦的不一样,但是找不到问题的所在,,。。。

莫烦PyTorch学习笔记(五)——模型的存取的更多相关文章

  1. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...

  4. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  5. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

  6. 莫烦 - Pytorch学习笔记 [ 二 ] CNN ( 1 )

    CNN原理和结构 观点提出 关于照片的三种观点引出了CNN的作用. 局部性:某一特征只出现在一张image的局部位置中. 相同性: 同一特征重复出现.例如鸟的羽毛. 不变性:subsampling下图 ...

  7. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(一)——torch or numpy

    Q1:什么是神经网络? Q2:torch vs numpy Numpy:NumPy系统是Python的一种开源的数值计算扩展.这种工具可用来存储和处理大型矩阵,比Python自身的嵌套列表(neste ...

随机推荐

  1. 好文 | MySQL 索引B+树原理,以及建索引的几大原则

    Java技术栈 www.javastack.cn 优秀的Java技术公众号 来源:小宝鸽 blog.csdn.net/u013142781/article/details/51706790 MySQL ...

  2. mybatis中的动态SQL语句

    有时候,静态的SQL语句并不能满足应用程序的需求.我们可以根据一些条件,来动态地构建 SQL语句. 例如,在Web应用程序中,有可能有一些搜索界面,需要输入一个或多个选项,然后根据这些已选择的条件去执 ...

  3. 【csp】2017-12

    第一题:游戏 题目: 题意:啊,不多赘述.看的懂. 题解:sort一下直接暴力比较大小. 代码: #include<iostream> #include<cstdio> #in ...

  4. echarts数据变了不重新渲染,以及重新渲染了前后数据会重叠渲染的问题

    1.echarts数据变了但是视图不重新渲染 新建Chart.vue文件 <template>  <p :id="id" :style="style&q ...

  5. YARN框架与MapReduce1.0框架的对比分析

  6. HDFS HA

  7. Docker 尝试安装rabbitmq实践笔记

    docker pull rabbitmq 自定義的rabbitmq Dockerfile # base image FROM rabbitmq:3.7-management # running req ...

  8. thinkphp 字段定义

    通常每个模型类是操作某个数据表,在大多数情况下,系统会自动获取当前数据表的字段信息. 系统会在模型首次实例化的时候自动获取数据表的字段信息(而且只需要一次,以后会永久缓存字段信息,除非设置不缓存或者删 ...

  9. csps模拟93序列,二叉搜索树,走路题解

    题面: 模拟93考得并不理想,二维偏序没看出来,然而看出来了也不会打 序列: 对a,b数列求前缀和,那么题意转化为了满足$suma[i]>=suma[j]$且$sumb[i]>=sumb[ ...

  10. (转)Android中RelativeLayout各个属性的含义

    转:http://blog.csdn.net/softkexin/article/details/5933589 android:layout_above="@id/xxx"  - ...