import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt torch.manual_seed() # fake data
x = torch.unsqueeze(torch.linspace(-,,),dim=)
y = x.pow() + 0.2 * torch.rand(x.size())
x, y = Variable(x,requires_grad=False), Variable(y,requires_grad=False) def save():
net1 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss() for t in range():
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() plt.figure(,figsize=(,))
plt.subplot()
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
torch.save(net1, 'net.pkl') # 保存整个网络,包括整个计算图
torch.save(net1.state_dict(), 'net_params.pkl') # 只保存网络中的参数 (速度快, 占内存少) def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
plt.subplot()
plt.title('Net2')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
def restore_params():
net3 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x) plt.subplot()
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=)
# 将保存的参数复制到 net3
plt.show() save()
restore_net()
restore_params()

结果和莫烦的不一样,但是找不到问题的所在,,。。。

莫烦PyTorch学习笔记(五)——模型的存取的更多相关文章

  1. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...

  4. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  5. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

  6. 莫烦 - Pytorch学习笔记 [ 二 ] CNN ( 1 )

    CNN原理和结构 观点提出 关于照片的三种观点引出了CNN的作用. 局部性:某一特征只出现在一张image的局部位置中. 相同性: 同一特征重复出现.例如鸟的羽毛. 不变性:subsampling下图 ...

  7. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(一)——torch or numpy

    Q1:什么是神经网络? Q2:torch vs numpy Numpy:NumPy系统是Python的一种开源的数值计算扩展.这种工具可用来存储和处理大型矩阵,比Python自身的嵌套列表(neste ...

随机推荐

  1. python爬取文件时,内容为空

    解决方式: img_res = requests.get(src,headers=header)在header中加上referer防盗链加上防盗链header的例子: header = {" ...

  2. 【JUC】JDK1.8源码分析之ConcurrentHashMap

    一.前言 最近几天忙着做点别的东西,今天终于有时间分析源码了,看源码感觉很爽,并且发现ConcurrentHashMap在JDK1.8版本与之前的版本在并发控制上存在很大的差别,很有必要进行认真的分析 ...

  3. jq页面换肤效果

    <!DOCTYPE html> <html lang="en"> <head> <script src="http://code ...

  4. Markdown 语法大全

    1 强调 星号与下划线都可以,单是斜体,双是粗体,符号可跨行,符号可加空格 **一个人来到田纳西** __毫无疑问__ *我做的馅饼 是全天下* _最好吃的_ 效果: 一个人来到田纳西 毫无疑问 我做 ...

  5. Selenium(二)---无界面模式+滑动底部

    一.使用无界面模式 1.正常情况启动 selenium 是有界面的 2.有些情况下,需要不显示界面,这时只要设置一下参数就可以实现了 # 不想显示界面可以用 Chrome——配置一下参数就好 from ...

  6. vue3+node全栈项目部署到云服务器

    一.前言 最近在B站学习了一下全栈开发,使用到的技术栈是Vue+Element+Express+MongoDB,为了让自己学的第一个全栈项目落地,于是想着把该项目部署到阿里云服务器.经过网上一番搜索和 ...

  7. Hive中SQL查询转换成MapReduce作业的过程

  8. Binary XML file line #23: Error inflating class android.widget.TextView

    分析一波,报错23行TextView的问题,但是检查了xml没有发现23行又TextView相关代码,就不应该继续纠结xml了,代码是通过R文件拿到xml资源的,你就应该怀疑是R文件的问题,R文件编译 ...

  9. 解决vagrant上使用Homestead很慢(响应速度10s+)

    说明: 使用vagrant和Homestead 在vBox上面跑laravel, 响应速度非常缓慢(大概在10+s), 尝试过增加虚拟机配置, 但是没有任何效果, 经验证也不是数据库的原因 . 通过网 ...

  10. 应用上云新模式,Aliware 全家桶亮相杭州云栖大会

    全面上云带来的变化,不仅是上云企业数量上的攀升,也是企业对云的使用方式的转变,越来越多的企业用户不仅将云作为一种弹性资源,更是开始在云上部署架构和应用,借助 Serverless 等技术,开发人员只需 ...