基础分类网络VGG
vgg16是牛津大学视觉几何组(Oxford Visual Geometry Group)2014年提出的一个模型. vgg模型也得名于此.
2014年,vgg16拿了Imagenet Large Scale Visual Recognition Challenge 2014 (ILSVRC2014)
比赛的冠军.
论文连接:https://arxiv.org/abs/1409.1556
http://www.robots.ox.ac.uk/~vgg/research/very_deep/牛津大学视觉研究小组在这里放出了他们在ImageNet比赛训练得到的模型文件.
网上有很多vgg16的实现,下面
https://github.com/machrisaa/tensorflow-vgg/blob/master/vgg16.py
这个是推理的实现,即加载权重文件,实现图像预测https://github.com/ppplinday/tensorflow-vgg16-train-and-test/blob/master/train_vgg.py
这个是训练的实现,即如何得到权重文件
vgg的模型结构如下:
每一层的卷积核的大小都是3*3.
现在的keras里已经集成了很多模型,具体可以参考keras的文档.
https://keras.io/applications/#models-for-image-classification-with-weights-trained-on-imagenet
下面是keras_applications/vgg16.py的实现.比tensorflow的代码更易于理解.
"""VGG16 model for Keras.
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from . import get_submodules_from_kwargs
from . import imagenet_utils
from .imagenet_utils import decode_predictions
from .imagenet_utils import _obtain_input_shape
preprocess_input = imagenet_utils.preprocess_input
WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/'
'releases/download/v0.1/'
'vgg16_weights_tf_dim_ordering_tf_kernels.h5')
WEIGHTS_PATH_NO_TOP = ('https://github.com/fchollet/deep-learning-models/'
'releases/download/v0.1/'
'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5')
def VGG16(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
**kwargs):
"""Instantiates the VGG16 architecture.
Optionally loads weights pre-trained on ImageNet.
Note that the data format convention used by the model is
the one specified in your Keras config at `~/.keras/keras.json`.
# Arguments
include_top: whether to include the 3 fully-connected
layers at the top of the network.
weights: one of `None` (random initialization),
'imagenet' (pre-training on ImageNet),
or the path to the weights file to be loaded.
input_tensor: optional Keras tensor
(i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(224, 224, 3)`
(with `channels_last` data format)
or `(3, 224, 224)` (with `channels_first` data format).
It should have exactly 3 input channels,
and width and height should be no smaller than 32.
E.g. `(200, 200, 3)` would be one valid value.
pooling: Optional pooling mode for feature extraction
when `include_top` is `False`.
- `None` means that the output of the model will be
the 4D tensor output of the
last convolutional block.
- `avg` means that global average pooling
will be applied to the output of the
last convolutional block, and thus
the output of the model will be a 2D tensor.
- `max` means that global max pooling will
be applied.
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
# Returns
A Keras model instance.
# Raises
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
"""
backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
if not (weights in {'imagenet', None} or os.path.exists(weights)):
raise ValueError('The `weights` argument should be either '
'`None` (random initialization), `imagenet` '
'(pre-training on ImageNet), '
'or the path to the weights file to be loaded.')
if weights == 'imagenet' and include_top and classes != 1000:
raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
' as true, `classes` should be 1000')
# Determine proper input shape
input_shape = _obtain_input_shape(input_shape,
default_size=224,
min_size=32,
data_format=backend.image_data_format(),
require_flatten=include_top,
weights=weights)
if input_tensor is None:
img_input = layers.Input(shape=input_shape)
else:
if not backend.is_keras_tensor(input_tensor):
img_input = layers.Input(tensor=input_tensor, shape=input_shape)
else:
img_input = input_tensor
# Block 1
x = layers.Conv2D(64, (3, 3),
activation='relu',
padding='same',
name='block1_conv1')(img_input)
x = layers.Conv2D(64, (3, 3),
activation='relu',
padding='same',
name='block1_conv2')(x)
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
# Block 2
x = layers.Conv2D(128, (3, 3),
activation='relu',
padding='same',
name='block2_conv1')(x)
x = layers.Conv2D(128, (3, 3),
activation='relu',
padding='same',
name='block2_conv2')(x)
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
# Block 3
x = layers.Conv2D(256, (3, 3),
activation='relu',
padding='same',
name='block3_conv1')(x)
x = layers.Conv2D(256, (3, 3),
activation='relu',
padding='same',
name='block3_conv2')(x)
x = layers.Conv2D(256, (3, 3),
activation='relu',
padding='same',
name='block3_conv3')(x)
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
# Block 4
x = layers.Conv2D(512, (3, 3),
activation='relu',
padding='same',
name='block4_conv1')(x)
x = layers.Conv2D(512, (3, 3),
activation='relu',
padding='same',
name='block4_conv2')(x)
x = layers.Conv2D(512, (3, 3),
activation='relu',
padding='same',
name='block4_conv3')(x)
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
# Block 5
x = layers.Conv2D(512, (3, 3),
activation='relu',
padding='same',
name='block5_conv1')(x)
x = layers.Conv2D(512, (3, 3),
activation='relu',
padding='same',
name='block5_conv2')(x)
x = layers.Conv2D(512, (3, 3),
activation='relu',
padding='same',
name='block5_conv3')(x)
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
if include_top:
# Classification block
x = layers.Flatten(name='flatten')(x)
x = layers.Dense(4096, activation='relu', name='fc1')(x)
x = layers.Dense(4096, activation='relu', name='fc2')(x)
x = layers.Dense(classes, activation='softmax', name='predictions')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D()(x)
elif pooling == 'max':
x = layers.GlobalMaxPooling2D()(x)
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = keras_utils.get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
model = models.Model(inputs, x, name='vgg16')
# Load weights.
if weights == 'imagenet':
if include_top:
weights_path = keras_utils.get_file(
'vgg16_weights_tf_dim_ordering_tf_kernels.h5',
WEIGHTS_PATH,
cache_subdir='models',
file_hash='64373286793e3c8b2b4e3219cbf3544b')
else:
weights_path = keras_utils.get_file(
'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
file_hash='6d6bbae143d832006294945121d1f1fc')
model.load_weights(weights_path)
if backend.backend() == 'theano':
keras_utils.convert_all_kernels_in_model(model)
elif weights is not None:
model.load_weights(weights)
return model
可以清楚地看出来,所用的卷积核全部是3*3的.
用keras做预测也很简单,
from keras.applications.vgg16 import VGG16
model = VGG16()
print(model.summary())
上面代码会把权重文件下载到
这里贴一段网上找的代码
from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from keras.preprocessing.image import load_img, img_to_array
import numpy as np
# VGG-16 instance
model = VGG16(weights='imagenet', include_top=True)
image = load_img('C:/Pictures/Pictures/test_imgs/golden.jpg', target_size=(224, 224))
image_data = img_to_array(image)
# reshape it into the specific format
image_data = image_data.reshape((1,) + image_data.shape)
print(image_data.shape)
# prepare the image data for VGG
image_data = preprocess_input(image_data)
# using the pre-trained model to predict
prediction = model.predict(image_data)
# decode the prediction results
results = decode_predictions(prediction, top=3)
print(results)
很简单
- 加载模型
- 加载图片,预处理
- 前向传播
- 解释输出tensor
vgg19和vgg16结构基本一致的,就是多了几个卷积层.
基础分类网络VGG的更多相关文章
- TCP/IP协议(一)网络基础知识 网络七层协议
参考书籍为<图解tcp/ip>-第五版.这篇随笔,主要内容还是TCP/IP所必备的基础知识,包括计算机与网络发展的历史及标准化过程(简述).OSI参考模型.网络概念的本质.网络构建的设备等 ...
- Python黑客编程基础3网络数据监听和过滤
网络数据监听和过滤 课程的实验环境如下: • 操作系统:kali Linux 2.0 • 编程工具:Wing IDE • Python版本:2.7.9 • 涉及 ...
- 黑马程序员:Java基础总结----网络编程
黑马程序员:Java基础总结 网络编程 ASP.Net+Android+IO开发 . .Net培训 .期待与您交流! 网络编程 网络通讯要素 . IP地址 . 网络中设备的标识 . 不易记忆,可用 ...
- 网络编程基础:网络基础之网络协议、socket模块
操作系统(简称OS)基础: 应用软件不能直接操作硬件,能直接操作硬件的只有操作系统:所以,应用软件可以通过操作系统来间接操作硬件 网络基础之网络协议: 网络通讯原理: 连接两台计算机之间的Intern ...
- GO学习-(19) Go语言基础之网络编程
Go语言基础之网络编程 现在我们几乎每天都在使用互联网,我们前面已经学习了如何编写Go语言程序,但是如何才能让我们的程序通过网络互相通信呢?本章我们就一起来学习下Go语言中的网络编程. 关于网络编程其 ...
- python渗透测试入门——基础的网络编程工具
<Python黑帽子--黑客与渗透测试编程之道学习>这本书是我在学习安全的过程中发现的在我看来十分优秀的一本书,业内也拥有很高的评价,所以在这里将自己的学习内容分享出来. 1.基础的网络编 ...
- 转 经典分类网络Googlenet
转自https://my.oschina.net/u/876354/blog/1637819 2014年,GoogLeNet和VGG是当年ImageNet挑战赛(ILSVRC14)的双雄,GoogLe ...
- 【深度学习系列】用PaddlePaddle和Tensorflow实现经典CNN网络Vgg
上周我们讲了经典CNN网络AlexNet对图像分类的效果,2014年,在AlexNet出来的两年后,牛津大学提出了Vgg网络,并在ILSVRC 2014中的classification项目的比赛中取得 ...
- 周末班:Python基础之网络编程
一.楔子 你现在已经学会了写python代码,假如你写了两个python文件a.py和b.py,分别去运行,你就会发现,这两个python的文件分别运行的很好.但是如果这两个程序之间想要传递一个数据, ...
随机推荐
- Asp.NetCore源码学习[2-1]:配置[Configuration]
Asp.NetCore源码学习[2-1]:配置[Configuration] 在Asp. NetCore中,配置系统支持不同的配置源(文件.环境变量等),虽然有多种的配置源,但是最终提供给系统使用的只 ...
- Unity基础之:UnityAPI的学习
版权声明: 本文原创发布于博客园"优梦创客"的博客空间(网址:http://www.cnblogs.com/raymondking123/)以及微信公众号"优梦创客&qu ...
- Java模拟并解决缓存穿透
什么叫做缓存穿透 缓存穿透只会发生在高并发的时候,就是当有10000个并发进行查询数据的时候,我们一般都会先去redis里面查询进行数据,但是如果redis里面没有这个数据的时候,那么这10000个并 ...
- [Spring cloud 一步步实现广告系统] 16. 增量索引实现以及投送数据到MQ(kafka)
实现增量数据索引 上一节中,我们为实现增量索引的加载做了充足的准备,使用到mysql-binlog-connector-java 开源组件来实现MySQL 的binlog监听,关于binlog的相关知 ...
- 消息中间件-activemq入门(二)
上一节我们了解了JMS规范并且知道了JMS规范的良好实现者-activemq.今天我们就去了解一下activemq的使用.另外我们应该抱着目的去学习,别忘了我们为什么要使用消息中间件:解耦系统之间的联 ...
- linux学习总结--linux100day(day2)
Linux中的哲学--一切皆文件 为了便于操作,我们可以使用secureCRT或Xshell连接到我们的虚拟机. 要用远程工具连接到虚拟机上,我们只需要打开虚拟机上的ssh服务,在xshell中填写主 ...
- Zookeeper_阅读源码第一步_在 IDE 里启动 zkServer(单机版)
Zookeeper是开源的,如果想多了解Zookeeper或看它的源码,最好是能找到它的源码并在 IDE 里启动,可以debug看它咋执行的,能够帮助你理解其原理. 准备源码 所以我们很容易搞到它的源 ...
- 深入理解ES6之——代理和反射(proxy)
通过调用new proxy()你可以创建一个代理来替代另一个对象(被称为目标),这个代理对目标对象进行了虚拟,因此该代理与该目标对象表面上可以被当做同一个对象来对待. 创建一个简单的代理 当你使用Pr ...
- 自己实现spring核心功能 三
前言 前两篇已经基本实现了spring的核心功能,下面讲到的参数绑定是属于springMvc的范畴了.本篇主要将请求到servlet后怎么去做映射和处理.首先来看一看dispatherServlet的 ...
- 查看centos中的用户和用户组和修改密码
查看centos中的用户和用户组 1.用户列表文件: vim /etc/passwd/ 2.用户组列表文件: vim /etc/group 3.查看系统中有哪些用户: cut -d : -f /etc ...