基于上一篇resnet网络结构进行实战。

再来贴一下resnet的基本结构方便与代码进行对比

resnet的自定义类如下:

import tensorflow as tf
from tensorflow import keras class BasicBlock(keras.layers.Layer): # filter_num指定通道数,stride指定步长
def __init__(self,filter_num,stride=1):
super(BasicBlock, self).__init__() # 注意padding=same并不总使得输入维度等于输出维度,而是对不同的步长有不同的策略,使得滑动更加完整
self.conv1 = keras.layers.Conv2D(filter_num,(3,3),strides=stride,padding='same')
self.bn1 = keras.layers.BatchNormalization()
self.relu = keras.layers.Activation('relu') self.conv2 = keras.layers.Conv2D(filter_num,(3,3),strides=1,padding='same')
self.bn2 = keras.layers.BatchNormalization() if stride!=1:
self.dowmsample = keras.Sequential()
self.dowmsample.add(keras.layers.Conv2D(filter_num,(1,1),strides=stride))
else:
self.dowmsample = lambda x:x def call(self, inputs, training=None): out = self.conv1(inputs)
out = self.bn1(out)
out = self.relu(out) out = self.conv2(out)
out = self.bn2(out) identity = self.dowmsample(inputs) output = keras.layers.add([out,identity])
output = tf.nn.relu(output) return output class ResNet(keras.Model): # resnet基本结构为[2,2,2,2],即分为四个部分,每个部分又分两个小部分
def __init__(self,layer_dims,num_classes=100):
super(ResNet,self).__init__() # 预处理层
self.stem = keras.Sequential([
keras.layers.Conv2D(64,(3,3),strides=(1,1)),
keras.layers.BatchNormalization(),
keras.layers.Activation('relu'),
keras.layers.MaxPool2D(pool_size=(2,2),strides=(1,1),padding='same')
]) self.layer1 = self.build_resblock(64,layer_dims[0])
self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
self.layer4 = self.build_resblock(512, layer_dims[3], stride=2) # 自适应输出,方便送入全连层进行分类
self.avgpool = keras.layers.GlobalAveragePooling2D()
self.fc = keras.layers.Dense(num_classes) def call(self, inputs, training=None):
x = self.stem(inputs) x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x) x = self.avgpool(x)
x = self.fc(x) return x def build_resblock(self,filter_num,blocks,stride=1):
res_blocks = keras.Sequential();
res_blocks.add(BasicBlock(filter_num,stride)) for _ in range(1,blocks):
res_blocks.add(BasicBlock(filter_num,1)) return res_blocks def resnet18():
return ResNet([2,2,2,2])

训练过程如下:

import tensorflow as tf
from tensorflow import keras
import os
from resnet import resnet18 os.environ['TF_CPP_MIN_LOG'] = '' def preprocess(x,y):
x = 2*tf.cast(x,dtype=tf.float32)/255.-1
y = tf.cast(y,dtype=tf.int32)
return x,y (x,y),(x_test,y_test) = keras.datasets.cifar100.load_data()
y = tf.squeeze(y,axis=1)
y_test = tf.squeeze(y_test,axis=1)
print(x.shape,y.shape,x_test.shape,y_test.shape) train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.shuffle(1000).map(preprocess).batch(64) test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = train_db.map(preprocess).batch(64) def main():
model = resnet18()
model.build(input_shape=(None,32,32,3))
optimizer = keras.optimizers.Adam(lr=1e-3)
model.summary() for epoch in range(50):
for step,(x,y) in enumerate(train_db):
with tf.GradientTape() as tape:
logits = model(x)
y_onehot = tf.one_hot(y,depth=10)
loss = tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True)
loss = tf.reduce_mean(loss) gradient = tape.gradient(loss,model.trainable_variables)
optimizer.apply_gradients(zip(gradient,model.trainable_variables)) if step % 100 == 0:
print(epoch,step,'loss:',float(loss)) total_num = 0
total_correct = 0
for x,y in test_db:
logits = model(x)
prob = tf.nn.softmax(logits,axis=1)
pred = tf.argmax(prob,axis=1)
pred = tf.cast(pred,dtype=tf.int32) correct = tf.cast(tf.equal(pred,y),dtype=tf.int32)
correct = tf.reduce_sum(correct) total_num += x.shape[0]
total_correct += correct
acc = total_correct/total_num print("acc:",acc) if __name__ == '__main__':
main()

打印网络结构和参数量如下:

Resnet——深度残差网络(二)的更多相关文章

  1. Resnet——深度残差网络(一)

    我们都知道随着神经网络深度的加深,训练过程中会很容易产生误差的积累,从而出现梯度爆炸和梯度消散的问题,这是由于随着网络层数的增多,在网络中反向传播的梯度会随着连乘变得不稳定(特别大或特别小),出现最多 ...

  2. 使用dlib中的深度残差网络(ResNet)实现实时人脸识别

    opencv中提供的基于haar特征级联进行人脸检测的方法效果非常不好,本文使用dlib中提供的人脸检测方法(使用HOG特征或卷积神经网方法),并使用提供的深度残差网络(ResNet)实现实时人脸识别 ...

  3. 深度残差网络(DRN)ResNet网络原理

    一说起“深度学习”,自然就联想到它非常显著的特点“深.深.深”(重要的事说三遍),通过很深层次的网络实现准确率非常高的图像识别.语音识别等能力.因此,我们自然很容易就想到:深的网络一般会比浅的网络效果 ...

  4. Dual Path Networks(DPN)——一种结合了ResNet和DenseNet优势的新型卷积网络结构。深度残差网络通过残差旁支通路再利用特征,但残差通道不善于探索新特征。密集连接网络通过密集连接通路探索新特征,但有高冗余度。

    如何评价Dual Path Networks(DPN)? 论文链接:https://arxiv.org/pdf/1707.01629v1.pdf在ImagNet-1k数据集上,浅DPN超过了最好的Re ...

  5. CNN卷积神经网络_深度残差网络 ResNet——解决神经网络过深反而引起误差增加的根本问题,Highway NetWork 则允许保留一定比例的原始输入 x。(这种思想在inception模型也有,例如卷积是concat并行,而不是串行)这样前面一层的信息,有一定比例可以不经过矩阵乘法和非线性变换,直接传输到下一层,仿佛一条信息高速公路,因此得名Highway Network

    from:https://blog.csdn.net/diamonjoy_zone/article/details/70904212 环境:Win8.1 TensorFlow1.0.1 软件:Anac ...

  6. 关于深度残差网络(Deep residual network, ResNet)

    题外话: From <白话深度学习与TensorFlow> 深度残差网络: 深度残差网络的设计就是为了克服这种由于网络深度加深而产生的学习效率变低,准确率无法有效提升的问题(也称为网络退化 ...

  7. 深度残差网络(ResNet)

    引言 对于传统的深度学习网络应用来说,网络越深,所能学到的东西越多.当然收敛速度也就越慢,训练时间越长,然而深度到了一定程度之后就会发现越往深学习率越低的情况,甚至在一些场景下,网络层数越深反而降低了 ...

  8. 深度残差网络——ResNet学习笔记

    深度残差网络—ResNet总结 写于:2019.03.15—大连理工大学 论文名称:Deep Residual Learning for Image Recognition 作者:微软亚洲研究院的何凯 ...

  9. ResNet(深度残差网络)

    注:平原改为简单堆叠网络 一般x是恒等映射,当x与fx尺寸不同的时候,w作用就是将x变成和fx尺寸相同. 过程: 先用w将x进行恒等映射.扩维映射或者降维映射d得到wx.(没有参数,不需要优化器训练) ...

随机推荐

  1. 基于django的会议室预订系统

    会议室预订系统 一.目标及业务流程 期望效果: 业务流程: 用户注册 用户登录 预订会议室 退订会议室 选择日期:今日以及以后日期 二.表结构设计和生成 1.models.py(用户继承Abstrac ...

  2. 笔记常用Linux命令(三) 查看服务器日志

    服务器日志 用于记录服务器的运行情况 查看服务器日志 tail:查看后面几行 n 显示行数 f 持续侦测后面的内容,查看服务器日志常用 查看最新的服务日志(静态) 命令格式:tail -n 行数 日志 ...

  3. Linux守护进程之systemd

    介绍 历史上,Linux 的启动一直采用init进程:下面的命令用来启动服务. $ sudo /etc/init.d/apache2 start # 或者 $ service apache2 star ...

  4. Docker基础内容之仓库

    前言 Docker提供了开放的中央仓库dockerhub,同时也允许我们使用registry搭建本地私有仓库.搭建私有仓库有如下的优点: 节省网络带宽,提升Docker部署速度,不用每个镜像从Dock ...

  5. JDK源码之Double类&Float类分析

    一 概述 Double 类是基本类型double的包装类,fainl修饰,在对象中包装了一个基本类型double的值.Double继承了Number抽象类,具有了转化为基本double类型的功能. 此 ...

  6. 从App.config中读取数据库连接字符串

    1.首先在App.config文件中添加如下代码注意<connectionStrings>插入位置. <connectionStrings> <add name=&quo ...

  7. centos7下oracle11g详细的安装与建表操作

    一.oracle的安装,在官网下载oracle11g R2 1.在桌面单击右键,选择“在终端中打开”,进入终端 输入命令:su 输入ROOT密码: 创建用户组oinstall:groupadd oin ...

  8. mysql 记录一次内存清理

    摘自:https://blog.csdn.net/wyzxg/article/details/7279986/ 摘要:Linux对内存的管理与Windows不同,free小并不是说内存不够用了,应该看 ...

  9. 使用visual studio 2013读取.mat文件

    现在有一个T.mat 文件需要在c++中处理然后以.mat 或是.txt形式返回 T.mat中存储了十个cell,每个cell中会有一个不等长的数组 1.以下是相关配置过程: 参考:http://we ...

  10. uml-类图书写指南

    说明 类图是最常用的UML图,面向对象建模的主要组成部分.它用于描述系统的结构化设计,显示出类.接口以及它们之间的静态结构和关系. 类图主要产出于面向对象设计的分析和设计阶段,用来描述系统的静态结构. ...