【chainer框架】【pytorch框架】
教程:
https://bennix.github.io/
https://bennix.github.io/blog/2017/12/14/chain_basic/
https://bennix.github.io/blog/2017/12/18/Chain_Tutorial1/
模块 | 功能 |
---|---|
datasets | 输入数据可以被格式化为这个类的模型输入。它涵盖了大部分输入数据结构的用例。 |
variable | 它是一个函数/连接/Chain的输出。 |
functions | 支持深度学习中广泛使用的功能的框架,例如 sigmoid, tanh, ReLU等 |
links | 支持深度学习中广泛使用的层的框架,例如全连接层,卷积层等 |
Chain | 连接和函数(层)连接起来形成一个“模型”。 |
optimizers | 指定用于调整模型参数的什么样的梯度下降方法,例如 SGD, AdaGrad, Adam. |
serializers | 保存/加载训练状态。例如 model, optimizer 等 |
iterators | 定义训练器使用的每个小批量数据。 |
training.updater | 定义Trainer中使用的每个前向、反向传播的参数更新过程。 |
training.Trainer | 管理训练器 |
下面是chainer模块的导入语句。
# Initial setup following
import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions
示例 minst:
MNIST数据集由70,000个尺寸为28×28(即784个像素)的灰度图像和相应的数字标签组成。数据集默认分为6万个训练图像和10,000个测试图像。我们可以通过datasets.get_mnist()
获得矢量化版本(即一组784维向量)。
train, test = datasets.get_mnist()
此代码自动下载MNIST数据集并将NumPy数组保存到 $(HOME)/.chainer
目录中。返回的训练集和测试集可以看作图像标签配对的列表(严格地说,它们是TupleDataset的实例)。
我们还必须定义如何迭代这些数据集。我们想要在数据集的每次扫描开始时对每个epoch的训练数据集进行重新洗牌。在这种情况下,我们可以使用iterators.SerialIterator
。
train_iter = iterators.SerialIterator(train, batch_size=100, shuffle=True)
另一方面,我们不必洗牌测试数据集。在这种情况下,我们可以通过shuffle = False来禁止混洗。当底层数据集支持快速切片时,它使迭代速度更快。
test_iter = iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)
当所有的例子被访问时,我们停止迭代通过设定 repeat=False 。测试/验证数据集通常需要此选项;没有这个选项,迭代进入一个无限循环。
接下来,我们定义架构。我们使用一个简单的三层网络,每层100个单元。
class MLP(Chain):
def __init__(self, n_units, n_out):
super(MLP, self).__init__()
with self.init_scope():
# the size of the inputs to each layer will be inferred
self.l1 = L.Linear(None, n_units) # n_in -> n_units # 第一个参数为输入维数,第二个参数为输出维数。这里注意,第一个参数填none时,则输入维数将在第一次前向传播时确定
self.l2 = L.Linear(None, n_units) # n_units -> n_units
self.l3 = L.Linear(None, n_out) # n_units -> n_out
def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
y = self.l3(h2)
return y
该链接使用relu()作为激活函数。请注意,“l3”链接是最终的全连接层,其输出对应于十个数字的分数。
为了计算损失值或评估预测的准确性,我们在上面的MLP连接的基础上定义一个分类器连接:
class Classifier(Chain):
def __init__(self, predictor):
super(Classifier, self).__init__()
with self.init_scope():
self.predictor = predictor
def __call__(self, x, t):
y = self.predictor(x)
loss = F.softmax_cross_entropy(y, t)
accuracy = F.accuracy(y, t)
report({'loss': loss, 'accuracy': accuracy}, self)
return loss
这个分类器类计算准确性和损失,并返回损失值。参数对x和t对应于数据集中的每个示例(图像和标签的元组)。 softmax_cross_entropy()
计算给定预测和基准真实标签的损失值。 accuracy()
计算预测准确度。我们可以为分类器的一个实例设置任意的预测器连接。
report()
函数向训练器报告损失和准确度。收集训练统计信息的具体机制参见 Reporter
. 您也可以采用类似的方式收集其他类型的观测值,如激活统计。
请注意,类似上面的分类器的类被定义为chainer.links.Classifier
。因此,我们将使用此预定义的Classifier
连接而不是使用上面的示例。
model = L.Classifier(MLP(100, 10)) # the input size, 784, is inferred
optimizer = optimizers.SGD()
optimizer.setup(model)
现在我们可以建立一个训练器对象。
updater = training.StandardUpdater(train_iter, optimizer)
trainer = training.Trainer(updater, (20, 'epoch'), out='result')
第二个参数(20,’epoch’)表示训练的持续时间。我们可以使用epoch或迭代作为单位。在这种情况下,我们通过遍历训练集20次来训练多层感知器。
为了调用训练循环,我们只需调用run()方法。
这个方法执行整个训练序列。
上面的代码只是优化了参数。在大多数情况下,我们想看看培训的进展情况,我们可以在调用run方法之前使用扩展插入。
trainer.extend(extensions.Evaluator(test_iter, model))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()
epoch main/accuracy validation/main/accuracy
[J total [..................................................] 0.83%
this epoch [########..........................................] 16.67%
100 iter, 0 epoch / 20 epochs
inf iters/sec. Estimated time to finish: 0:00:00.
[4A[J total [..................................................] 1.67%
this epoch [################..................................] 33.33%
200 iter, 0 epoch / 20 epochs
270.19 iters/sec. Estimated time to finish: 0:00:43.672168.
[4A[J total [#.................................................] 2.50%
this epoch [#########################.........................] 50.00%
300 iter, 0 epoch / 20 epochs
271.99 iters/sec. Estimated time to finish: 0:00:43.017048.
[4A[J total [#.................................................] 3.33%
this epoch [#################################.................] 66.67%
400 iter, 0 epoch / 20 epochs
【chainer框架】【pytorch框架】的更多相关文章
- 《深度学习框架PyTorch:入门与实践》的Loss函数构建代码运行问题
在学习陈云的教程<深度学习框架PyTorch:入门与实践>的损失函数构建时代码如下: 可我运行如下代码: output = net(input) target = Variable(t.a ...
- PyTorch框架+Python 3面向对象编程学习笔记
一.CNN情感分类中的面向对象部分 sparse.py super(Embedding, self).__init__() 表示需要父类初始化,即要运行父类的_init_(),如果没有这个,则要自定义 ...
- 手写数字识别 卷积神经网络 Pytorch框架实现
MNIST 手写数字识别 卷积神经网络 Pytorch框架 谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打 说明 下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释, ...
- 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)
我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...
- 全面解析Pytorch框架下模型存储,加载以及冻结
最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...
- Pyinstaller打包Pytorch框架所遇到的问题
目录 前言 基本流程 一.安装Pyinstaller 和 测试Hello World 二.打包整个项目,在本机上调试生成exe 三.在新电脑上测试 参考资料 前言 第一次尝试用Pyinstalle ...
- PyTorch框架起步
PyTorch框架基本处理操作 part1:pytorch简介与安装 CPU版本安装:pip install torch1.3.0+cpu torchvision0.4.1+cpu -f https: ...
- 介绍开源的.net通信框架NetworkComms框架 源码分析
原文网址: http://www.cnblogs.com/csdev Networkcomms 是一款C# 语言编写的TCP/UDP通信框架 作者是英国人 以前是收费的 售价249英镑 我曾经花了 ...
- web 框架的本质及自定义web框架 模板渲染jinja2 mvc 和 mtv框架 Django框架的下载安装 基于Django实现的一个简单示例
Django基础一之web框架的本质 本节目录 一 web框架的本质及自定义web框架 二 模板渲染JinJa2 三 MVC和MTV框架 四 Django的下载安装 五 基于Django实现的一个简单 ...
- 那些年读过的书《Java并发编程实战》和《Java并发编程的艺术》三、任务执行框架—Executor框架小结
<Java并发编程实战>和<Java并发编程的艺术> Executor框架小结 1.在线程中如何执行任务 (1)任务执行目标: 在正常负载情况下,服务器应用 ...
随机推荐
- mysql 创建连接是 Cannot create PoolableConnectionFactory (Unknown character set: 'utf8mb4')
Cannot create PoolableConnectionFactory (Unknown character set: 'utf8mb4') maven 依赖换版本 <dependenc ...
- Fast RDP Brute暴力破解3389口令
http://www.tuicool.com/articles/b67rQfr 下载地址:https://www.rekings.com/fast-rdp-brute-gui-v2-0/
- java strtus2 拦截器(Interceptors)
在strtus2 中有一个比较重要的东西就是拦截器(Interceptors) 拦截器可以做到在已有的业务中插入一块共通的,比如在一个业务中,直接插入一串登录功能,就不用去每个页面一个个去显示是否登录 ...
- 字符集导致乱码问题,gi安装问题
今天是2014-4-24,今天中午收到一个天津网友问的一个安装gi的问题,和一个网友问的字符集问题:在此整理一下 问题一: gi安装问题: 问题描写叙述: 在安装gi的时候提示:"INS-2 ...
- C#调用Oracle存储过程
C#调用Oracle存储过程的代码如下所示: using System; using System.Collections.Generic; using System.Collections.Obje ...
- 二分图匹配 + 最小点覆盖 - Vertex Cover
Vertex Cover Problem's Link Mean: 给你一个无向图,让你给图中的结点染色,使得:每条边的两个顶点至少有一个顶点被染色.求最少的染色顶点数. analyse: 裸的最小点 ...
- EasyUI Window和Layout
我们建立tabs内容. <div class="easyui-window" title="Layout Window" icon="icon- ...
- printf 字体颜色打印
为了给printf着色方便, 我们可以定义一些宏: view plain copy to clipboard print ? #define NONE "/033[m&qu ...
- 【BZOJ】1650: [Usaco2006 Dec]River Hopscotch 跳石子(二分+贪心)
http://www.lydsy.com/JudgeOnline/problem.php?id=1650 看到数据和最小最大时一眼就是二分... 但是仔细想想好像判断时不能贪心? 然后看题解还真是贪心 ...
- [转]Shell脚本中发送html邮件的方法
<span "="">作为运维人员,免不了要编写一些监控脚本,并将监控结果及时的发送出来.那么通过邮件发送是比较常用的一种通知方式了.通常的,如果需要发送的内 ...