基于上一篇resnet网络结构进行实战。

再来贴一下resnet的基本结构方便与代码进行对比

resnet的自定义类如下:

import tensorflow as tf
from tensorflow import keras class BasicBlock(keras.layers.Layer): # filter_num指定通道数,stride指定步长
def __init__(self,filter_num,stride=1):
super(BasicBlock, self).__init__() # 注意padding=same并不总使得输入维度等于输出维度,而是对不同的步长有不同的策略,使得滑动更加完整
self.conv1 = keras.layers.Conv2D(filter_num,(3,3),strides=stride,padding='same')
self.bn1 = keras.layers.BatchNormalization()
self.relu = keras.layers.Activation('relu') self.conv2 = keras.layers.Conv2D(filter_num,(3,3),strides=1,padding='same')
self.bn2 = keras.layers.BatchNormalization() if stride!=1:
self.dowmsample = keras.Sequential()
self.dowmsample.add(keras.layers.Conv2D(filter_num,(1,1),strides=stride))
else:
self.dowmsample = lambda x:x def call(self, inputs, training=None): out = self.conv1(inputs)
out = self.bn1(out)
out = self.relu(out) out = self.conv2(out)
out = self.bn2(out) identity = self.dowmsample(inputs) output = keras.layers.add([out,identity])
output = tf.nn.relu(output) return output class ResNet(keras.Model): # resnet基本结构为[2,2,2,2],即分为四个部分,每个部分又分两个小部分
def __init__(self,layer_dims,num_classes=100):
super(ResNet,self).__init__() # 预处理层
self.stem = keras.Sequential([
keras.layers.Conv2D(64,(3,3),strides=(1,1)),
keras.layers.BatchNormalization(),
keras.layers.Activation('relu'),
keras.layers.MaxPool2D(pool_size=(2,2),strides=(1,1),padding='same')
]) self.layer1 = self.build_resblock(64,layer_dims[0])
self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
self.layer4 = self.build_resblock(512, layer_dims[3], stride=2) # 自适应输出,方便送入全连层进行分类
self.avgpool = keras.layers.GlobalAveragePooling2D()
self.fc = keras.layers.Dense(num_classes) def call(self, inputs, training=None):
x = self.stem(inputs) x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x) x = self.avgpool(x)
x = self.fc(x) return x def build_resblock(self,filter_num,blocks,stride=1):
res_blocks = keras.Sequential();
res_blocks.add(BasicBlock(filter_num,stride)) for _ in range(1,blocks):
res_blocks.add(BasicBlock(filter_num,1)) return res_blocks def resnet18():
return ResNet([2,2,2,2])

训练过程如下:

import tensorflow as tf
from tensorflow import keras
import os
from resnet import resnet18 os.environ['TF_CPP_MIN_LOG'] = '' def preprocess(x,y):
x = 2*tf.cast(x,dtype=tf.float32)/255.-1
y = tf.cast(y,dtype=tf.int32)
return x,y (x,y),(x_test,y_test) = keras.datasets.cifar100.load_data()
y = tf.squeeze(y,axis=1)
y_test = tf.squeeze(y_test,axis=1)
print(x.shape,y.shape,x_test.shape,y_test.shape) train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.shuffle(1000).map(preprocess).batch(64) test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = train_db.map(preprocess).batch(64) def main():
model = resnet18()
model.build(input_shape=(None,32,32,3))
optimizer = keras.optimizers.Adam(lr=1e-3)
model.summary() for epoch in range(50):
for step,(x,y) in enumerate(train_db):
with tf.GradientTape() as tape:
logits = model(x)
y_onehot = tf.one_hot(y,depth=10)
loss = tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True)
loss = tf.reduce_mean(loss) gradient = tape.gradient(loss,model.trainable_variables)
optimizer.apply_gradients(zip(gradient,model.trainable_variables)) if step % 100 == 0:
print(epoch,step,'loss:',float(loss)) total_num = 0
total_correct = 0
for x,y in test_db:
logits = model(x)
prob = tf.nn.softmax(logits,axis=1)
pred = tf.argmax(prob,axis=1)
pred = tf.cast(pred,dtype=tf.int32) correct = tf.cast(tf.equal(pred,y),dtype=tf.int32)
correct = tf.reduce_sum(correct) total_num += x.shape[0]
total_correct += correct
acc = total_correct/total_num print("acc:",acc) if __name__ == '__main__':
main()

打印网络结构和参数量如下:

Resnet——深度残差网络(二)的更多相关文章

  1. Resnet——深度残差网络(一)

    我们都知道随着神经网络深度的加深,训练过程中会很容易产生误差的积累,从而出现梯度爆炸和梯度消散的问题,这是由于随着网络层数的增多,在网络中反向传播的梯度会随着连乘变得不稳定(特别大或特别小),出现最多 ...

  2. 使用dlib中的深度残差网络(ResNet)实现实时人脸识别

    opencv中提供的基于haar特征级联进行人脸检测的方法效果非常不好,本文使用dlib中提供的人脸检测方法(使用HOG特征或卷积神经网方法),并使用提供的深度残差网络(ResNet)实现实时人脸识别 ...

  3. 深度残差网络(DRN)ResNet网络原理

    一说起“深度学习”,自然就联想到它非常显著的特点“深.深.深”(重要的事说三遍),通过很深层次的网络实现准确率非常高的图像识别.语音识别等能力.因此,我们自然很容易就想到:深的网络一般会比浅的网络效果 ...

  4. Dual Path Networks(DPN)——一种结合了ResNet和DenseNet优势的新型卷积网络结构。深度残差网络通过残差旁支通路再利用特征,但残差通道不善于探索新特征。密集连接网络通过密集连接通路探索新特征,但有高冗余度。

    如何评价Dual Path Networks(DPN)? 论文链接:https://arxiv.org/pdf/1707.01629v1.pdf在ImagNet-1k数据集上,浅DPN超过了最好的Re ...

  5. CNN卷积神经网络_深度残差网络 ResNet——解决神经网络过深反而引起误差增加的根本问题,Highway NetWork 则允许保留一定比例的原始输入 x。(这种思想在inception模型也有,例如卷积是concat并行,而不是串行)这样前面一层的信息,有一定比例可以不经过矩阵乘法和非线性变换,直接传输到下一层,仿佛一条信息高速公路,因此得名Highway Network

    from:https://blog.csdn.net/diamonjoy_zone/article/details/70904212 环境:Win8.1 TensorFlow1.0.1 软件:Anac ...

  6. 关于深度残差网络(Deep residual network, ResNet)

    题外话: From <白话深度学习与TensorFlow> 深度残差网络: 深度残差网络的设计就是为了克服这种由于网络深度加深而产生的学习效率变低,准确率无法有效提升的问题(也称为网络退化 ...

  7. 深度残差网络(ResNet)

    引言 对于传统的深度学习网络应用来说,网络越深,所能学到的东西越多.当然收敛速度也就越慢,训练时间越长,然而深度到了一定程度之后就会发现越往深学习率越低的情况,甚至在一些场景下,网络层数越深反而降低了 ...

  8. 深度残差网络——ResNet学习笔记

    深度残差网络—ResNet总结 写于:2019.03.15—大连理工大学 论文名称:Deep Residual Learning for Image Recognition 作者:微软亚洲研究院的何凯 ...

  9. ResNet(深度残差网络)

    注:平原改为简单堆叠网络 一般x是恒等映射,当x与fx尺寸不同的时候,w作用就是将x变成和fx尺寸相同. 过程: 先用w将x进行恒等映射.扩维映射或者降维映射d得到wx.(没有参数,不需要优化器训练) ...

随机推荐

  1. FileUpload实现图片的无重上传

    //如果确认了上传文件,则判断文件类型是否符合要求        if (this.FileUpload1.HasFile)        {            //获取上传文件的后缀       ...

  2. chrome 安装

    Centos7 yum安装chrome浏览器   跟着这个教程安装的:Centos7安装chrome浏览器 (点击) 1. 配置yum源 在目录 /etc/yum.repos.d/ 下新建文件 goo ...

  3. linux下误清用户/home下的文件怎么办?

    2016-08-19 10:38:10   有时候我们不小心把home目录下的用户目录删除了,出现上图情况,每行开头直接变成-bash-3.2$这种形式而不是[lyp@centos7 ~]$这种,这时 ...

  4. ios--->OC中Protocol理解及在代理模式中的使用

    OC中Protocol理解及在代理模式中的使用 Protocol基本概念 Protocol翻译过来, 叫做"协议",其作用就是用来声明一些方法: Protocol(协议)的作用 定 ...

  5. Docker基础内容之容器

    前言 容器是独立运行的一个或一组应用以及它们的运行态环境. 相关命令 启动容器相关命令 docker run 运行一个ubuntu14.04版本的容器,如果这个镜像本地不存在则会去默认仓库中下载 do ...

  6. Unicode标准以及其常见的编码方案

    目录 基本概念 码位 码位的类型 编码方案 UTF-32 UTF-16 UTF-8 参考资料 Unicode标准为每一个字符提供一个唯一的数字,而不用区分平台.语言等因素. The Unicode S ...

  7. Linux之时间同步操作

    Linux之时间同步操作 时间同步操作应用的命令 yum进行软件安装,软件安装过程中如遇到询问,一律选择y,ntp是时间同步命令 [root@localhost ~]# yum -y install ...

  8. 《windows程序设计》第三章学习心得

    第三章是基于对一个windows窗口的学习,来达到对windows程序运行机制的理解. 从语言的角度看消息机制,Windows给程序发消息的本质就是调用"窗口过程"函数. Don' ...

  9. 使用卷影拷贝提取ntds.dit

    一.简介 通常情况下,即使拥有管理员权限,也无法读取域控制器中的C:\Windows\NTDS\ntds.dit文件.使用windows本地卷影拷贝服务,就可以获得该文件的副本. 在活动目录中,所有的 ...

  10. 浅谈二分—— by hyl天梦

    二分 解决范围 二分法可以用来解决这一系列具有单调性质的题,例如求单调函数的零点 其实在小学奥数中就用到了二分法 例如手动开根号,再比如猜数游戏 二分的具体过程就是先取一个中间值,判定一下正确答案在哪 ...