版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_44633882/article/details/89054159

------------------------------------------------------------------------------------------

PS: 已给出相关的代码Demo,代码地址:

https://gitee.com/devilmaycry812839668/vgg16_inference

-------------------------------------------------------------------

主流的CNN模型基本都会使用VGG16或者ResNet等网络作为预训练模型,正好有个朋友和我说发给他一个VGG16的预训练模型和代码,我就整理了一下。在这里也分享一下,方便大家直接使用。

系统环境

  • Tensorflow-gpu 1.12.0
  • Python 3.5.2

资料来源

官方slim说明

https://github.com/tensorflow/models/tree/1af55e018eebce03fb61bba9959a04672536107d/research/slim

主页里直接可以看到所提供的模型列表和下载链接。

我们选择vgg16来做个示范哈,虽然vgg16的准确率现在已经不算高。

拿到vgg_16.ckpt模型文件!

直接贴上代码

vgg16预训练模型使用代码

import os
import numpy as np
import tensorflow as tf
slim = tf.contrib.slim
PROJECT_PATH = os.path.dirname(os.path.abspath(os.getcwd()))
# 预训练模型位置
tf.app.flags.DEFINE_string('pretrained_model_path', os.path.join(PROJECT_PATH, 'data/vgg_16.ckpt'), '')
FLAGS = tf.app.flags.FLAGS def vgg_arg_scope(weight_decay=0.1):
"""定义 VGG arg scope.
Args:
weight_decay: The l2 regularization coefficient.
Returns:
An arg_scope.
"""
with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_regularizer=slim.l2_regularizer(weight_decay),
biases_initializer=tf.zeros_initializer()):
with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc:
return arg_sc def vgg16(inputs,scope='vgg_16'):
with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc:
# Collect outputs for conv2d, fully_connected and max_pool2d.
with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],):
# outputs_collections=end_points_collection):
net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
net = slim.max_pool2d(net, [2, 2], scope='pool3')
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
net = slim.max_pool2d(net, [2, 2], scope='pool4')
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
# net = slim.max_pool2d(net, [2, 2], scope='pool5')
# net = slim.fully_connected(net, 4096, scope='fc6')
# net = slim.dropout(net, 0.5, scope='dropout6')
# net = slim.fully_connected(net, 4096, scope='fc7')
# net = slim.dropout(net, 0.5, scope='dropout7')
# net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc8')
return net def net():
input_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_image')
with slim.arg_scope(vgg_arg_scope()):
conv5_3 = vgg16(input_image) # vgg16网络 init = tf.global_variables_initializer()
# restore预训练模型op
if FLAGS.pretrained_model_path is not None:
variable_restore_op = slim.assign_from_checkpoint_fn(FLAGS.pretrained_model_path,
slim.get_trainable_variables(),
ignore_missing_vars=True)
with tf.Session() as sess:
sess.run(init)
if FLAGS.pretrained_model_path is not None:
# resotre 预训练模型
variable_restore_op(sess)
a = sess.run([conv5_3],feed_dict={input_image:np.arange(360000).reshape(1,300,400,3)}) if __name__ == '__main__':
net()
print(tf.trainable_variables())

讲一讲,代码里要注意的地方吧,也比较简单易懂。

1.vgg_arg_scope

def vgg_arg_scope(weight_decay=0.1):
with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_regularizer=slim.l2_regularizer(weight_decay),
biases_initializer=tf.zeros_initializer()):
with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc:
return arg_sc

vgg_arg_scope()函数返回了一个scope参数空间,使用起来就是with slim.arg_scope(vgg_arg_scope()):

它规定了[slim.conv2d, slim.fully_connected]都要满足什么变量参数,比如:激活函数,参数初始化。

activation_fn=tf.nn.relu来说,所有在这个变量空间中的conv2d卷积和fully_connected全连接都是指定了relu作为激活函数。

当然,这里存在覆盖是可以的,可以嵌套arg_scope进行设置,内层空间覆盖了外层空间,最内层的就是slim.conv2d()里传入指定的参数了,这是覆盖了所有外层的。变量空间在我看来,非常方便,也使网络定义变得简单。

2.slim.repeat()

VGG16中比如一个conv,其中做了3次相同的卷积,写出来的代码就很长,使用repeat()就简单一句话net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')增强了代码可读性,而有人可能会问,那三层卷积层怎么进行标识呢?
当然没问题,你输出变量会发现是类似conv5/conv5_1,在_后面递增自动标记区分。

3.代码里是每个层是如何拿到自己对应的模型参数呢?

这个应该是有些人的困惑吧,毕竟不知道这个,也只能拿着代码直接用。这个的关键是变量空间
网络定义完成了,你可以通过 print(tf.trainable_variables()) 来获得所有网络中的变量。
我贴出来 vgg16 中的变量,太多了,捡重要的说,就说说 conv1,可以看到变量是这么标识的 vgg_16/conv1/conv1_1/weights,前面有很多前缀,就和龙母报出来自己一堆头衔一样,其实是起到一个定位效果。

    # [<tf.Variable 'vgg_16/conv1/conv1_1/weights:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
# <tf.Variable 'vgg_16/conv1/conv1_1/biases:0' shape=(64,) dtype=float32_ref>,
# <tf.Variable 'vgg_16/conv1/conv1_2/weights:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
# <tf.Variable 'vgg_16/conv1/conv1_2/biases:0' shape=(64,) dtype=float32_ref>,
# <tf.Variable 'vgg_16/conv2/conv2_1/weights:0' shape=(3, 3, 64, 128) dtype=float32_ref>,

在代码里,我们要让每个层在预训练模型里找到自己对应的参数,就必须这么定义变量空间。

    with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc:
with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d]):
net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')

看到了 scope 和 ‘vgg_16’两个,其实 scope 我们也传入的是 ’vgg_16’,tf.variable_scope() 的参数,前两个是 name_or_scope, default_name。默认名称是当 name_or_scope 为空时,使用的默认名称。

这么整理一下,'vgg_16',  后面的slim.repeat()里的scope='conv1',还有自动标记的 conv1_1

连起来就是 vgg_16/conv1/conv1_1

4.  预训练模型restore。

先准备op,而且若 pretrained_model_path 不为空,才加入和使用 variable_restore_op

if FLAGS.pretrained_model_path is not None:
variable_restore_op = slim.assign_from_checkpoint_fn(FLAGS.pretrained_model_path,
slim.get_trainable_variables(),
ignore_missing_vars=True)

Session()中使用

if FLAGS.pretrained_model_path is not None:
variable_restore_op(sess)

讲解完毕!哦,还有补充一下,一般vgg16来说,只会拿conv5_3的输出,继续做fine-tune。所以,你只用conv5_3,测试的时候是不用在意输入图片的大小的,因为都是卷积嘛。但是,我测试的时候,传入了个(1,3,6,3)的数组,出现了这么一个错。想了想,嗯,应该是这个数组做不了那么多次卷积的,所以Tensorflow报错了。(这里只是简单记录一下),所以用一个大一些的数组传入就可以啦

2019-04-06 12:20:14.650154: F tensorflow/stream_executor/cuda/cuda_dnn.cc:542] Check failed: cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd, dims.data(), strides.data()) == CUDNN_STATUS_SUCCESS (3 vs. 0)batch_descriptor: {count: 1 feature_map_count: 128 spatial: 0 1  value_min: 0.000000 value_max: 0.000000 layout: BatchDepthYX}
bash: line 1: 2492 Aborted (core dumped) env "PYTHONUNBUFFERED"="1" "PYTHONPATH"="/tmp/pycharm_project_299:/home/benke/.pycharm_helpers/pycharm_matplotlib_backend" "PYCHARM_HOSTED"="1" "JETBRAINS_REMOTE_RUN"="1" "PYCHARM_MATPLOTLIB_PORT"="65407" "PYTHONIOENCODING"="UTF-8" '/opt/anaconda3/bin/python' '-u' '/tmp/pycharm_project_299/data/vgg.py'

---------------------------------------------------------------------------------------------

转者注:

tensorflow官方预训练模型下载链接:

https://github.com/tensorflow/models/tree/master/research/slim

上面的代码一直在自己电脑上无法跑通,后来发现需要报ckpt文件放在data文件里面,并且运行的主文件也必须在一个文件夹下面:

其次,也是最关键的就是如果你是使用GPU在做这个计算那么你的GPU内存应该要大于30G,这也就是我为什么最终只有换到了服务器上才能够在GPU Tesla V100  上跑通这个代码的原因。

更正一下, 后来发现多次跑这个代码不成功是因为windows平台下IDE的问题,最后在CMD里面成功运行,显卡为2070super:

【转载】 Tensorflow如何直接使用预训练模型(vgg16为例)的更多相关文章

  1. 我的Keras使用总结(3)——利用bottleneck features进行微调预训练模型VGG16

    Keras的预训练模型地址:https://github.com/fchollet/deep-learning-models/releases 一个稍微讲究一点的办法是,利用在大规模数据集上预训练好的 ...

  2. 【转载】最强NLP预训练模型!谷歌BERT横扫11项NLP任务记录

    本文介绍了一种新的语言表征模型 BERT--来自 Transformer 的双向编码器表征.与最近的语言表征模型不同,BERT 旨在基于所有层的左.右语境来预训练深度双向表征.BERT 是首个在大批句 ...

  3. 我的Keras使用总结(4)——Application中五款预训练模型学习及其应用

    本节主要学习Keras的应用模块 Application提供的带有预训练权重的模型,这些模型可以用来进行预测,特征提取和 finetune,上一篇文章我们使用了VGG16进行特征提取和微调,下面尝试一 ...

  4. tensorflow利用预训练模型进行目标检测(二):预训练模型的使用

    一.运行样例 官网链接:https://github.com/tensorflow/models/blob/master/research/object_detection/object_detect ...

  5. tensorflow利用预训练模型进行目标检测(一):安装tensorflow detection api

    一.tensorflow安装 首先系统中已经安装了两个版本的tensorflow,一个是通过keras安装的, 一个是按照官网教程https://www.tensorflow.org/install/ ...

  6. tensorflow 预训练模型列表

    tensorflow 预训练模型列表 https://github.com/tensorflow/models/tree/master/research/slim Pre-trained Models ...

  7. pytorch预训练模型的下载地址以及解决下载速度慢的方法

    https://github.com/pytorch/vision/tree/master/torchvision/models 几乎所有的常用预训练模型都在这里面 总结下各种模型的下载地址: 1 R ...

  8. 【tf.keras】tf.keras加载AlexNet预训练模型

    目录 从 PyTorch 中导出模型参数 第 0 步:配置环境 第 1 步:安装 MMdnn 第 2 步:得到 PyTorch 保存完整结构和参数的模型(pth 文件) 第 3 步:导出 PyTorc ...

  9. 文本分类实战(十)—— BERT 预训练模型

    1 大纲概述 文本分类这个系列将会有十篇左右,包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMo,BERT等)的文本分类.总共有以下系列: word2vec预训练词向量 te ...

  10. 文本分类实战(九)—— ELMO 预训练模型

    1 大纲概述 文本分类这个系列将会有十篇左右,包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMo,BERT等)的文本分类.总共有以下系列: word2vec预训练词向量 te ...

随机推荐

  1. 比较 SpringSecurity 和 Shiro

    相比 Spring Security, Shiro 在保持强大功能的同时,使用简单性和灵活性. SpringSecurity: 即使是一个一个简单的请求, 最少得经过它的 8 个Filter.Spri ...

  2. H5弹窗底层滑动

    H5弹窗底层滑动 背景 产品提出H5 弹出窗滑动时,底层页面也会跟随滑动,需要调整禁止底层滑动,增加用户体验. 问题产生原因 ios 滑动时有回弹效果 顶层元素滑动默认行为 解决办法 阻止元素的默认( ...

  3. 判断日期是否为周六周日,BigDecimal比较大小

    判断日期是否为周六周日,BigDecimal比较大小 package com.example.core.mydemo.date; import java.math.BigDecimal; import ...

  4. PowerBI_一分钟了解POWERBI计算组_基础运用篇(一)

    在第一篇计算组的文章中,给大家介绍了,POWERBI的计算组功能的基本概念和作用. 本文,旨在通过简单案例,介绍计算组功能的具体应用场景. 没有看过第一篇的同学,可以先简单过一下第一篇,补齐一下概念和 ...

  5. 从Purge机制说起,详解GaussDB(for MySQL)的优化策略

    本文分享自华为云社区<[华为云MySQL技术专栏]GaussDB(for MySQL) Purge优化>,作者:GaussDB 数据库. 在MySQL中,尤其是在使用InnoDB引擎时,P ...

  6. 嵌入式ARM端测试手册——全志T3+Logos FPGA开发板(上)

    前 言 本指导文档适用开发环境: Windows开发环境:Windows 7 64bit.Windows 10 64bit Linux开发环境:Ubuntu18.04.4 64bit 虚拟机:VMwa ...

  7. Nuxt框架中内置组件详解及使用指南(二)

    title: Nuxt框架中内置组件详解及使用指南(二) date: 2024/7/7 updated: 2024/7/7 author: cmdragon excerpt: 摘要:"本文详 ...

  8. Java 面向对象编程之InstanceOf关键词和多态

    InstanceOf关键字使用,什么是多态 InstanceOf关键字 是Java的一个二元操作符(运算符),也是Java的保留关键字 语法 //如果该object 是该class的⼀个实例,那⼀个实 ...

  9. Mybatis Plus 3.X版本的insert填充自增id的IdType.ID_WORKER策略源码分析

    总结/朱季谦 某天同事突然问我,你知道Mybatis Plus的insert方法,插入数据后自增id是如何自增的吗? 我愣了一下,脑海里只想到,当在POJO类的id设置一个自增策略后,例如@Table ...

  10. 可能是全网最适合入门的面向对象编程教程:Python实现-嵌入式爱好者必看!

    前言 对于嵌入式入门的同学来说,往往会遇到设备端处理能力不足.在面对大规模计算情况下需要借助上位机完成进一步的数据处理的情况.此时,Python 语言因其简单易用的特点和丰富多样的库成为了我们做上位机 ...