1 前言

理论上,网络层数越深,拟合效果越好。但是,层数加深也会导致梯度消失或梯度爆炸现象产生。当网络层数已经过深时,深层网络表现为“恒等映射”。实践表明,神经网络对残差的学习比对恒等关系的学习表现更好,因此,残差网络在深层模型中广泛应用。

本文以MNIST手写数字分类为例,为方便读者认识残差网络,网络中只有全连接层,没有卷积层。关于MNIST数据集的说明,见使用TensorFlow实现MNIST数据集分类

笔者工作空间如下:

代码资源见-->残差网络(ResNet)案例分析

2 实验

renet.py

from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import add,Input,Dense,Activation #载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y #残差块
def ResBlock(x,hidden_size1,hidden_size2):
r=Dense(hidden_size1,activation='relu')(x) #第一隐层
r=Dense(hidden_size2)(r) #第二隐层
if x.shape[1]==hidden_size2:
shortcut=x
else:
shortcut=Dense(hidden_size2)(x) #shortcut(捷径)
o=add([r,shortcut])
o=Activation('relu')(o) #激活函数
return o #残差网络
def ResNet(train_x,train_y,valid_x,valid_y,test_x,test_y):
inputs=Input(shape=(784,))
x=ResBlock(inputs,30,30)
x=ResBlock(x,30,30)
x=ResBlock(x,20,20)
x=Dense(10,activation='softmax')(x)
model=Model(input=inputs,output=x)
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=50,verbose=2,validation_data=(valid_x,valid_y))
#评估模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)
print('test_loss:',pre[0],'- test_acc:',pre[1]) train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
ResNet(train_x,train_y,valid_x,valid_y,test_x,test_y)

网络各层输出尺寸:

__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 784) 0
__________________________________________________________________________________________________
dense_1 (Dense) (None, 30) 23550 input_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 30) 930 dense_1[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 30) 23550 input_1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 30) 0 dense_2[0][0]
dense_3[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 30) 0 add_1[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 30) 930 activation_1[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 30) 930 dense_4[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 30) 0 dense_5[0][0]
activation_1[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 30) 0 add_2[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 20) 620 activation_2[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 20) 420 dense_6[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 20) 620 activation_2[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 20) 0 dense_7[0][0]
dense_8[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 20) 0 add_3[0][0]
__________________________________________________________________________________________________
dense_9 (Dense) (None, 10) 210 activation_3[0][0]
==================================================================================================
Total params: 51,760
Trainable params: 51,760
Non-trainable params: 0

网络训练结果:

Epoch 48/50
- 1s - loss: 0.0019 - acc: 0.9999 - val_loss: 0.1463 - val_acc: 0.9706
Epoch 49/50
- 1s - loss: 0.0016 - acc: 0.9999 - val_loss: 0.1502 - val_acc: 0.9722
Epoch 50/50
- 1s - loss: 0.0013 - acc: 0.9999 - val_loss: 0.1542 - val_acc: 0.9728
test_loss: 0.16228994959965348 - test_acc: 0.9721000045537949

​ 声明:本文转自基于keras的残差网络

基于keras的残差网络的更多相关文章

  1. 基于 Keras 用 LSTM 网络做时间序列预测

    目录 基于 Keras 用 LSTM 网络做时间序列预测 问题描述 长短记忆网络 LSTM 网络回归 LSTM 网络回归结合窗口法 基于时间步的 LSTM 网络回归 在批量训练之间保持 LSTM 的记 ...

  2. BERT实战——基于Keras

    1.keras_bert 和 kert4keras keras_bert 是 CyberZHG 大佬封装好了Keras版的Bert,可以直接调用官方发布的预训练权重. github:https://g ...

  3. 使用dlib中的深度残差网络(ResNet)实现实时人脸识别

    opencv中提供的基于haar特征级联进行人脸检测的方法效果非常不好,本文使用dlib中提供的人脸检测方法(使用HOG特征或卷积神经网方法),并使用提供的深度残差网络(ResNet)实现实时人脸识别 ...

  4. [深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88)

    [深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88) 个人主页--> https://xiaosongshine.github.io/ 项目g ...

  5. 残差网络ResNet笔记

    发现博客园也可以支持Markdown,就把我之前写的博客搬过来了- 欢迎转载,请注明出处:http://www.cnblogs.com/alanma/p/6877166.html 下面是正文: Dee ...

  6. 深度残差网络(DRN)ResNet网络原理

    一说起“深度学习”,自然就联想到它非常显著的特点“深.深.深”(重要的事说三遍),通过很深层次的网络实现准确率非常高的图像识别.语音识别等能力.因此,我们自然很容易就想到:深的网络一般会比浅的网络效果 ...

  7. Dual Path Networks(DPN)——一种结合了ResNet和DenseNet优势的新型卷积网络结构。深度残差网络通过残差旁支通路再利用特征,但残差通道不善于探索新特征。密集连接网络通过密集连接通路探索新特征,但有高冗余度。

    如何评价Dual Path Networks(DPN)? 论文链接:https://arxiv.org/pdf/1707.01629v1.pdf在ImagNet-1k数据集上,浅DPN超过了最好的Re ...

  8. 关于深度残差网络(Deep residual network, ResNet)

    题外话: From <白话深度学习与TensorFlow> 深度残差网络: 深度残差网络的设计就是为了克服这种由于网络深度加深而产生的学习效率变低,准确率无法有效提升的问题(也称为网络退化 ...

  9. 深度残差网络——ResNet学习笔记

    深度残差网络—ResNet总结 写于:2019.03.15—大连理工大学 论文名称:Deep Residual Learning for Image Recognition 作者:微软亚洲研究院的何凯 ...

  10. Resnet——深度残差网络(一)

    我们都知道随着神经网络深度的加深,训练过程中会很容易产生误差的积累,从而出现梯度爆炸和梯度消散的问题,这是由于随着网络层数的增多,在网络中反向传播的梯度会随着连乘变得不稳定(特别大或特别小),出现最多 ...

随机推荐

  1. [转帖]time_zone 是怎么打爆你的MySQL的

    https://plantegg.github.io/2023/10/03/time_zone%E6%98%AF%E6%80%8E%E4%B9%88%E6%89%93%E7%88%86%E4%BD%A ...

  2. [转帖]linux中批量多行缩进与添加空格

    用vim打开修改python脚本的时候,将代码整体向后移动4个空格操作如下: ESC之后,ctrl+v进入多行行首选中模式 使用上下键进行上下移动,选中多行行首 shift+i,进入插入模式 连续敲击 ...

  3. ext4 磁盘扩容

    目录 ext4文件系统磁盘扩容 目标 途径 操作步骤 改变前的现状 操作和改变后的状态 ext4文件系统磁盘扩容 一个磁盘有多个分区,分别创建了物理卷.卷组.逻辑卷.通过虚拟机软件对虚拟机的磁盘/de ...

  4. [转帖]GC日志分析工具——GCViewer案例

    原创 石页粑粑 来自zxsk的码农 2020-09-28 06:18 一.GCViewer介绍 业界较为流行分析GC日志的两个工具--GCViewer.GCEasy.GCEasy部分功能还是要收费的, ...

  5. [转帖]JVM NativeMemoryTracking ;jcmd process_id VM.native_memory;Native memory tracking is not enabled

    目录 一.Native Memory Tracking (NMT) 是Hotspot VM用来分析VM内部内存使用情况的一个功能.我们可以利用jcmd(jdk自带)这个工具来访问NMT的数据. 1.N ...

  6. [转帖]技术派-汇编语言指令集(intel X86系列)

    针对汇编语言指令集(intel X86系列8086/80186/80286/80386/80486) AAA - Ascii Adjust for Addition        AAD - Asci ...

  7. 西门子PLC高校作业以及创新项目

    抢答器 在主持人按下启动按钮,3秒内

  8. STM32CubeMX教程27 SDIO - 读写SD卡

    1.准备材料 正点原子stm32f407探索者开发板V2.4 STM32CubeMX软件(Version 6.10.0) keil µVision5 IDE(MDK-Arm) ST-LINK/V2驱动 ...

  9. MetaGPT( The Multi-Agent Framework):颠覆AI开发的革命性多智能体元编程框架

    "MetaGPT( The Multi-Agent Framework):颠覆AI开发的革命性多智能体元编程框架" 一个多智能体元编程框架,给定一行需求,它可以返回产品文档.架构设 ...

  10. NLP涉及技术原理和应用简单讲解【一】:paddle(梯度裁剪、ONNX协议、动态图转静态图、推理部署)

    参考链接: https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/advanced/gradient_clip_cn.html 1. ...