基于上一篇resnet网络结构进行实战。

再来贴一下resnet的基本结构方便与代码进行对比

resnet的自定义类如下:

import tensorflow as tf
from tensorflow import keras class BasicBlock(keras.layers.Layer): # filter_num指定通道数,stride指定步长
def __init__(self,filter_num,stride=1):
super(BasicBlock, self).__init__() # 注意padding=same并不总使得输入维度等于输出维度,而是对不同的步长有不同的策略,使得滑动更加完整
self.conv1 = keras.layers.Conv2D(filter_num,(3,3),strides=stride,padding='same')
self.bn1 = keras.layers.BatchNormalization()
self.relu = keras.layers.Activation('relu') self.conv2 = keras.layers.Conv2D(filter_num,(3,3),strides=1,padding='same')
self.bn2 = keras.layers.BatchNormalization() if stride!=1:
self.dowmsample = keras.Sequential()
self.dowmsample.add(keras.layers.Conv2D(filter_num,(1,1),strides=stride))
else:
self.dowmsample = lambda x:x def call(self, inputs, training=None): out = self.conv1(inputs)
out = self.bn1(out)
out = self.relu(out) out = self.conv2(out)
out = self.bn2(out) identity = self.dowmsample(inputs) output = keras.layers.add([out,identity])
output = tf.nn.relu(output) return output class ResNet(keras.Model): # resnet基本结构为[2,2,2,2],即分为四个部分,每个部分又分两个小部分
def __init__(self,layer_dims,num_classes=100):
super(ResNet,self).__init__() # 预处理层
self.stem = keras.Sequential([
keras.layers.Conv2D(64,(3,3),strides=(1,1)),
keras.layers.BatchNormalization(),
keras.layers.Activation('relu'),
keras.layers.MaxPool2D(pool_size=(2,2),strides=(1,1),padding='same')
]) self.layer1 = self.build_resblock(64,layer_dims[0])
self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
self.layer4 = self.build_resblock(512, layer_dims[3], stride=2) # 自适应输出,方便送入全连层进行分类
self.avgpool = keras.layers.GlobalAveragePooling2D()
self.fc = keras.layers.Dense(num_classes) def call(self, inputs, training=None):
x = self.stem(inputs) x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x) x = self.avgpool(x)
x = self.fc(x) return x def build_resblock(self,filter_num,blocks,stride=1):
res_blocks = keras.Sequential();
res_blocks.add(BasicBlock(filter_num,stride)) for _ in range(1,blocks):
res_blocks.add(BasicBlock(filter_num,1)) return res_blocks def resnet18():
return ResNet([2,2,2,2])

训练过程如下:

import tensorflow as tf
from tensorflow import keras
import os
from resnet import resnet18 os.environ['TF_CPP_MIN_LOG'] = '' def preprocess(x,y):
x = 2*tf.cast(x,dtype=tf.float32)/255.-1
y = tf.cast(y,dtype=tf.int32)
return x,y (x,y),(x_test,y_test) = keras.datasets.cifar100.load_data()
y = tf.squeeze(y,axis=1)
y_test = tf.squeeze(y_test,axis=1)
print(x.shape,y.shape,x_test.shape,y_test.shape) train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.shuffle(1000).map(preprocess).batch(64) test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = train_db.map(preprocess).batch(64) def main():
model = resnet18()
model.build(input_shape=(None,32,32,3))
optimizer = keras.optimizers.Adam(lr=1e-3)
model.summary() for epoch in range(50):
for step,(x,y) in enumerate(train_db):
with tf.GradientTape() as tape:
logits = model(x)
y_onehot = tf.one_hot(y,depth=10)
loss = tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True)
loss = tf.reduce_mean(loss) gradient = tape.gradient(loss,model.trainable_variables)
optimizer.apply_gradients(zip(gradient,model.trainable_variables)) if step % 100 == 0:
print(epoch,step,'loss:',float(loss)) total_num = 0
total_correct = 0
for x,y in test_db:
logits = model(x)
prob = tf.nn.softmax(logits,axis=1)
pred = tf.argmax(prob,axis=1)
pred = tf.cast(pred,dtype=tf.int32) correct = tf.cast(tf.equal(pred,y),dtype=tf.int32)
correct = tf.reduce_sum(correct) total_num += x.shape[0]
total_correct += correct
acc = total_correct/total_num print("acc:",acc) if __name__ == '__main__':
main()

打印网络结构和参数量如下:

Resnet——深度残差网络(二)的更多相关文章

  1. Resnet——深度残差网络(一)

    我们都知道随着神经网络深度的加深,训练过程中会很容易产生误差的积累,从而出现梯度爆炸和梯度消散的问题,这是由于随着网络层数的增多,在网络中反向传播的梯度会随着连乘变得不稳定(特别大或特别小),出现最多 ...

  2. 使用dlib中的深度残差网络(ResNet)实现实时人脸识别

    opencv中提供的基于haar特征级联进行人脸检测的方法效果非常不好,本文使用dlib中提供的人脸检测方法(使用HOG特征或卷积神经网方法),并使用提供的深度残差网络(ResNet)实现实时人脸识别 ...

  3. 深度残差网络(DRN)ResNet网络原理

    一说起“深度学习”,自然就联想到它非常显著的特点“深.深.深”(重要的事说三遍),通过很深层次的网络实现准确率非常高的图像识别.语音识别等能力.因此,我们自然很容易就想到:深的网络一般会比浅的网络效果 ...

  4. Dual Path Networks(DPN)——一种结合了ResNet和DenseNet优势的新型卷积网络结构。深度残差网络通过残差旁支通路再利用特征,但残差通道不善于探索新特征。密集连接网络通过密集连接通路探索新特征,但有高冗余度。

    如何评价Dual Path Networks(DPN)? 论文链接:https://arxiv.org/pdf/1707.01629v1.pdf在ImagNet-1k数据集上,浅DPN超过了最好的Re ...

  5. CNN卷积神经网络_深度残差网络 ResNet——解决神经网络过深反而引起误差增加的根本问题,Highway NetWork 则允许保留一定比例的原始输入 x。(这种思想在inception模型也有,例如卷积是concat并行,而不是串行)这样前面一层的信息,有一定比例可以不经过矩阵乘法和非线性变换,直接传输到下一层,仿佛一条信息高速公路,因此得名Highway Network

    from:https://blog.csdn.net/diamonjoy_zone/article/details/70904212 环境:Win8.1 TensorFlow1.0.1 软件:Anac ...

  6. 关于深度残差网络(Deep residual network, ResNet)

    题外话: From <白话深度学习与TensorFlow> 深度残差网络: 深度残差网络的设计就是为了克服这种由于网络深度加深而产生的学习效率变低,准确率无法有效提升的问题(也称为网络退化 ...

  7. 深度残差网络(ResNet)

    引言 对于传统的深度学习网络应用来说,网络越深,所能学到的东西越多.当然收敛速度也就越慢,训练时间越长,然而深度到了一定程度之后就会发现越往深学习率越低的情况,甚至在一些场景下,网络层数越深反而降低了 ...

  8. 深度残差网络——ResNet学习笔记

    深度残差网络—ResNet总结 写于:2019.03.15—大连理工大学 论文名称:Deep Residual Learning for Image Recognition 作者:微软亚洲研究院的何凯 ...

  9. ResNet(深度残差网络)

    注:平原改为简单堆叠网络 一般x是恒等映射,当x与fx尺寸不同的时候,w作用就是将x变成和fx尺寸相同. 过程: 先用w将x进行恒等映射.扩维映射或者降维映射d得到wx.(没有参数,不需要优化器训练) ...

随机推荐

  1. css的字体单位

    在css中的字体单位主要以px.em.rem为主.其中px也就是像素,是一种字体长度,它的长度是相对于显示器的品目分辨率而言的.一般情况下在浏览器中默认字体的大小是16px.其中em是相对字体.em的 ...

  2. ios---CoreLocation框架实现定位功能

    CoreLocation框架实现定位功能(iOS8.0之后) // // ViewController.m // 定位 // // Created by admin on 2017/9/20. // ...

  3. 用set、map等存储自定义结构体时容器内部判别各元素是否相同的注意事项

    STL作为通用模板极大地方便了C++使用者的编程,因为它可以存储任意数据类型的元素 如果我们想用set与map来存储自定义结构体时,如下 struct pp { double xx; double y ...

  4. Git详解之安装

    前言 是时候动手尝试下 Git 了,不过得先安装好它.有许多种安装方式,主要分为两种,一种是通过编译源代码来安装:另一种是使用为特定平台预编译好的安装包. 从源代码安装 若是条件允许,从源代码安装有很 ...

  5. Shell常用语句及结构

    条件判断语句之if if 语句通过关系运算符判断表达式的真假来决定执行哪个分支:shell有三种if语句样式,如下: 语句1 if [ expression ] then Statement(s) t ...

  6. ROS之服务

    服务(service)是另一种在节点之间传递数据的方法,服务其实就是同步的跨进程函数调用,它能够让一个节点调用运行在另一个节点中的函数. 我们就像之前消息类型一样定义这个函数的输入/输出.服务端(提供 ...

  7. python day01学习

    1.python语言 # 89年 龟叔2.python的特点 # 优点 : 简明 简单 跨平台性好 # 缺点 : 慢 -执行速度相对其他语言慢 # 编程语言的分类:  # 编译型语言: c c++ j ...

  8. 尝试用 Python 写了个病毒传播模拟程序

    病毒扩散仿真程序,用 python 也可以. 概述 事情是这样的,B 站 UP 主 @ele 实验室,写了一个简单的疫情传播仿真程序,告诉大家在家待着的重要性,视频相信大家都看过了,并且 UP 主也放 ...

  9. lua学习之语句篇

    语句 赋值 修改一个变量或者修改 table 中的一个字段的值 多重赋值,lua 先对等号右边的所有元素进行求值,然后再赋值 值的个数小于变量的个数,那么多余的变量就置为 nil 初始化变量,应该为每 ...

  10. sql中常量和变量的引用

    String name =jtf.getText().trim(); String sql="select * from stu where stuname='  "+name+& ...