pytorch1.0网络保存、提取、加载

import torch
import torch.nn.functional as F # 包含激励函数
import matplotlib.pyplot as plt # 假数据
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) def save():
# save net1
# 建网络
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
# 训练
for t in range(100):
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() # plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 2 ways to save the net
torch.save(net1, 'net.pkl') # save entire net # 保存整个网络
torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters # 只保存网络中的参数 (速度快, 占内存少) # 提取网络
def restore_net():
# restore entire net1 to net2
net2 = torch.load('net.pkl')
prediction = net2(x) # plot result
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 只提取网络参数
def restore_params():
# 新建 net3
# restore only the parameters in net1 to net3
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
# 将保存的参数复制到 net3
# copy net1's parameters into net3
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x) # plot result
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show() # 保存 net1 (1. 整个网络, 2. 只有参数)
# save net1
save()
# 提取整个网络
# restore entire net (may slow)
restore_net()
# 提取网络参数, 复制到新网络
# restore only the net parameters
restore_params()

pytorch1.0神经网络保存、提取、加载的更多相关文章

  1. tensorflow 之模型的保存与加载(三)

    前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...

  2. tensorflow 之模型的保存与加载(一)

    怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...

  3. tensorflow模型的保存与加载

    模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...

  4. MindSpore保存与加载模型

    技术背景 近几年在机器学习和传统搜索算法的结合中,逐渐发展出了一种Search To Optimization的思维,旨在通过构造一个特定的机器学习模型,来替代传统算法中的搜索过程,进而加速经典图论等 ...

  5. tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署

    TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...

  6. tensorflow 之模型的保存与加载(二)

    上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练. 在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量. #!/ ...

  7. tensorflow实现线性回归、以及模型保存与加载

    内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...

  8. TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结

    写在前面 我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度 ...

  9. [PyTorch 学习笔记] 7.1 模型保存与加载

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...

随机推荐

  1. 解释下Http请求头和常见响应状态码

    Accept-Charset:指出浏览器可以接受的字符编码.英文浏览器的默认值是ISO-8859-1.ccept:指浏览器或其他客户可以接爱的MIME文件格式.可以根据它判断并返回适当的文件格式. A ...

  2. for循环实战性能优化之使用Map集合优化

           笔者在<for循环实战性能优化>中提出了五种提升for循环性能的优化策略,这次我们在其中嵌套循环优化小循环驱动大循环的基础上,借助Map集合高效的查询性能来优化嵌套for循环 ...

  3. LVM卷

    sdb和sdc创建为LVM并且挂载到/benet/ 将sdd扩展到之前的lvm卷 新建2块1G的磁盘和1块2G的磁盘 将分区ID改为8e 创建PV阶段 pvcreate /dev/sdb1 /dev/ ...

  4. ansys meshing当中共享节点的操作

    原视频下载地址: https://yunpan.cn/cYKuN9mIHMUxa  访问密码 5858

  5. 如何快速关联/修改Git远程仓库地址

    如何快速关联/修改Git远程仓库地址?按照如下步骤即可快速实现关联/修改Git远程仓库地址: 删除本地仓库当前关联的无效远程地址,再为本地仓库添加新的远程仓库地址 git remote -v //查看 ...

  6. Percona,MariaDB,MySQL衍生版如何取舍

    缘起 自从甲骨文公司收购了MySQL后,有将MySQL闭源的潜在风险.而且Oracle对培养MySQL这个免费的儿子并不太用心,漏洞修补和版本升级的速度一段时间非常缓慢,所以业界对MySQL的未来普遍 ...

  7. ubuntu之路——day11.4 定位数据不匹配与人工合成数据

    1.人工检验train和dev/test之间的区别: 比如:汽车语音识别中的噪音.地名难以识别等等 2.使得你的训练集更靠近(相似于)dev/test,收集更多类似于dev的数据: 比如:dev中存在 ...

  8. mysql innodb与myisam存储文件的区别

    myisam: .frm: 存储表定义 .myd(MYData):存储数据 .MYI(MYindex):存储引擎 innodb: .frm:存储表定义 .idb:存储数据和索引,在同一个文件中

  9. 【转】Android检查手机是否被root

    目前来说Android平台并没有提供能够root检查的工具.但是我们可以通过两种方式来判断 手机里面是否有su文件 这个su文件是不是能够执行 但是这两种检查方式都存在缺点. 第一种存在误测和漏测的情 ...

  10. Phpstudy 无法启动mysql

    原因: 两个mysql版本冲突 本地已经有一个mysql服务(3306)默认开启,再装了phpstudy又会自带一个mysqlla服务(3306) phpstudy启动后会启动mysqlla  发现3 ...