tensorflow2.0建议使用tf.keras作为构建神经网络的高级API

接下来我就使用tensorflow实现VGG16去训练数据

背景介绍:

2012年 AlexNet 在 ImageNet 上显著的降低了分类错误率,深度神经网络进入迅速发展阶段。在2014年牛津大学机器人实验室尝试构建了更深的网络,文章中称为"VERY DEEP CONVOLUTIONAL NETWORKS",如VGG16,有16层,虽然现在看起来稀疏平常,但与 AlexNet 相比,翻了几倍。这个阶段,主要是没有解决网络太深梯度反向传播消失的问题,且受限于GPU等硬件设备的性能,所以深度网络不易于训练。不过,VGG 显然是当时最好的图像分类模型,斩获 ILSVRC 比赛冠军。顺便说下,2012年之后,标准数据集主要是ImageNet,到后来又有微软的COCO数据集。



上图为VGG16的网络结构,我们可以从输入层开始数,数到最终的输出层,正好是16层,有兴趣的可以数一数,接下来我们就来实现这样的一个网络

实现:

  1. Stage 1

    卷积层1:conv1

    卷积层2:conv2

    池化层1:pool1

    激活函数1:relu1
		self.conv1 = layers.Conv2D(64,3,1,'same')
self.conv2 = layers.Conv2D(64,3,1,'same')
self.pool1 = layers.MaxPool2D(2,1,'same')
self.relu1 = layers.ReLU()
  1. Stage 2

    卷积层3:conv3

    卷积层4:conv4

    池化层2:pool2

    激活函数2:relu2
		self.conv3 = layers.Conv2D(128, 3, 1, 'same')
self.conv4 = layers.Conv2D(128, 3, 1, 'same')
self.pool2 = layers.MaxPool2D(2, 1, 'same')
self.relu2 = layers.ReLU()
  1. Stage 3

    卷积层5:conv5

    卷积层6:conv6

    卷积层7:conv7

    池化层3:pool3

    激活函数3:relu3

self.conv5 = layers.Conv2D(256,3,1,'same')
self.conv6 = layers.Conv2D(256,3,1,'same')
self.conv7 = layers.Conv2D(256,3,1,'same')
self.pool3 = layers.MaxPool2D(2,1,'same')
self.relu3 = layers.ReLU()
  1. Stage 4

    卷积层8:conv8

    卷积层9:conv9

    卷积层10:conv10

    池化层4:pool4

    激活函数4:relu4
		self.conv8 = layers.Conv2D(512,3,1,'same')
self.conv9 = layers.Conv2D(512,3,1,'same')
self.conv10 = layers.Conv2D(512,3,1,'same')
self.pool4=layers.MaxPool2D(2,1,'same')
self.relu4=layers.ReLU()
  1. Statage 5

    卷积层11:conv11

    卷积层12:conv12

    卷积层13:conv13

    池化层5:pool5

    激活函数5:relu5
		self.conv8 = layers.Conv2D(512,3,1,'same')
self.conv9 = layers.Conv2D(512,3,1,'same')
self.conv10 = layers.Conv2D(512,3,1,'same')
self.pool4=layers.MaxPool2D(2,1,'same')
self.relu4=layers.ReLU()

接着实现call函数:

    def call(self, inputs, training=None, mask=None):
out=self.conv1(inputs)
out=self.conv2(out)
out=self.relu1(out)
out=self.pool1(out) out=self.conv3(out)
out=self.conv4(out)
out=self.relu2(out)
out=self.pool2(out) out=self.conv5(out)
out=self.conv6(out)
out=self.conv7(out)
out=self.relu3(out)
out=self.pool3(out) out=self.conv8(out)
out=self.conv9(out)
out=self.conv10(out)
out=self.relu4(out)
out=self.pool4(out)
out=self.conv11(out)
out=self.conv12(out)
out=self.conv13(out)
out=self.relu5(out)
out=self.pool5(out)
out=self.avgpool(out)
out=self.fc(out)
return out

网络搭建好了之后,我们可以使用model.summary()的方法查看一下网络结构和相关参数

查看网络结构:

训练

1. 数据的预处理

import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets,Sequential,layers,optimizers,metrics
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' def preprocess(x,y):
x=2*tf.cast(x,dtype=tf.float32)/255.-1.
y=tf.cast(y,dtype=tf.int32)
return x,y batchsz=128
# load dataset
(x_train,y_train),(x_test,y_test)=datasets.cifar10.load_data()
y_train=tf.squeeze(y_train)
y_test=tf.squeeze(y_test)
y_train=tf.one_hot(y_train,depth=10)
y_test=tf.one_hot(y_test,depth=10)

2. 加载数据

这里为了训练方便,就使用CIFAR10的数据集了,获取该数据集很方便,只需keras.datasets.cifar10.load_data()即可获得

# train data
train_date=tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_date=train_date.map(preprocess).shuffle(10000).batch(batchsz)
# test data
test_data=tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_data=test_data.map(preprocess).batch(batchsz)

3. 搭建网络结构

class VGG(keras.Model):
def __init__(self,num_class=10):
super(VGG, self).__init__()
self.conv1 = layers.Conv2D(64,3,1,'same')
self.conv2 = layers.Conv2D(64,3,1,'same')
self.pool1 = layers.MaxPool2D(2,1,'same')
self.relu1 = layers.ReLU() self.conv3 = layers.Conv2D(128, 3, 1, 'same')
self.conv4 = layers.Conv2D(128, 3, 1, 'same')
self.relu2 = layers.ReLU()
self.pool2 = layers.MaxPool2D(2, 1, 'same') self.conv5 = layers.Conv2D(256,3,1,'same')
self.conv6 = layers.Conv2D(256,3,1,'same')
self.conv7 = layers.Conv2D(256,3,1,'same')
self.relu3 = layers.ReLU()
self.pool3 = layers.MaxPool2D(2,1,'same') self.conv8 = layers.Conv2D(512,3,1,'same')
self.conv9 = layers.Conv2D(512,3,1,'same')
self.conv10 = layers.Conv2D(512,3,1,'same')
self.relu4=layers.ReLU()
self.pool4=layers.MaxPool2D(2,1,'same') self.conv11 = layers.Conv2D(512,3,1,'same')
self.conv12 = layers.Conv2D(512,3,1,'same')
self.conv13 = layers.Conv2D(512,3,1,'same')
self.relu5 = layers.ReLU()
self.pool5 = layers.MaxPool2D(2,1,'same') self.avgpool = layers.GlobalAveragePooling2D()
self.fc = layers.Dense(num_class) def call(self, inputs, training=None, mask=None):
out=self.conv1(inputs)
out=self.conv2(out)
out=self.relu1(out)
out=self.pool1(out) out=self.conv3(out)
out=self.conv4(out)
out=self.relu2(out)
out=self.pool2(out) out=self.conv5(out)
out=self.conv6(out)
out=self.conv7(out)
out=self.relu3(out)
out=self.pool3(out) out=self.conv8(out)
out=self.conv9(out)
out=self.conv10(out)
out=self.relu4(out)
out=self.pool4(out) out=self.conv11(out)
out=self.conv12(out)
out=self.conv13(out)
out=self.relu5(out)
out=self.pool5(out) out=self.avgpool(out)
out=self.fc(out)
return out

可以看到搭建网络的方式和pytorch很相似

4. 预训练

model=VGG()
model.build(input_shape=(None,32,32,3))
model.compile(optimizer=optimizers.Adam(1e-4),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_date, epochs=15, validation_data=test_data, validation_freq=1)

5. 训练数据

这里我们同样使用Tensorflow提供的一个接口compile实现训练,大家也可以改用其他的方法实现数据的更新。

下面为训练的数据,可以看到前半部的效果还是不错的,test_acc达到了82%,但是后面出现了过拟合的现象,笔者这里训练了30个epoch,这30个epoch训练了大概四个小时的样子;大家有兴趣的话,可以将参数进行优化,防止过拟合,再训练以达到更好的效果。最后希望大家多多实践,共同进步。

Epoch 1/30
391/391 [==============================] - 428s 1s/step - loss: 1.6934 - acc: 0.3565 - val_loss: 1.4043 - val_acc: 0.4945
Epoch 2/30
391/391 [==============================] - 424s 1s/step - loss: 1.2890 - acc: 0.5294 - val_loss: 1.1885 - val_acc: 0.5673
Epoch 3/30
391/391 [==============================] - 423s 1s/step - loss: 1.0865 - acc: 0.6059 - val_loss: 1.0395 - val_acc: 0.6334
Epoch 4/30
391/391 [==============================] - 423s 1s/step - loss: 0.9334 - acc: 0.6637 - val_loss: 0.8658 - val_acc: 0.6929
Epoch 5/30
391/391 [==============================] - 422s 1s/step - loss: 0.8382 - acc: 0.7040 - val_loss: 0.8632 - val_acc: 0.7012
Epoch 6/30
391/391 [==============================] - 422s 1s/step - loss: 0.7352 - acc: 0.7410 - val_loss: 0.8306 - val_acc: 0.7171
Epoch 7/30
391/391 [==============================] - 422s 1s/step - loss: 0.6655 - acc: 0.7671 - val_loss: 0.8549 - val_acc: 0.7071
Epoch 8/30
391/391 [==============================] - 421s 1s/step - loss: 0.5957 - acc: 0.7929 - val_loss: 0.6909 - val_acc: 0.7699
Epoch 9/30
391/391 [==============================] - 421s 1s/step - loss: 0.5478 - acc: 0.8103 - val_loss: 0.6670 - val_acc: 0.7755
Epoch 10/30
391/391 [==============================] - 421s 1s/step - loss: 0.5001 - acc: 0.8261 - val_loss: 0.5885 - val_acc: 0.8019
Epoch 11/30
391/391 [==============================] - 421s 1s/step - loss: 0.4494 - acc: 0.8442 - val_loss: 0.6598 - val_acc: 0.7872
Epoch 12/30
391/391 [==============================] - 420s 1s/step - loss: 0.4189 - acc: 0.8534 - val_loss: 0.6492 - val_acc: 0.7789
Epoch 13/30
391/391 [==============================] - 420s 1s/step - loss: 0.3793 - acc: 0.8680 - val_loss: 0.5678 - val_acc: 0.8087
Epoch 14/30
391/391 [==============================] - 421s 1s/step - loss: 0.3489 - acc: 0.8777 - val_loss: 0.6082 - val_acc: 0.8030
Epoch 15/30
391/391 [==============================] - 421s 1s/step - loss: 0.3194 - acc: 0.8881 - val_loss: 0.6338 - val_acc: 0.8107
Epoch 16/30
391/391 [==============================] - 421s 1s/step - loss: 0.2814 - acc: 0.9022 - val_loss: 0.6844 - val_acc: 0.7938
Epoch 17/30
391/391 [==============================] - 420s 1s/step - loss: 0.2697 - acc: 0.9045 - val_loss: 0.6159 - val_acc: 0.8151
Epoch 18/30
391/391 [==============================] - 421s 1s/step - loss: 0.2419 - acc: 0.9153 - val_loss: 0.6183 - val_acc: 0.8203
Epoch 19/30
391/391 [==============================] - 421s 1s/step - loss: 0.2186 - acc: 0.9232 - val_loss: 0.6346 - val_acc: 0.8121
Epoch 20/30
391/391 [==============================] - 420s 1s/step - loss: 0.2103 - acc: 0.9251 - val_loss: 0.7089 - val_acc: 0.8005
Epoch 21/30
391/391 [==============================] - 420s 1s/step - loss: 0.1888 - acc: 0.9328 - val_loss: 0.7297 - val_acc: 0.8097
Epoch 22/30
390/391 [============================>.] - ETA: 1s - loss: 0.1830 - acc: 0.9341Epoch 23/30
391/391 [==============================] - 419s 1s/step - loss: 0.1679 - acc: 0.9417 - val_loss: 0.7396 - val_acc: 0.8165
Epoch 24/30
391/391 [==============================] - 420s 1s/step - loss: 0.1646 - acc: 0.9419 - val_loss: 0.6877 - val_acc: 0.8199
Epoch 25/30
391/391 [==============================] - 420s 1s/step - loss: 0.1406 - acc: 0.9504 - val_loss: 0.7773 - val_acc: 0.8163
Epoch 26/30
391/391 [==============================] - 420s 1s/step - loss: 0.1590 - acc: 0.9438 - val_loss: 0.7408 - val_acc: 0.8143
Epoch 27/30
391/391 [==============================] - 420s 1s/step - loss: 0.1392 - acc: 0.9502 - val_loss: 0.7879 - val_acc: 0.8115
Epoch 28/30
391/391 [==============================] - 420s 1s/step - loss: 0.1309 - acc: 0.9537 - val_loss: 0.8269 - val_acc: 0.8189
Epoch 29/30
391/391 [==============================] - 420s 1s/step - loss: 0.1226 - acc: 0.9582 - val_loss: 0.7843 - val_acc: 0.8182
Epoch 30/30
391/391 [==============================] - 420s 1s/step - loss: 0.1167 - acc: 0.9591 - val_loss: 0.8827 - val_acc: 0.8187

Tensorflow2.0:使用Keras自定义网络实战的更多相关文章

  1. CIFAR10自定义网络实战

    目录 CIFAR10 MyDenseLayer CIFAR10 MyDenseLayer import os import tensorflow as tf from tensorflow.keras ...

  2. TensorFlow2.0(11):tf.keras建模三部曲

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  3. 推荐模型AutoRec:原理介绍与TensorFlow2.0实现

    1. 简介 本篇文章先简单介绍论文思路,然后使用Tensoflow2.0.Keras API复现算法部分.包括: 自定义模型 自定义损失函数 自定义评价指标RMSE 就题目而言<AutoRec: ...

  4. 推荐模型NeuralCF:原理介绍与TensorFlow2.0实现

    1. 简介 NCF是协同过滤在神经网络上的实现--神经网络协同过滤.由新加坡国立大学与2017年提出. 我们知道,在协同过滤的基础上发展来的矩阵分解取得了巨大的成就,但是矩阵分解得到低维隐向量求内积是 ...

  5. 推荐模型DeepCrossing: 原理介绍与TensorFlow2.0实现

    DeepCrossing是在AutoRec之后,微软完整的将深度学习应用在推荐系统的模型.其应用场景是搜索推荐广告中,解决了特征工程,稀疏向量稠密化,多层神经网路的优化拟合等问题.所使用的特征在论文中 ...

  6. tensorflow2.0 在pycharm下提示问题

    tensorflow2.0 使用keras一般通过tensorflow.keras来使用,但是pycharm没有提示,原因是因为实际的keras路径放在tensorflow/python/keras, ...

  7. 一文上手Tensorflow2.0之tf.keras(三)

    系列文章目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU ...

  8. 『TensorFlow2.0正式版』TF2.0+Keras速成教程·零:开篇简介与环境准备

    此篇教程参考自TensorFlow 2.0 + Keras Crash Course,在原文的基础上进行了适当的总结与改编,以适应于国内开发者的理解与使用,水平有限,如果写的不对的地方欢迎大家评论指出 ...

  9. TensorFlow2.0教程-使用keras训练模型

    1.一般的模型构造.训练.测试流程 # 模型构造 inputs = keras.Input(shape=(784,), name='mnist_input') h1 = layers.Dense(64 ...

  10. Servlet3.0的注解自定义原生Listener监听器实战

    简介:监听器介绍和Servlet3.0的注解自定义原生Listener监听器实战 自定义Listener(常用的监听器 servletContextListener.httpSessionListen ...

随机推荐

  1. 万字长文硬核AQS源码分析

    阅读本文前,需要储备的知识点如下,点击链接直接跳转. java线程详解 Java不能操作内存?Unsafe了解一下 一文读懂LockSupport AQS简介 AQS即AbstractQueuedSy ...

  2. 文心一言 VS 讯飞星火 VS chatgpt (82)-- 算法导论8.1 1题

    一.如果用go语言,在一棵比较排序算法的决策树中,一个叶结点可能的最小深度是多少? 文心一言: 在一个排序决策树(如二叉搜索树)中,每个叶节点的最小深度等于输入数据中最大元素与最小元素之间的位距离.这 ...

  3. MIT 6.828 Lab实验记录 —— lab1 Booting PC

    实验参考信息 MIT 6.828 lab1 讲义地址 MIT 6.828 课程 Schedule MIT 6.828 lab 环境搭建参考 MIT 6.828 lab 工具guide Brennan' ...

  4. API数据接口的应用步骤

    API(Application Programming Interface)是现代软件应用程序开发中的一项重要技术,它允许不同的应用程序之间进行通信和数据交换.API接口通过提供统一的访问点,使得应用 ...

  5. 每日一库:pprof简介

    pprof简介 pprof是Go语言的一个性能分析库,它可以帮助开发者找出程序中的性能瓶颈.pprof提供了CPU分析.内存分析.阻塞分析等多种性能分析功能. 以下是pprof的主要特性: CPU分析 ...

  6. 解决SVN死锁问题

    svn执行clean up后出现提示:svn cleanup failed–previous operation has not finished; run cleanup if it was int ...

  7. MQ系列14:MQ如何做到消息延时处理

    MQ系列1:消息中间件执行原理 MQ系列2:消息中间件的技术选型 MQ系列3:RocketMQ 架构分析 MQ系列4:NameServer 原理解析 MQ系列5:RocketMQ消息的发送模式 MQ系 ...

  8. 后浪搞的在线版 Windows 12「GitHub 热点速览」

    本周比较火的莫过于 3 位初中生开源的 Windows 12 网页版,虽然项目完成度不如在线版的 Windows 11,但是不妨一看.除了后生可畏的 win12 之外,开源不到一周的 open-int ...

  9. 造轮子之EventBus

    前面基础管理的功能基本开发完了,接下来我们来优化一下开发功能,来添加EventBus功能.EventBus也是我们使用场景非常广的东西.这里我会实现一个本地的EventBus以及分布式的EventBu ...

  10. js闭包使用之处

    1.循环绑定 No Use:   var lists = document.getElementsByTagName('li');   for(var i=0;i<lists.length;i& ...