pytorch之 sava_reload_model
- import torch
- import matplotlib.pyplot as plt
- # torch.manual_seed(1) # reproducible
- # fake data
- x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
- y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
- # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
- # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
- def save():
- # save net1
- net1 = torch.nn.Sequential(
- torch.nn.Linear(1, 10),
- torch.nn.ReLU(),
- torch.nn.Linear(10, 1)
- )
- optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
- loss_func = torch.nn.MSELoss()
- for t in range(100):
- prediction = net1(x)
- loss = loss_func(prediction, y)
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- # plot result
- plt.figure(1, figsize=(10, 3))
- plt.subplot(131)
- plt.title('Net1')
- plt.scatter(x.data.numpy(), y.data.numpy())
- plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
- # 2 ways to save the net
- torch.save(net1, 'net.pkl') # save entire net
- torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters
- def restore_net():
- # restore entire net1 to net2
- net2 = torch.load('net.pkl')
- prediction = net2(x)
- # plot result
- plt.subplot(132)
- plt.title('Net2')
- plt.scatter(x.data.numpy(), y.data.numpy())
- plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
- def restore_params():
- # restore only the parameters in net1 to net3
- net3 = torch.nn.Sequential(
- torch.nn.Linear(1, 10),
- torch.nn.ReLU(),
- torch.nn.Linear(10, 1)
- )
- # copy net1's parameters into net3
- net3.load_state_dict(torch.load('net_params.pkl'))
- prediction = net3(x)
- # plot result
- plt.subplot(133)
- plt.title('Net3')
- plt.scatter(x.data.numpy(), y.data.numpy())
- plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
- plt.show()
- # save net1
- save()
- # restore entire net (may slow)
- restore_net()
- # restore only the net parameters
- restore_params()
pytorch之 sava_reload_model的更多相关文章
- Ubutnu16.04安装pytorch
1.下载Anaconda3 首先需要去Anaconda官网下载最新版本Anaconda3(https://www.continuum.io/downloads),我下载是是带有python3.6的An ...
- 解决运行pytorch程序多线程问题
当我使用pycharm运行 (https://github.com/Joyce94/cnn-text-classification-pytorch ) pytorch程序的时候,在Linux服务器 ...
- 基于pytorch实现word2vec
一.介绍 word2vec是Google于2013年推出的开源的获取词向量word2vec的工具包.它包括了一组用于word embedding的模型,这些模型通常都是用浅层(两层)神经网络训练词向量 ...
- 基于pytorch的CNN、LSTM神经网络模型调参小结
(Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...
- pytorch实现VAE
一.VAE的具体结构 二.VAE的pytorch实现 1加载并规范化MNIST import相关类: from __future__ import print_function import argp ...
- PyTorch教程之Training a classifier
我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...
- PyTorch教程之Neural Networks
我们可以通过torch.nn package构建神经网络. 现在我们已经了解了autograd,nn基于autograd来定义模型并对他们有所区分. 一个 nn.Module模块由如下部分构成:若干层 ...
- PyTorch教程之Autograd
在PyTorch中,autograd是所有神经网络的核心内容,为Tensor所有操作提供自动求导方法. 它是一个按运行方式定义的框架,这意味着backprop是由代码的运行方式定义的. 一.Varia ...
- Linux安装pytorch的具体过程以及其中出现问题的解决办法
1.安装Anaconda 安装步骤参考了官网的说明:https://docs.anaconda.com/anaconda/install/linux.html 具体步骤如下: 首先,在官网下载地址 h ...
随机推荐
- Scala 学习(3)之「类——基本概念1」
类 小提示:可以通过:paste进入 Scala 的多行模式,输入对应的代码块之后,按ctrl + D退出多行模式,然后再调用刚才输入的函数或者方法进行测试 //定义类,包含 field 以及方法 c ...
- 最大流入门题目 - poj 1273
Every time it rains on Farmer John's fields, a pond forms over Bessie's favorite clover patch. This ...
- dp - 求符合题意的序列的个数
The sequence of integers a1,a2,…,ak is called a good array if a1=k−1 and a1>0. For example, the s ...
- transient简介
当一个对象实现了Serilizable接口,这个对象就可以被序列化,java的这种序列化模式为开发者提供了很多便利,我们可以不必关系具体序列化的过程,只要这个类实现了Serilizable接口,这个的 ...
- 【javaScript】获取某年某月的的最后一天(即当月天数) 妙用
javaScript里 面的new Date("xxxx/xx/xx")这个日期的构造方法有一个妙处,当你传入的是"xxxx/xx/0"(0号)的话,得到的日期 ...
- python实现数据结构-队列
注:本文档主要是学习<Python核心编程(第二版)>时的练习题. 队列是一种"先进先出"的数据结构(FIFO),是一种操作受限的线性结构,先进队列的成员先出队列.示意 ...
- Ubuntu16手动安装OpenStack——nova篇。。转
前言: 本文转自https://www.voidking.com/dev-ubuntu16-manual-openstack-nova/ ,过程非常的详细,作者也说本实验最终失败,因为课程要求我们只要 ...
- 3分钟接入socket.io使用
WebSocket 简介 传统的客户端和服务器通信协议是HTTP:客户端发起请求,服务端进行响应,服务端从不主动勾搭客户端. 这种模式有个明显软肋,就是同步状态.而实际应用中有大量需要客户端和服务器实 ...
- Day4-Python3基础-装饰器、迭代器
今日内容: 1.高阶函数 2.嵌套函数 3.装饰器 4.生成器 5.迭代器 1.高阶函数 定义: a:把一个函数名当作实参传给函数 a:返回值包含函数名(不修改函数的调用方式) import time ...
- 画布 canvas 的相关内容
1.什么是canvas canvas也被叫做画布,是在JavaScript中完成网页图像制作的一个重要的途径,画布是一个矩形区域,在这个矩形区域中你可以利用好这里的每一个像素.同样在canvas中也有 ...