【Keras篇】---利用keras改写VGG16经典模型在手写数字识别体中的应用
一、前述
VGG16是由16层神经网络构成的经典模型,包括多层卷积,多层全连接层,一般我们改写的时候卷积层基本不动,全连接层从后面几层依次向前改写,因为先改参数较小的。
二、具体
1、因为本文中代码需要依赖OpenCV,所以第一步先安装OpenCV
因为VGG要求输入244*244,而数据集是28*28的,所以需要通过OpenCV在代码里去改变。
2、把模型下载后离线放入用户的管理目录下面,这样训练的时候就不需要从网上再下载了
3、我们保留的是除了全连接的所有层。
4、选择数据生成器,在真正使用的时候才会生成数据,加载到内存,前面yield只是做了一个标记
代码:
- # 使用迁移学习的思想,以VGG16作为模板搭建模型,训练识别手写字体
- # 引入VGG16模块
- from keras.applications.vgg16 import VGG16
- # 其次加载其他模块
- from keras.layers import Input
- from keras.layers import Flatten
- from keras.layers import Dense
- from keras.layers import Dropout
- from keras.models import Model
- from keras.optimizers import SGD
- # 加载字体库作为训练样本
- from keras.datasets import mnist
- # 加载OpenCV(在命令行中窗口中输入pip install opencv-python),这里为了后期对图像的处理,
- # 大家使用pip install C:\Users\28542\Downloads\opencv_python-3.4.1+contrib-cp35-cp35m-win_amd64.whl
- # 比如尺寸变化和Channel变化。这些变化是为了使图像满足VGG16所需要的输入格式
- import cv2
- import h5py as h5py
- import numpy as np
- # 建立一个模型,其类型是Keras的Model类对象,我们构建的模型会将VGG16顶层(全连接层)去掉,只保留其余的网络
- # 结构。这里用include_top = False表明我们迁移除顶层以外的其余网络结构到自己的模型中
- # VGG模型对于输入图像数据要求高宽至少为48个像素点,由于硬件配置限制,我们选用48个像素点而不是原来
- # VGG16所采用的224个像素点。即使这样仍然需要24GB以上的内存,或者使用数据生成器
- model_vgg = VGG16(include_top=False, weights='imagenet', input_shape=(48, 48, 3))#输入进来的数据是48*48 3通道
- #选择imagnet,会选择当年大赛的初始参数
- #include_top=False 去掉最后3层的全连接层看源码可知
- for layer in model_vgg.layers:
- layer.trainable = False#别去调整之前的卷积层的参数
- model = Flatten(name='flatten')(model_vgg.output)#去掉全连接层,前面都是卷积层
- model = Dense(4096, activation='relu', name='fc1')(model)
- model = Dense(4096, activation='relu', name='fc2')(model)
- model = Dropout(0.5)(model)
- model = Dense(10, activation='softmax')(model)#model就是最后的y
- model_vgg_mnist = Model(inputs=model_vgg.input, outputs=model, name='vgg16')
- #把model_vgg.input X传进来
- #把model Y传进来 就可以训练模型了
- # 打印模型结构,包括所需要的参数
- model_vgg_mnist.summary()
- #以下是原版的模型结构 224*224
- model_vgg = VGG16(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
- for layer in model_vgg.layers:
- layer.trainable = False#别去调整之前的卷积层的参数
- model = Flatten()(model_vgg.output)
- model = Dense(4096, activation='relu', name='fc1')(model)
- model = Dense(4096, activation='relu', name='fc2')(model)
- model = Dropout(0.5)(model)
- model = Dense(10, activation='softmax', name='prediction')(model)
- model_vgg_mnist_pretrain = Model(model_vgg.input, model, name='vgg16_pretrain')
- model_vgg_mnist_pretrain.summary()
- # 新的模型不需要训练原有卷积结构里面的1471万个参数,但是注意参数还是来自于最后输出层前的两个
- # 全连接层,一共有1.2亿个参数需要训练
- sgd = SGD(lr=0.05, decay=1e-5)#lr 学习率 decay 梯度的逐渐减小 每迭代一次梯度就下降 0.05*(1-(10的-5))这样来变
- #随着越来越下降 学习率越来越小 步子越小
- model_vgg_mnist.compile(loss='categorical_crossentropy',
- optimizer=sgd, metrics=['accuracy'])
- # 因为VGG16对网络输入层需要接受3通道的数据的要求,我们用OpenCV把图像从32*32变成224*224,把黑白图像转成RGB图像
- # 并把训练数据转化成张量形式,供keras输入
- (X_train, y_train), (X_test, y_test) = mnist.load_data("../test_data_home")
- X_train, y_train = X_train[:1000], y_train[:1000]#训练集1000条
- X_test, y_test = X_test[:100], y_test[:100]#测试集100条
- X_train = [cv2.cvtColor(cv2.resize(i, (48, 48)), cv2.COLOR_GRAY2RGB)
- for i in X_train]#变成彩色的
- #np.concatenate拼接到一起把
- X_train = np.concatenate([arr[np.newaxis] for arr in X_train]).astype('float32')
- X_test = [cv2.cvtColor(cv2.resize(i, (48, 48)), cv2.COLOR_GRAY2RGB)
- for i in X_test]
- X_test = np.concatenate([arr[np.newaxis] for arr in X_test]).astype('float32')
- print(X_train.shape)
- print(X_test.shape)
- X_train = X_train / 255
- X_test = X_test / 255
- def tran_y(y):
- y_ohe = np.zeros(10)
- y_ohe[y] = 1
- return y_ohe
- y_train_ohe = np.array([tran_y(y_train[i]) for i in range(len(y_train))])
- y_test_ohe = np.array([tran_y(y_test[i]) for i in range(len(y_test))])
- model_vgg_mnist.fit(X_train, y_train_ohe, validation_data=(X_test, y_test_ohe),
- epochs=100, batch_size=50)
结果:
自定义的网络层:
【Keras篇】---利用keras改写VGG16经典模型在手写数字识别体中的应用的更多相关文章
- 利用神经网络算法的C#手写数字识别(二)
利用神经网络算法的C#手写数字识别(二) 本篇主要内容: 让项目编译通过,并能打开图片进行识别. 1. 从上一篇<利用神经网络算法的C#手写数字识别>中的源码地址下载源码与资源, ...
- 利用神经网络算法的C#手写数字识别(一)
利用神经网络算法的C#手写数字识别 转发来自云加社区,用于学习机器学习与神经网络 欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwri ...
- 利用c++编写bp神经网络实现手写数字识别详解
利用c++编写bp神经网络实现手写数字识别 写在前面 从大一入学开始,本菜菜就一直想学习一下神经网络算法,但由于时间和资源所限,一直未展开比较透彻的学习.大二下人工智能课的修习,给了我一个学习的契机. ...
- 利用神经网络算法的C#手写数字识别
欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...
- NN:利用深度学习之神经网络实现手写数字识别(数据集50000张图片)—Jason niu
import mnist_loader import network training_data, validation_data, test_data = mnist_loader.load_dat ...
- 手写数字识别——利用keras高层API快速搭建并优化网络模型
在<手写数字识别——手动搭建全连接层>一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配.梯度计算.准确度的统计等问题,但 ...
- mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)
前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...
- keras框架的MLP手写数字识别MNIST,梳理?
keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...
- keras—多层感知器MLP—MNIST手写数字识别
一.手写数字识别 现在就来说说如何使用神经网络实现手写数字识别. 在这里我使用mind manager工具绘制了要实现手写数字识别需要的模块以及模块的功能: 其中隐含层节点数量(即神经细胞数量)计算 ...
随机推荐
- SVN学习之windows下svn的安装
svn是apache的一个开源项目,全称为subversion.是一个基于版本的项目管理软件,一般在多人开发的项目中使用,目前svn已经替代了原来的cvs.大多数情况下,svn服务安装在linux服务 ...
- Java与Kotlin, 哪个是开发安卓应用的首选语言?
Java是很多开发者创建安卓应用的首选语言.但它在 Android 界的领导地位正受到各种新语言的挑战,Kotlin就是其一.虽然Kotlin最近才开始受到热捧,但有为数不少的人相信 Kotlin 在 ...
- Dubbo中订阅和通知解析
Dubbo中关于服务的订阅和通知主要发生在服务提供方暴露服务的过程和服务消费方初始化时候引用服务的过程中. 2345678910111213141516171819 public <T> ...
- bzoj 4565 状压区间dp
我还以为我状压很好...... 噗!!! 果然我区间很差... f[i][j][s]表示i~j段,合并后的状态为s所得的最大收益 枚举i,j,k,s. f[i][j][s<<1]=max( ...
- 【最小生成树】Bzoj1232 [Usaco2008Nov]安慰奶牛cheer
Description Farmer John变得非常懒, 他不想再继续维护供奶牛之间供通行的道路. 道路被用来连接N (5 <= N <= 10,000)个牧场, 牧场被连续地编号为1. ...
- 【双连通分量】Bzoj2730 HNOI2012 矿场搭建
Description 煤矿工地可以看成是由隧道连接挖煤点组成的无向图.为安全起见,希望在工地发生事故时所有挖煤点的工人都能有一条出路逃到救援出口处.于是矿主决定在某些挖煤点设立救援出口,使得无论哪一 ...
- BZOJ_2303_[Apio2011]方格染色 _并查集
BZOJ_2303_[Apio2011]方格染色 _并查集 Description Sam和他的妹妹Sara有一个包含n × m个方格的 表格.她们想要将其的每个方格都染成红色或蓝色. 出于个人喜好, ...
- ELK---日志分析系统
ELK就是一套完整的日志分析系统 ELK=Logstash+Elasticsearch+Kibana 统一官网https://www.elastic.co/products ELK模块说明 Logst ...
- Mysql8.0命令
1.创建用户 create user 'username'@'localhost' identified by 'pwd' 2.修改访问权限 在mysql数据下修改user表用户host为'%' up ...
- ASP.NET Core的实时库: SignalR简介及使用
大纲 本系列会分为2-3篇文章. 第一篇介绍了SignalR的预备知识和原理 本文介绍SignalR以及ASP.NET Core里使用SignalR. 本文的内容: 介绍SignalR 在ASP.NE ...