1. import glob
  2. import os.path
  3. import numpy as np
  4. import tensorflow as tf
  5. from tensorflow.python.platform import gfile
  6. import tensorflow.contrib.slim as slim
  7.  
  8. # 加载通过TensorFlow-Slim定义好的inception_v3模型。
  9. import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3
  10.  
  11. # 处理好之后的数据文件。
  12. INPUT_DATA = '../../datasets/flower_processed_data.npy'
  13. # 保存训练好的模型的路径。
  14. TRAIN_FILE = 'train_dir/model'
  15. # 谷歌提供的训练好的模型文件地址。因为GitHub无法保存大于100M的文件,所以
  16. # 在运行时需要先自行从Google下载inception_v3.ckpt文件。
  17. CKPT_FILE = '../../datasets/inception_v3.ckpt'
  18.  
  19. # 定义训练中使用的参数。
  20. LEARNING_RATE = 0.0001
  21. STEPS = 300
  22. BATCH = 32
  23. N_CLASSES = 5
  24.  
  25. # 不需要从谷歌训练好的模型中加载的参数。
  26. CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
  27. # 需要训练的网络层参数明层,在fine-tuning的过程中就是最后的全联接层。
  28. TRAINABLE_SCOPES='InceptionV3/Logits,InceptionV3/AuxLogit'
  1. def get_tuned_variables():
  2. exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(',')]
  3.  
  4. variables_to_restore = []
  5. # 枚举inception-v3模型中所有的参数,然后判断是否需要从加载列表中移除。
  6. for var in slim.get_model_variables():
  7. excluded = False
  8. for exclusion in exclusions:
  9. if var.op.name.startswith(exclusion):
  10. excluded = True
  11. break
  12. if not excluded:
  13. variables_to_restore.append(var)
  14. return variables_to_restore
  1. def get_trainable_variables():
  2. scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(',')]
  3. variables_to_train = []
  4.  
  5. # 枚举所有需要训练的参数前缀,并通过这些前缀找到所有需要训练的参数。
  6. for scope in scopes:
  7. variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
  8. variables_to_train.extend(variables)
  9. return variables_to_train
  1. def main():
  2. # 加载预处理好的数据。
  3. processed_data = np.load(INPUT_DATA)
  4. training_images = processed_data[0]
  5. n_training_example = len(training_images)
  6. training_labels = processed_data[1]
  7.  
  8. validation_images = processed_data[2]
  9. validation_labels = processed_data[3]
  10.  
  11. testing_images = processed_data[4]
  12. testing_labels = processed_data[5]
  13. print("%d training examples, %d validation examples and %d testing examples." % (
  14. n_training_example, len(validation_labels), len(testing_labels)))
  15.  
  16. # 定义inception-v3的输入,images为输入图片,labels为每一张图片对应的标签。
  17. images = tf.placeholder(tf.float32, [None, 299, 299, 3], name='input_images')
  18. labels = tf.placeholder(tf.int64, [None], name='labels')
  19.  
  20. # 定义inception-v3模型。因为谷歌给出的只有模型参数取值,所以这里
  21. # 需要在这个代码中定义inception-v3的模型结构。虽然理论上需要区分训练和
  22. # 测试中使用到的模型,也就是说在测试时应该使用is_training=False,但是
  23. # 因为预先训练好的inception-v3模型中使用的batch normalization参数与
  24. # 新的数据会有出入,所以这里直接使用同一个模型来做测试。
  25. with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
  26. logits, _ = inception_v3.inception_v3(
  27. images, num_classes=N_CLASSES, is_training=True)
  28.  
  29. trainable_variables = get_trainable_variables()
  30. # 定义损失函数和训练过程。
  31. tf.losses.softmax_cross_entropy(
  32. tf.one_hot(labels, N_CLASSES), logits, weights=1.0)
  33. total_loss = tf.losses.get_total_loss()
  34. train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(total_loss)
  35.  
  36. # 计算正确率。
  37. with tf.name_scope('evaluation'):
  38. correct_prediction = tf.equal(tf.argmax(logits, 1), labels)
  39. evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  40.  
  41. # 定义加载Google训练好的Inception-v3模型的Saver。
  42. load_fn = slim.assign_from_checkpoint_fn(
  43. CKPT_FILE,
  44. get_tuned_variables(),
  45. ignore_missing_vars=True)
  46.  
  47. # 定义保存新模型的Saver。
  48. saver = tf.train.Saver()
  49.  
  50. with tf.Session() as sess:
  51. # 初始化没有加载进来的变量。
  52. init = tf.global_variables_initializer()
  53. sess.run(init)
  54.  
  55. # 加载谷歌已经训练好的模型。
  56. print('Loading tuned variables from %s' % CKPT_FILE)
  57. load_fn(sess)
  58.  
  59. start = 0
  60. end = BATCH
  61. for i in range(STEPS):
  62. _, loss = sess.run([train_step, total_loss], feed_dict={
  63. images: training_images[start:end],
  64. labels: training_labels[start:end]})
  65.  
  66. if i % 30 == 0 or i + 1 == STEPS:
  67. saver.save(sess, TRAIN_FILE, global_step=i)
  68.  
  69. validation_accuracy = sess.run(evaluation_step, feed_dict={
  70. images: validation_images, labels: validation_labels})
  71. print('Step %d: Training loss is %.1f Validation accuracy = %.1f%%' % (
  72. i, loss, validation_accuracy * 100.0))
  73.  
  74. start = end
  75. if start == n_training_example:
  76. start = 0
  77.  
  78. end = start + BATCH
  79. if end > n_training_example:
  80. end = n_training_example
  81.  
  82. # 在最后的测试数据上测试正确率。
  83. test_accuracy = sess.run(evaluation_step, feed_dict={
  84. images: testing_images, labels: testing_labels})
  85. print('Final test accuracy = %.1f%%' % (test_accuracy * 100))

吴裕雄--天生自然python Google深度学习框架:Tensorflow实现迁移学习的更多相关文章

  1. 吴裕雄--天生自然python Google深度学习框架:经典卷积神经网络模型

    import tensorflow as tf INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE_SIZE = 28 NUM_CHANNELS = 1 NUM_LABEL ...

  2. 吴裕雄--天生自然python Google深度学习框架:图像识别与卷积神经网络

  3. 吴裕雄--天生自然python Google深度学习框架:MNIST数字识别问题

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  4. 吴裕雄--天生自然python Google深度学习框架:深度学习与深层神经网络

  5. 吴裕雄--天生自然python Google深度学习框架:TensorFlow实现神经网络

    http://playground.tensorflow.org/

  6. 吴裕雄--天生自然python Google深度学习框架:Tensorflow基础应用

    import tensorflow as tf a = tf.constant([1.0, 2.0], name="a") b = tf.constant([2.0, 3.0], ...

  7. 吴裕雄--天生自然python Google深度学习框架:人工智能、深度学习与机器学习相互关系介绍

  8. 吴裕雄--天生自然神经网络与深度学习实战Python+Keras+TensorFlow:Bellman函数、贪心算法与增强性学习网络开发实践

    !pip install gym import random import numpy as np import matplotlib.pyplot as plt from keras.layers ...

  9. 吴裕雄--天生自然神经网络与深度学习实战Python+Keras+TensorFlow:使用TensorFlow和Keras开发高级自然语言处理系统——LSTM网络原理以及使用LSTM实现人机问答系统

    !mkdir '/content/gdrive/My Drive/conversation' ''' 将文本句子分解成单词,并构建词库 ''' path = '/content/gdrive/My D ...

随机推荐

  1. eclipse导入maven工程,右键没有build path和工程不能自动编译解决方法

    原文链接:https://blog.csdn.net/wusunshine/article/details/52506389 eclipse导入maven工程,右键没有build path解决方法: ...

  2. POJ-3629 模拟

    A - Card Stacking Time Limit:1000MS     Memory Limit:65536KB     64bit IO Format:%I64d & %I64u S ...

  3. i春秋-web-upload(文件内容读取)(“百度杯”九月场)

    提示很明显,flag在flag.php中,所以,任务就是获取flag.php的内容. 方法一:一句话+菜刀(不再叙述) 方法二:上传脚本,使脚本拥有一定权限,再输出flag 先造一个php脚本 < ...

  4. 了解OOM

    1)什么是OOM? OOM,全称“Out Of Memory”,翻译成中文就是“内存用完了”,来源于java.lang.OutOfMemoryError.看下关于的官方说明: Thrown when ...

  5. 实验4&5

    [实验任务四]: 在上网时,我们经常会看到以下这种对话框,要用户输入一个验证码. 1.程序设计思想 先利用Math.random()得到一个整数,然后将其类型转换为字符类型,连接起来生成六位验证字符串 ...

  6. ubuntu16+caffe fast-rcnnCPU运行步骤

    //////////////////////////////////////////////////////////////////////////////////////////////////// ...

  7. 进度4_家庭记账本App

    在上一个博客中,我学习了用Fragment进行数据的传值,但是出现了好多问题,我通过百度查阅资料发现fregment在进行数值传输的时候有的语法不能使用,并且不方便的进行数据库的使用,所以我在原来的家 ...

  8. handler method 参数绑定常用注解

    handler method 参数绑定常用的注解,我们根据他们处理的Request的不同内容部分分为四类: A.处理requet uri 部分(这里指uri template中variable,不含q ...

  9. re模块2

    # 元字符+,*遇到?后就会变为贪婪匹配 print(re.findall('abc+?','abcccccc')) #['abc'] print(re.findall('abc*?','abcccc ...

  10. Okhttp 多次调用同一个方法出现错误java.net.SocketException: Socket closed

    Okhttp 多次调用同一个方法出现错误java.net.SocketException: Socket closed https://blog.csdn.net/QQiqq1314/article/ ...