1. import torch
  2. import matplotlib.pyplot as plt
  3.  
  4. # torch.manual_seed(1) # reproducible
  5.  
  6. # fake data
  7. x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
  8. y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
  9.  
  10. # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
  11. # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
  12.  
  13. def save():
  14. # save net1
  15. net1 = torch.nn.Sequential(
  16. torch.nn.Linear(1, 10),
  17. torch.nn.ReLU(),
  18. torch.nn.Linear(10, 1)
  19. )
  20. optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
  21. loss_func = torch.nn.MSELoss()
  22.  
  23. for t in range(100):
  24. prediction = net1(x)
  25. loss = loss_func(prediction, y)
  26. optimizer.zero_grad()
  27. loss.backward()
  28. optimizer.step()
  29.  
  30. # plot result
  31. plt.figure(1, figsize=(10, 3))
  32. plt.subplot(131)
  33. plt.title('Net1')
  34. plt.scatter(x.data.numpy(), y.data.numpy())
  35. plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  36.  
  37. # 2 ways to save the net
  38. torch.save(net1, 'net.pkl') # save entire net
  39. torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters
  40.  
  41. def restore_net():
  42. # restore entire net1 to net2
  43. net2 = torch.load('net.pkl')
  44. prediction = net2(x)
  45.  
  46. # plot result
  47. plt.subplot(132)
  48. plt.title('Net2')
  49. plt.scatter(x.data.numpy(), y.data.numpy())
  50. plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  51.  
  52. def restore_params():
  53. # restore only the parameters in net1 to net3
  54. net3 = torch.nn.Sequential(
  55. torch.nn.Linear(1, 10),
  56. torch.nn.ReLU(),
  57. torch.nn.Linear(10, 1)
  58. )
  59.  
  60. # copy net1's parameters into net3
  61. net3.load_state_dict(torch.load('net_params.pkl'))
  62. prediction = net3(x)
  63.  
  64. # plot result
  65. plt.subplot(133)
  66. plt.title('Net3')
  67. plt.scatter(x.data.numpy(), y.data.numpy())
  68. plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  69. plt.show()
  70.  
  71. # save net1
  72. save()
  73.  
  74. # restore entire net (may slow)
  75. restore_net()
  76.  
  77. # restore only the net parameters
  78. restore_params()

pytorch之 sava_reload_model的更多相关文章

  1. Ubutnu16.04安装pytorch

    1.下载Anaconda3 首先需要去Anaconda官网下载最新版本Anaconda3(https://www.continuum.io/downloads),我下载是是带有python3.6的An ...

  2. 解决运行pytorch程序多线程问题

    当我使用pycharm运行  (https://github.com/Joyce94/cnn-text-classification-pytorch )  pytorch程序的时候,在Linux服务器 ...

  3. 基于pytorch实现word2vec

    一.介绍 word2vec是Google于2013年推出的开源的获取词向量word2vec的工具包.它包括了一组用于word embedding的模型,这些模型通常都是用浅层(两层)神经网络训练词向量 ...

  4. 基于pytorch的CNN、LSTM神经网络模型调参小结

    (Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...

  5. pytorch实现VAE

    一.VAE的具体结构 二.VAE的pytorch实现 1加载并规范化MNIST import相关类: from __future__ import print_function import argp ...

  6. PyTorch教程之Training a classifier

    我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...

  7. PyTorch教程之Neural Networks

    我们可以通过torch.nn package构建神经网络. 现在我们已经了解了autograd,nn基于autograd来定义模型并对他们有所区分. 一个 nn.Module模块由如下部分构成:若干层 ...

  8. PyTorch教程之Autograd

    在PyTorch中,autograd是所有神经网络的核心内容,为Tensor所有操作提供自动求导方法. 它是一个按运行方式定义的框架,这意味着backprop是由代码的运行方式定义的. 一.Varia ...

  9. Linux安装pytorch的具体过程以及其中出现问题的解决办法

    1.安装Anaconda 安装步骤参考了官网的说明:https://docs.anaconda.com/anaconda/install/linux.html 具体步骤如下: 首先,在官网下载地址 h ...

随机推荐

  1. Scala 学习(3)之「类——基本概念1」

    类 小提示:可以通过:paste进入 Scala 的多行模式,输入对应的代码块之后,按ctrl + D退出多行模式,然后再调用刚才输入的函数或者方法进行测试 //定义类,包含 field 以及方法 c ...

  2. 最大流入门题目 - poj 1273

    Every time it rains on Farmer John's fields, a pond forms over Bessie's favorite clover patch. This ...

  3. dp - 求符合题意的序列的个数

    The sequence of integers a1,a2,…,ak is called a good array if a1=k−1 and a1>0. For example, the s ...

  4. transient简介

    当一个对象实现了Serilizable接口,这个对象就可以被序列化,java的这种序列化模式为开发者提供了很多便利,我们可以不必关系具体序列化的过程,只要这个类实现了Serilizable接口,这个的 ...

  5. 【javaScript】获取某年某月的的最后一天(即当月天数) 妙用

    javaScript里 面的new Date("xxxx/xx/xx")这个日期的构造方法有一个妙处,当你传入的是"xxxx/xx/0"(0号)的话,得到的日期 ...

  6. python实现数据结构-队列

    注:本文档主要是学习<Python核心编程(第二版)>时的练习题. 队列是一种"先进先出"的数据结构(FIFO),是一种操作受限的线性结构,先进队列的成员先出队列.示意 ...

  7. Ubuntu16手动安装OpenStack——nova篇。。转

    前言: 本文转自https://www.voidking.com/dev-ubuntu16-manual-openstack-nova/ ,过程非常的详细,作者也说本实验最终失败,因为课程要求我们只要 ...

  8. 3分钟接入socket.io使用

    WebSocket 简介 传统的客户端和服务器通信协议是HTTP:客户端发起请求,服务端进行响应,服务端从不主动勾搭客户端. 这种模式有个明显软肋,就是同步状态.而实际应用中有大量需要客户端和服务器实 ...

  9. Day4-Python3基础-装饰器、迭代器

    今日内容: 1.高阶函数 2.嵌套函数 3.装饰器 4.生成器 5.迭代器 1.高阶函数 定义: a:把一个函数名当作实参传给函数 a:返回值包含函数名(不修改函数的调用方式) import time ...

  10. 画布 canvas 的相关内容

    1.什么是canvas canvas也被叫做画布,是在JavaScript中完成网页图像制作的一个重要的途径,画布是一个矩形区域,在这个矩形区域中你可以利用好这里的每一个像素.同样在canvas中也有 ...