kaggle竞赛的inception模型已经能够提取图像很好的特征,后续训练出一个针对当前图片数据的全连接层,进行花的识别和分类。这里见书即可,不再赘述。

书中使用google参加Kaggle竞赛的inception模型重新训练一个全连接神经网络,对五种花进行识别,我姑且命名为模型flower_photos_model。我进一步拓展,将lower_photos_model模型进一步保存,然后部署和应用。然后,我们直接调用迁移之后又训练好的模型,对花片进行预测。

这里讨论两种方式:使用import_meta_graph和使用saver()

首先,原书的迁移学习的代码需要做一些改动。

  1. writer = tf.summary.FileWriter('./graphs/flower_photos_model_graph', sess.graph)
  2. saver.save(sess, "Saved_model/flower_photos_model.ckpt")

Saver()方式

我相较于训练flower_photos_model模型时,增添了一个变量的定义:

即label_index=tf.argmax(final_tensor,1)

  1. def main():
  2. #先定义相同的计算图再加载迁移学习的模型
  3. bottleneck_input = tf.placeholder(tf.float32, [None, BOTTLENECK_TENSOR_SIZE], name='BottleneckInputPlaceholder')
  4. with tf.name_scope('final_training_ops'):
  5. weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, n_classes], stddev=0.001))
  6. biases = tf.Variable(tf.zeros([n_classes]))
  7. logits = tf.matmul(bottleneck_input, weights) + biases
  8. final_tensor = tf.nn.softmax(logits)
  9. label_index=tf.argmax(final_tensor,1)
  10. #利用import_meta_graph和import_graph_def加载的变量均不允许与当前定义计算图有冲突。
  11. #saver = tf.train.Saver()则只加载当前计算图中定义的。
  12. saver = tf.train.Saver()
  13.  
  14. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.700)
  15. with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
  16. saver.restore(sess, "Saved_model/flower_photos_model.ckpt")
  17. #还是要加载一下inception模型
  18. MODEL_DIR = './inception_dec_2015'
  19. MODEL_FILE= 'tensorflow_inception_graph.pb'
  20. with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
  21. graph_def = tf.GraphDef()
  22. graph_def.ParseFromString(f.read())
  23. bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])
  24. print (bottleneck_tensor)
  25. print (jpeg_data_tensor)
  26. #为了在tensorboard中观察加载的计算图。
  27. writer = tf.summary.FileWriter('./graphs/flower_photos_model_graph_use', sess.graph)
  28. writer.close()
  29. #image_path='./data/xiaojie_application/xiaojie_rose.jpg'
  30. image_path='./data/xiaojie_application/xiaojie_sunflowers.jpg'
  31. #image_path='./data/xiaojie_application/5547758_eea9edfd54_n.jpg'
  32.  
  33. """测试一张图片,能否获取瓶颈向量。
  34. image_data = gfile.FastGFile(image_path, 'rb').read()
  35. print (sess.run(jpeg_data_tensor,{jpeg_data_tensor:image_data}))
  36. print ("xiaojie1")
  37. print (sess.run(bottleneck_tensor,{jpeg_data_tensor:image_data}))
  38. """
  39. label_index_value=evalution_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor,bottleneck_input,label_index)
  40. #print (label_index_value)
  41. classes=['daisy','dandelion','roses','sunflowers','tulips']
  42. print ("预测的花的类型:",classes[label_index_value[0]])

相关的函数的定义:

  1. evalution_xiaojie输出预测的分类index
  1. def evalution_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor,bottleneck_input,label_index):
  2. #输出一张图片的预测结果 bottleneck_values=get_bottleneck_values_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor)
  3. bottlenecks = []
  4. bottlenecks.append(bottleneck_values)
  5. label_index_value = sess.run(label_index, feed_dict={
  6. bottleneck_input: bottlenecks})
  7. return label_index_value

获取瓶颈向量(关于瓶颈向量,见书)

  1. def get_bottleneck_values_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor):
  2. #瓶颈向量
  3. if not os.path.exists(CACHE_DIR): os.makedirs(CACHE_DIR)
  4. bottleneck_path = get_bottleneck_path_xiaojie(CACHE_DIR,image_path)
  5. print (bottleneck_path)
  6. if not os.path.exists(bottleneck_path):
  7. image_data = gfile.FastGFile(image_path, 'rb').read()
  8. bottleneck_values = run_bottleneck_on_image(sess, image_data, jpeg_data_tensor, bottleneck_tensor)
  9. bottleneck_string = ','.join(str(x) for x in bottleneck_values)
  10. with open(bottleneck_path, 'w') as bottleneck_file:
  11. bottleneck_file.write(bottleneck_string)
  12. else:
  13. with open(bottleneck_path, 'r') as bottleneck_file:
  14. bottleneck_string = bottleneck_file.read()
  15. bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
  16. return bottleneck_values

使用inception模型计算瓶颈向量

  1. def run_bottleneck_on_image(sess, image_data, image_data_tensor, bottleneck_tensor):
  2. print("yes")
  3. bottleneck_values = sess.run(bottleneck_tensor, {image_data_tensor: image_data})
  4. bottleneck_values = np.squeeze(bottleneck_values)
  5. print("no")
  6. return bottleneck_values

瓶颈向量有一个缓存文件,这也是类似于原书训练迁移学习模型时的做法

  1. def get_bottleneck_path_xiaojie(CACHE_DIR,image_path):
  2. file_name_suffix=image_path.split('/')[-1]
  3. file_name_no_suffix=file_name_suffix.split('.')[0]
  4. bottleneck_file_name=file_name_no_suffix+('_cache.txt')
  5. bottleneck_path=os.path.join(CACHE_DIR, bottleneck_file_name)
  6. return bottleneck_path

定义的全局变量

  1. BOTTLENECK_TENSOR_SIZE = 2048
  2. n_classes = 5
  3. BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
  4. JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
  5. CACHE_DIR='./data/xiaojie_application/cache_bottleneck/'

Saver方式的说明:

Saver只能导出持久化模型中与当前代码定义计算图相匹配的部分。

因此,对于之前inception也需要再一次重新加载。

此外,当前代码定义计算图,比持久化模型flower_photos_model多定义了一个变量,即label_index=tf.argmax(final_tensor,1),即输出预测的分类index。

import_meta_graph方式

import_meta_graph方式与saver方式的不同点在于会导入完整的计算图,因此当前代码不能定义和要加载计算图相互冲突的部分。

相关函数定义的代码均不变。只将main函数的内容和全局变量改为:

  1. def main():
  2. #如果使用tf.train.import_meta_graph的话,就会重复加载计算图。因此,避免重复,当前代码中不能定义重复的。
  3.  
  4. #saver = tf.train.Saver()
  5. saver = tf.train.import_meta_graph("Saved_model/flower_photos_model.ckpt.meta")
  6. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.700)
  7. with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
  8. #with tf.Session() as sess:
  9. #如果直接使用saver = tf.train.Saver()和restore还原一个model.ckpt文件,是不可能将之前迁移学习那个模型利用import_graph_def加载的inception模型加载进来的。
  10. saver.restore(sess, "Saved_model/flower_photos_model.ckpt")
  11.  
  12. bottleneck_tensor= sess.graph.get_tensor_by_name(import_BOTTLENECK_TENSOR_NAME)
  13. jpeg_data_tensor = sess.graph.get_tensor_by_name(import_JPEG_DATA_TENSOR_NAME)
  14. print (bottleneck_tensor)
  15. print (jpeg_data_tensor)
  16.  
  17. writer = tf.summary.FileWriter('./graphs/flower_photos_model_graph_use', sess.graph)
  18. writer.close()
  19. image_path='./data/xiaojie_application/xiaojie_rose.jpg'
  20. #image_path='./data/xiaojie_application/xiaojie_sunflowers.jpg'
  21. #image_path='./data/xiaojie_application/5547758_eea9edfd54_n.jpg'
  22.  
  23. """测试一张图片
  24. image_data = gfile.FastGFile(image_path, 'rb').read()
  25. print (sess.run(jpeg_data_tensor,{jpeg_data_tensor:image_data}))
  26. print ("xiaojie1")
  27. print (sess.run(bottleneck_tensor,{jpeg_data_tensor:image_data}))
  28. """
  29. bottleneck_input= sess.graph.get_tensor_by_name("BottleneckInputPlaceholder:0")
  30. final_tensor = sess.graph.get_tensor_by_name("final_training_ops/Softmax:0")
  31. label_index=tf.argmax(final_tensor,1)
  32. label_index_value=evalution_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor,bottleneck_input,label_index)
  33. print (label_index_value)
  34. classes=['daisy','dandelion','roses','sunflowers','tulips']
  35. print ("预测的花的类型:",classes[label_index_value[0]])

全局变量改为:

import_BOTTLENECK_TENSOR_NAME = 'import/pool_3/_reshape:0'

import_JPEG_DATA_TENSOR_NAME = 'import/DecodeJpeg/contents:0'

这是因为,使用import_meta_graph方式的话,当前代码不能定义任何与持久化模型中计算图冲突的节点。此外,在flower_photos_model模型对全连接层进行训练的过程中,已经利用import_graph_def的方式导入google Inception v3的持久化模型pb文件,因此,已经包括了google的模型。通过在tensorboard中查看,会发现,所有导入的模块节点之前会带上import节点。因此,在训练flower_photos_model模型时,使用的是pool_3/_reshape:0获取张量,而此时,只能使用import/pool_3/_reshape:0'获取张量。

只能使用import/pool_3/_reshape:0'获取张量。

final_tensor = sess.graph.get_tensor_by_name("final_training_ops/Softmax:0")

然后,我们再定义一个label_index

label_index=tf.argmax(final_tensor,1)

因此,同saver模型一样,所有的其它函数接口和实现都不用变。

最后的结果很nice。可以识别五种花朵,可以直接部署应用。

程序附件

链接:https://pan.baidu.com/s/11YtyDEyV84jONPi9tO2TCw 密码:8mfj

2 (自我拓展)部署花的识别模型(学习tensorflow实战google深度学习框架)的更多相关文章

  1. 1 如何使用pb文件保存和恢复模型进行迁移学习(学习Tensorflow 实战google深度学习框架)

    学习过程是Tensorflow 实战google深度学习框架一书的第六章的迁移学习环节. 具体见我提出的问题:https://www.tensorflowers.cn/t/5314 参考https:/ ...

  2. TensorFlow+实战Google深度学习框架学习笔记(11)-----Mnist识别【采用滑动平均,双层神经网络】

    模型:双层神经网络 [一层隐藏层.一层输出层]隐藏层输出用relu函数,输出层输出用softmax函数 过程: 设置参数 滑动平均的辅助函数 训练函数 x,y的占位,w1,b1,w2,b2的初始化 前 ...

  3. Tensorflow 实战Google深度学习框架 第五章 5.2.1Minister数字识别 源代码

    import os import tab import tensorflow as tf print "tensorflow 5.2 " from tensorflow.examp ...

  4. TensorFlow+实战Google深度学习框架学习笔记(12)------Mnist识别和卷积神经网络LeNet

    一.卷积神经网络的简述 卷积神经网络将一个图像变窄变长.原本[长和宽较大,高较小]变成[长和宽较小,高增加] 卷积过程需要用到卷积核[二维的滑动窗口][过滤器],每个卷积核由n*m(长*宽)个小格组成 ...

  5. 实战Google深度学习框架-C5-MNIST数字识别问题

    5.1 MNIST数据处理 MNIST是NIST数据集的一个子集,包含60000张图片作为训练数据,10000张作为测试数据,其中每张图片代表0~9中的一个数字,图片大小为28*28(可以用一个28* ...

  6. 论文阅读:Face Recognition: From Traditional to Deep Learning Methods 《人脸识别综述:从传统方法到深度学习》

     论文阅读:Face Recognition: From Traditional to Deep Learning Methods  <人脸识别综述:从传统方法到深度学习>     一.引 ...

  7. 机器学习如何选择模型 & 机器学习与数据挖掘区别 & 深度学习科普

    今天看到这篇文章里面提到如何选择模型,觉得非常好,单独写在这里. 更多的机器学习实战可以看这篇文章:http://www.cnblogs.com/charlesblc/p/6159187.html 另 ...

  8. 『高性能模型』Roofline Model与深度学习模型的性能分析

    转载自知乎:Roofline Model与深度学习模型的性能分析 在真实世界中,任何模型(例如 VGG / MobileNet 等)都必须依赖于具体的计算平台(例如CPU / GPU / ASIC 等 ...

  9. tensorflow模型持久化保存和加载--深度学习-神经网络

    模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...

随机推荐

  1. django 迁移工程数据库无法创建的问题

    1.今天我遇到一个问题在此做笔记记下来 2.我晚上一般是在家练习的,白天会拷贝工程到公司用 3.因为我在家里创建过一次数据库了,通过命令创建,但是无论我怎么修改models都无法创建表,最后只能通过新 ...

  2. 文件IO(存取.txt文件)

    //存文件方法 public void Save_File_Info(string Save_Path) { //根据路径,创建文件和数据流 FileStream FS = new FileStrea ...

  3. [转]Hive 数据类型

    Hive的内置数据类型可以分为两大类:(1).基础数据类型:(2).复杂数据类型.其中,基础数据类型包括:TINYINT,SMALLINT,INT,BIGINT,BOOLEAN,FLOAT,DOUBL ...

  4. 添加ASP.NET网站资源文件夹

    ASP.NET应用程序包含7个默认文件夹,分别为Bin.APP_Code.App_GlobalResources.App_LocalResources.App_WebReferences.App_Br ...

  5. render函数的使用

    render函数的几种使用方法最近使用element-ui的tree组件时,需要在组件树的右边自定义一些图标,就想到了用render函数去渲染. <el-tree class="p-t ...

  6. WEB下渗透测试经验技巧(全)[转载]

    Nuclear’Atk 整理的: 上传漏洞拿shell: 1.直接上传asp.asa.jsp.cer.php.aspx.htr.cdx….之类的马,拿到shell.2.就是在上传时在后缀后面加空格或者 ...

  7. Storm框架入门

    1 Topology构成 和同样是计算框架的Mapreduce相比,Mapreduce集群上运行的是Job,而Storm集群上运行的是Topology.但是Job在运行结束之后会自行结束,Topolo ...

  8. Azure 项目构建 – 部署高可用的 Python Web 应用

    Python 以其优美,清晰,简单的特性在全世界广泛流行,成为最主流的编程语言之一.Azure 平台针对 Python 提供了非常完备的支持.本项目中,您将了解如何构造和部署基于 Azure Web ...

  9. 朝圣Java(问题集锦)之:The Apache Tomcat installation at this directory is version 8.5.32. A Tomcat 8.0 inst

    最近开始学Java了.有C#底子,但是学起来Java还是很吃力,感觉别人架好了各种包,自己只要调用就行了,结果还有各种bug出现.掩面中. 启动Tomcat的时候,报错The Apache Tomca ...

  10. mysql主从复制测试

    mysql主从复制测试: 1. 配置主服务器:在主库上面添加复制账号GRANT REPLICATION SLAVE on *.* to 'mark'@'%' identified by 'mark' ...