在介绍这一节之前,需要你对slim模型库有一些基本了解,具体可以参考第二十二节,TensorFlow中的图片分类模型库slim的使用、数据集处理,这一节我们会详细介绍slim模型库下面的一些函数的使用。

一 简介

slim被放在tensorflow.contrib这个库下面,导入的方法如下:

  1. import tensorflow.contrib.slim as slim

这样我们就可以使用slim了,既然说到了,先来了解tensorflow.contrib这个库,tensorflow官方对它的描述是:此目录中的任何代码未经官方支持,可能会随时更改或删除。每个目录下都有指定的所有者。它旨在包含额外功能和贡献,最终会合并到核心TensorFlow中,但其接口可能仍然会发生变化,或者需要进行一些测试,看是否可以获得更广泛的接受。所以slim依然不属于原生tensorflow。

那么什么是slim?slim到底有什么用?

上一节已经讲到slim是一个使构建,训练,评估神经网络变得简单的库。它可以消除原生tensorflow里面很多重复的模板性的代码,让代码更紧凑,更具备可读性。另外slim提供了很多计算机视觉方面的著名模型(VGG, AlexNet等),我们不仅可以直接使用,甚至能以各种方式进行扩展。

slim的子模块及功能介绍:

  • arg_scope: provides a new scope named arg_scope that allows a user to define default arguments for specific operations within that scope.

除了基本的name_scope,variabel_scope外,又加了arg_scope,它是用来控制每一层的默认超参数的。(后面会详细说)

  • data: contains TF-slim's dataset definition, data providers, parallel_reader, and decoding utilities.

貌似slim里面还有一套自己的数据定义,这个跳过,我们用的不多。

  • evaluation: contains routines for evaluating models.

评估模型的一些方法,用的也不多。

  • layers: contains high level layers for building models using tensorflow.

这个比较重要,slim的核心和精髓,一些复杂层的定义。

  • learning: contains routines for training models.

一些训练规则。

  • losses: contains commonly used loss functions.

一些loss。

  • metrics: contains popular evaluation metrics.

评估模型的度量标准。

  • nets: contains popular network definitions such as VGG and AlexNet models.

包含一些经典网络,VGG等,用的也比较多。

  • queues: provides a context manager for easily and safely starting and closing QueueRunners.

文本队列管理,比较有用。

  • regularizers: contains weight regularizers.

包含一些正则规则。

  • variables: provides convenience wrappers for variable creation and manipulation.

这个比较有用,我很喜欢slim管理变量的机制。

二.slim定义模型

在slim中,组合使用variables, layers和scopes可以简洁的定义模型。

1.variable

定义于模型变量。生成一个weight变量, 用truncated normal初始化它, 并使用l2正则化,并将其放置于CPU上, 只需下面的代码即可:

  1. #定义模型变量
  2. weights = slim.model_variable('weights', shape=[10, 10, 3 , 3],
  3. initializer=tf.truncated_normal_initializer(stddev=0.1),
  4. regularizer=slim.l2_regularizer(0.05),
  5. device='/CPU:0')
  6. model_variables = slim.get_model_variables()

原生tensorflow包含两类变量:普通变量和局部变量。大部分变量都是普通变量,它们一旦生成就可以通过使用saver存入硬盘,局部变量只在session中存在,不会保存。

  • slim进一步的区分了变量类型,定义了model_variables(模型变量),这种变量代表了模型的参数。模型变量通过训练或者微调而得到学习,或者在评测或前向传播中可以从ckpt文件中载入。
  • 非模型参数在实际前向传播中不需要的参数,比如global_step。同样的,移动平均反应了模型参数,但它本身不是模型参数。如下:
  1. #常规变量
  2. my_var = slim.variable('my_var',shape=[20, 1],
  3. initializer=tf.zeros_initializer())
  4. #get_variables()得到模型参数和常规参数
  5. regular_variables_and_model_variables = slim.get_variables()

当我们通过slim的layers或着直接使用slim.model_variable创建变量时,tf会将此变量加入tf.GraphKeys.MODEL_VARIABLES这个集合中,当你需要构建自己的变量时,可以通过以下代码
将其加入模型参数。

  1. #Letting TF-Slim know about the additional variable.
  2. slim.add_model_variable(my_var)

2.layers

抽象并封装了常用的层,并且提供了repeat和stack操作,使得定义网络更加方便。
首先让我们看看tensorflow怎么实现一个层,例如卷积层:

  1. #在tensorflow下实现一个层
  2. input_x = tf.placeholder(dtype=tf.float32,shape=[None,224,224,3])
  3. with tf.name_scope('conv1_1') as scope:
  4. weight = tf.Variable(tf.truncated_normal([3, 3, 3, 64],
  5. dtype=tf.float32,
  6. stddev=1e-1),
  7. name='weights')
  8. conv = tf.nn.conv2d(input_x, weight, [1, 1, 1, 1], padding='SAME')
  9. bias = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32),
  10. trainable=True, name='biases')
  11. conv1 = tf.nn.relu(tf.nn.bias_add(conv, bias), name=scope)

然后slim的实现:

  1. #在slim实现一层
  2. net = slim.conv2d(input_x, 64, [3, 3], scope='conv1_1')

但这个不是重要的,因为tenorflow目前也有大部分层的简单实现,这里比较吸引人的是slim中的repeat和stack操作:

假设定义三个相同的卷积层:

  1. net = ...
  2. net = slim.conv2d(net, 256, [3, 3], scope='conv2_1')
  3. net = slim.conv2d(net, 256, [3, 3], scope='conv2_2')
  4. net = slim.conv2d(net, 256, [3, 3], scope='conv2_3')
  5. net = slim.max_pool2d(net, [2, 2], scope='pool2')

在slim中的repeat操作可以减少代码量:

  1. net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv2')
  2. net = slim.max_pool2d(net, [2, 2], scope='pool2')

repeat不仅只实现了相同操作相同参数的重复,它还将scope进行了展开,例子中的scope被展开为 'conv2/conv2_1', 'conv2/conv2_2' and 'conv2/conv2_3'。

而stack是处理卷积核或者输出不一样的情况,假设定义三层FC:

  1. #stack的使用 stack是处理卷积核或者输出不一样的情况,
  2. x = tf.placeholder(dtype=tf.float32,shape=[None,784])
  3. x = slim.fully_connected(x, 32, scope='fc/fc_1')
  4. x = slim.fully_connected(x, 64, scope='fc/fc_2')
  5. x = slim.fully_connected(x, 128, scope='fc/fc_3')
  1. #使用stack操作:
  2. x = slim.stack(x, slim.fully_connected, [32, 64, 128], scope='fc')

同理卷积层也一样:

  1. # 普通方法:
  2. net = slim.conv2d(input_x, 32, [3, 3], scope='core/core_1')
  3. net = slim.conv2d(net, 32, [1, 1], scope='core/core_2')
  4. net = slim.conv2d(net, 64, [3, 3], scope='core/core_3')
  5. net = slim.conv2d(net, 64, [1, 1], scope='core/core_4')
  6.  
  7. # 简便方法:
  8. net = slim.stack(input_x, slim.conv2d, [(32, [3, 3]), (32, [1, 1]), (64, [3, 3]), (64, [1, 1])], scope='core')

3.scope

除了tensorflow中的name_scope和variable_scope, tf.slim新增了arg_scope操作,这一操作符可以让定义在这一scope中的操作共享参数,即如不指定参数的话,则使用默认参数。且参数可以被局部覆盖。

如果你的网络有大量相同的参数,如下:

  1. net = slim.conv2d(input_x, 64, [11, 11], 4, padding='SAME',
  2. weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
  3. weights_regularizer=slim.l2_regularizer(0.0005), scope='conv1')
  4. net = slim.conv2d(net, 128, [11, 11], padding='VALID',
  5. weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
  6. weights_regularizer=slim.l2_regularizer(0.0005), scope='conv2')
  7. net = slim.conv2d(net, 256, [11, 11], padding='SAME',
  8. weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
  9. weights_regularizer=slim.l2_regularizer(0.0005), scope='conv3')

然后我们用arg_scope处理一下:

  1. #使用arg_scope
  2. with slim.arg_scope([slim.conv2d], padding='SAME',
  3. weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
  4. weights_regularizer=slim.l2_regularizer(0.0005)):
  5. net = slim.conv2d(input_x, 64, [11, 11], scope='conv1')
  6. net = slim.conv2d(net, 128, [11, 11], padding='VALID', scope='conv2')
  7. net = slim.conv2d(net, 256, [11, 11], scope='conv3')

如上倒数第二行代码,对padding进行了重新赋值。那如果除了卷积层还有其他层呢?那就要如下定义:

  1. with slim.arg_scope([slim.conv2d, slim.fully_connected],
  2. activation_fn=tf.nn.relu,
  3. weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
  4. weights_regularizer=slim.l2_regularizer(0.0005)):
  5. with slim.arg_scope([slim.conv2d], stride=1, padding='SAME'):
  6. net = slim.conv2d(input_x, 64, [11, 11], 4, padding='VALID', scope='conv1')
  7. net = slim.conv2d(net, 256, [5, 5],
  8. weights_initializer=tf.truncated_normal_initializer(stddev=0.03),
  9. scope='conv2')
  10. net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc')

写两个arg_scope就行了。采用如上方法,定义一个VGG也就十几行代码的事了。

  1. #定义一个vgg16网络
  2. def vgg16(inputs):
  3. with slim.arg_scope([slim.conv2d, slim.fully_connected],
  4. activation_fn=tf.nn.relu,
  5. weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
  6. weights_regularizer=slim.l2_regularizer(0.0005)):
  7. net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
  8. net = slim.max_pool2d(net, [2, 2], scope='pool1')
  9. net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
  10. net = slim.max_pool2d(net, [2, 2], scope='pool2')
  11. net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
  12. net = slim.max_pool2d(net, [2, 2], scope='pool3')
  13. net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
  14. net = slim.max_pool2d(net, [2, 2], scope='pool4')
  15. net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
  16. net = slim.max_pool2d(net, [2, 2], scope='pool5')
  17. net = slim.fully_connected(net, 4096, scope='fc6')
  18. net = slim.dropout(net, 0.5, scope='dropout6')
  19. net = slim.fully_connected(net, 4096, scope='fc7')
  20. net = slim.dropout(net, 0.5, scope='dropout7')
  21. net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc8')
  22. return net

三.训练模型

这里直接选用经典网络。

  1. import tensorflow as tf
  2. vgg = tf.contrib.slim.nets.vgg
  3.  
  4. # Load the images and labels.
  5. images, labels = ...
  6.  
  7. # Create the model.
  8. predictions, _ = vgg.vgg_16(images)
  9.  
  10. # Define the loss functions and get the total loss.
  11. loss = slim.losses.softmax_cross_entropy(predictions, labels)

关于loss,要说一下定义自己的loss的方法,以及注意不要忘记加入到slim中让slim看到你的loss。

还有正则项也是需要手动添加进loss当中的,不然最后计算的时候就不优化正则目标了。

  1. # Load the images and labels.
  2. images, scene_labels, depth_labels, pose_labels = ...
  3.  
  4. # Create the model.
  5. scene_predictions, depth_predictions, pose_predictions = CreateMultiTaskModel(images)
  6.  
  7. # Define the loss functions and get the total loss.
  8. classification_loss = slim.losses.softmax_cross_entropy(scene_predictions, scene_labels)
  9. sum_of_squares_loss = slim.losses.sum_of_squares(depth_predictions, depth_labels)
  10. pose_loss = MyCustomLossFunction(pose_predictions, pose_labels)
  11. slim.losses.add_loss(pose_loss) # Letting TF-Slim know about the additional loss.
  12.  
  13. # The following two ways to compute the total loss are equivalent:
  14. regularization_loss = tf.add_n(slim.losses.get_regularization_losses())
  15. total_loss1 = classification_loss + sum_of_squares_loss + pose_loss + regularization_loss
  16.  
  17. # (Regularization Loss is included in the total loss by default).
  18. total_loss2 = slim.losses.get_total_loss()

slim在learning.py中提供了一个简单而有用的训练模型的工具。我们只需调用slim.learning.create_train_op 和slim.learning.train就可以完成优化过程。

slim.learning.train函数被用来训练神经网络,函数定义如下:

  1. def slim.learning.train(train_op,
  2. logdir,
  3. train_step_fn=train_step,
  4. train_step_kwargs=_USE_DEFAULT,
  5. log_every_n_steps=1,
  6. graph=None,
  7. master='',
  8. is_chief=True,
  9. global_step=None,
  10. number_of_steps=None,
  11. init_op=_USE_DEFAULT,
  12. init_feed_dict=None,
  13. local_init_op=_USE_DEFAULT,
  14. init_fn=None,
  15. ready_op=_USE_DEFAULT,
  16. summary_op=_USE_DEFAULT,
  17. save_summaries_secs=600,
  18. summary_writer=_USE_DEFAULT,
  19. startup_delay_steps=0,
  20. saver=None,
  21. save_interval_secs=600,
  22. sync_optimizer=None,
  23. session_config=None,
  24. trace_every_n_steps=None):

其中部分参数的说明如下:

  • train_op: A `Tensor` that, when executed, will apply the gradients and return the loss value.
  • logdir: The directory where training logs are written to. If None, model checkpoints and summaries will not be written.检查点文件和日志文件的保存路径。
  • number_of_steps: The max number of gradient steps to take during training,as measured by 'global_step': training will stop if global_step is greater than 'number_of_steps'. If the value is left as None, training proceeds indefinitely.默认是一致循环训练。
  • save_summaries_secs: How often, in seconds, to save summaries.
  • summary_writer: `SummaryWriter` to use. Can be `None` to indicate that no summaries should be written. If unset, we create a SummaryWriter.
  • startup_delay_steps: The number of steps to wait for before beginning. Note that this must be 0 if a sync_optimizer is supplied.
  • saver: Saver to save checkpoints. If None, a default one will be created and used.
  • save_interval_secs: How often, in seconds, to save the model to `logdir`.
  1. g = tf.Graph()
  2.  
  3. # Create the model and specify the losses...
  4. ...
  5.  
  6. total_loss = slim.losses.get_total_loss()
  7. optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  8.  
  9. # create_train_op ensures that each time we ask for the loss, the update_ops
  10. # are run and the gradients being computed are applied too.
  11. train_op = slim.learning.create_train_op(total_loss, optimizer)
  12. logdir = ... # Where checkpoints are stored.
  13.  
  14. slim.learning.train(
  15. train_op,
  16. logdir,
  17. number_of_steps=1000, #迭代次数
  18. save_summaries_secs=300, #存summary间隔秒数
  19. save_interval_secs=600) #存模型间隔秒数

四.读取保存模型变量

在迁移学习中,我们经常会用到别人已经训练好的网络和模型参数,这时候我们可能需要从检查点文件中加载部分变量,下面我就会讲解如何加载指定变量。以及当前图的变量名和检查点文件中变量名不一致时怎么办。

1. 从检查恢复部分变量

通过以下功能我们可以载入模型的部分变量:

  1. # Create some variables.
  2. v1 = tf.Variable(..., name="v1")
  3. v2 = tf.Variable(..., name="v2")
  4. ...
  5. # Add ops to restore all the variables.
  6. restorer = tf.train.Saver()
  7.  
  8. # Add ops to restore some variables.
  9. restorer = tf.train.Saver([v1, v2])
  10.  
  11. # Later, launch the model, use the saver to restore variables from disk, and
  12. # do some work with the model.
  13. with tf.Session() as sess:
  14. # Restore variables from disk.
  15. restorer.restore(sess, "/tmp/model.ckpt")
  16. print("Model restored.")
  17. # Do some work with the model
  18. ...

通过这种方式我们可以加载不同变量名的变量!

2 从从检查点恢复部分变量还可以采用其他方法

  1. # Create some variables.
  2. v1 = slim.variable(name="v1", ...)
  3. v2 = slim.variable(name="nested/v2", ...)
  4. ...
  5.  
  6. # Get list of variables to restore (which contains only 'v2'). These are all
  7. # equivalent methods:
  8. #从检查点文件中恢复name='v2'的变量
  9. variables_to_restore = slim.get_variables_by_name("v2")
  10. # or 从检查点文件中恢复name带有2的所有变量
  11. variables_to_restore = slim.get_variables_by_suffix("")
  12. # or 从检查点文件中恢复命名空间scope='nested'的所有变量
  13. variables_to_restore = slim.get_variables(scope="nested")
  14. # or 恢复命名空间scope='nested'的所有变量
  15. variables_to_restore = slim.get_variables_to_restore(include=["nested"])
  16. # or 除了命名空间scope='v1'的变量
  17. variables_to_restore = slim.get_variables_to_restore(exclude=["v1"])
  18.  
  19. # Create the saver which will be used to restore the variables.
  20. restorer = tf.train.Saver(variables_to_restore)
  21.  
  22. with tf.Session() as sess:
  23. # Restore variables from disk.
  24. restorer.restore(sess, "/tmp/model.ckpt")
  25. print("Model restored.")
  26. # Do some work with the model
  27. ...

3.当图的变量名与checkpoint中的变量名不同时,恢复模型参数

当从checkpoint文件中恢复变量时,Saver在checkpoint文件中定位到变量名,并且把它们映射到当前图中的变量中。之前的例子中,我们创建了Saver,并为其提供了变量列表作为参数。这时,在checkpoint文件中定位的变量名,是隐含地从每个作为参数给出的变量的var.op.name而获得的。这一方式在图与checkpoint文件中变量名字相同时,可以很好的工作。而当名字不同时,必须给Saver提供一个将checkpoint文件中的变量名映射到图中的每个变量的字典。

假设我们定义的网络变量是conv1/weights,而从VGG检查点文件加载的变量名为vgg16/conv1/weights,正常load肯定会报错(找不到变量名),但是可以这样:例子见下:

  1. # Assuming that 'conv1/weights' should be restored from 'vgg16/conv1/weights'
  2. def name_in_checkpoint(var):
  3. return 'vgg16/' + var.op.name
  4.  
  5. # Assuming that 'conv1/weights' and 'conv1/bias' should be restored from 'conv1/params1' and 'conv1/params2'
  6. def name_in_checkpoint(var):
  7. if "weights" in var.op.name:
  8. return var.op.name.replace("weights", "params1")
  9. if "bias" in var.op.name:
  10. return var.op.name.replace("bias", "params2")
  11.  
  12. variables_to_restore = slim.get_model_variables()
  13. variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}
  14. restorer = tf.train.Saver(variables_to_restore)
  15.  
  16. with tf.Session() as sess:
  17. # Restore variables from disk.
  18. restorer.restore(sess, "/tmp/model.ckpt")

4.在一个不同的任务上对网络进行微调

比如我们要将1000类的imagenet分类任务应用于20类的Pascal VOC分类任务中,我们只导入部分层,见下例:

  1. image, label = MyPascalVocDataLoader(...)
  2. images, labels = tf.train.batch([image, label], batch_size=32)
  3.  
  4. # Create the model,20类
  5. predictions = vgg.vgg_16(images,num_classes=20)
  6.  
  7. train_op = slim.learning.create_train_op(...)
  8.  
  9. # Specify where the Model, trained on ImageNet, was saved.
  10. model_path = '/path/to/pre_trained_on_imagenet.checkpoint'
  11.  
  12. # Specify where the new model will live:
  13. log_dir = '/path/to/my_pascal_model_dir/'
  14.  
  15. # Restore only the convolutional layers: 从检查点载入除了fc6,fc7,fc8层之外的参数
  16. variables_to_restore = slim.get_variables_to_restore(exclude=['fc6', 'fc7', 'fc8'])
  17. init_fn = assign_from_checkpoint_fn(model_path, variables_to_restore)
  18.  
  19. # Start training.
  20. slim.learning.train(train_op, log_dir, init_fn=init_fn)

下面会显示一个具体迁移学习的案例。

五 预训练

如果我们仍然是对1000类的数据集进行分类,此时我们可以利用训练好的模型参数进行初始化,然后继续训练。

文件夹结构如下,不懂得话,可以参考第二十二节,TensorFlow中的图片分类模型库slim的使用、数据集处理,其中vgg预训练模型下载地址:https://github.com/tensorflow/models/tree/master/research/slim/#Pretrained

代码如下:

  1. def retrain():
  2. '''
  3. 演示一个VGG16网络的例子
  4. 从头开始训练
  5. '''
  6. batch_size = 128
  7.  
  8. learning_rate = 1e-4
  9.  
  10. #用于保存微调后的检查点文件和日志文件路径
  11. train_log_dir = './log/vgg16/slim_retrain'
  12.  
  13. #官方下载的检查点文件路径
  14. checkpoint_file = './log/vgg16/vgg_16.ckpt'
  15.  
  16. if not tf.gfile.Exists(train_log_dir):
  17. tf.gfile.MakeDirs(train_log_dir)
  18.  
  19. #创建一个图,作为当前图
  20. with tf.Graph().as_default():
  21.  
  22. #加载数据
  23. train_images, train_labels = ....
  24.  
  25. #创建vgg16网络 如果想冻结所有层,可以指定slim.conv2d中的 trainable=False
  26. logits,end_points = vgg.vgg_16(train_images, is_training=True)
  27.  
  28. #交叉熵代价函数
  29. slim.losses.softmax_cross_entropy(logits, onehot_labels=train_labels)
  30. total_loss = slim.losses.get_total_loss()
  31.  
  32. #设置写入到summary中的变量
  33. tf.summary.scalar('losses/total_loss', total_loss)
  34.  
  35. '''
  36. 设置优化器 这里不能指定成Adam优化器,因为我们的官方模型文件中使用的就是GradientDescentOptimizer优化器,
  37. 因此我们要和官方模型一致,如果想使用AdamOptimizer优化器,我们可以在调用完vgg16()网络后,就执行恢复模型。
  38. 而把执行恢复模型的代码放在后面,会由于我们在当前图中定义了一些检查点中不存在变量,恢复时在检查点文件找不
  39. 到变量,因此会报错。
  40. '''
  41. optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  42. #optimizer = tf.train.AdamOptimizer(learning_rate)
  43. # create_train_op that ensures that when we evaluate it to get the loss,
  44. # the update_ops are done and the gradient updates are computed.
  45. train_tensor = slim.learning.create_train_op(total_loss, optimizer)
  46.  
  47. # Restore only the convolutional layers: 从检查点载入除了fc8层之外的参数到当前图
  48. variables_to_restore = slim.get_variables_to_restore(exclude=['vgg_16/fc8'])
  49. init_fn = slim.assign_from_checkpoint_fn(checkpoint_file, variables_to_restore)
  50.  
  51. print('开始训练!')
  52. #开始训练网络
  53. slim.learning.train(train_tensor,
  54. train_log_dir,
  55. number_of_steps=100, #迭代次数 一次迭代batch_size个样本
  56. save_summaries_secs=300, #存summary间隔秒数
  57. save_interval_secs=300, #存模模型间隔秒数
  58. init_fn=init_fn)

六 微调

有时候我们数据集比较少的时候,可能使用已经训练的网络模型。比如我们想对flowers数据集进行分类。该数据集分成了两部分,训练集数据有3320张,校验集数据有350张。我们使用slim库下已经写好的vgg16网络,并下载对应的模型参数文件。由于模型参数是针对ImageNet数据集训练的得到的,而我们Flower数据集只有5类,因此需要把vgg16最后一层分类数改为5。

这里我们仍然先使用TensorFlow的网络架构来实现微调功能,后面我们再演示一个使用slim库简化之后的代码。

1.TensorFlow实现代码

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Wed Jun 6 11:56:58 2018
  4.  
  5. @author: zy
  6. """
  7.  
  8. '''
  9. 利用已经训练好的vgg16网络对flowers数据集进行微调
  10. 把最后一层分类由2000->5 然后重新训练,我们也可以冻结其它所有层,只训练最后一层
  11. '''
  12.  
  13. from nets import vgg
  14. import matplotlib.pyplot as plt
  15. import tensorflow as tf
  16. import numpy as np
  17. import input_data
  18. import os
  19.  
  20. slim = tf.contrib.slim
  21.  
  22. DATA_DIR = './datasets/data/flowers'
  23. #输出类别
  24. NUM_CLASSES = 5
  25.  
  26. #获取图片大小
  27. IMAGE_SIZE = vgg.vgg_16.default_image_size
  28.  
  29. def flowers_fine_tuning():
  30. '''
  31. 演示一个VGG16的例子
  32. 微调 这里只调整VGG16最后一层全连接层,把1000类改为5类
  33. 对网络进行训练
  34. '''
  35.  
  36. '''
  37. 1.设置参数,并加载数据
  38. '''
  39. #用于保存微调后的检查点文件和日志文件路径
  40. train_log_dir = './log/vgg16/fine_tune'
  41. train_log_file = 'flowers_fine_tune.ckpt'
  42.  
  43. #官方下载的检查点文件路径
  44. checkpoint_file = './log/vgg16/vgg_16.ckpt'
  45.  
  46. #设置batch_size
  47. batch_size = 256
  48.  
  49. learning_rate = 1e-4
  50.  
  51. #训练集数据长度
  52. n_train = 3320
  53. #测试集数据长度
  54. #n_test = 350
  55. #迭代轮数
  56. training_epochs = 3
  57.  
  58. display_epoch = 1
  59.  
  60. if not tf.gfile.Exists(train_log_dir):
  61. tf.gfile.MakeDirs(train_log_dir)
  62.  
  63. #加载数据
  64. train_images, train_labels = input_data.get_batch_images_and_label(DATA_DIR,batch_size,NUM_CLASSES,True,IMAGE_SIZE,IMAGE_SIZE)
  65. test_images, test_labels = input_data.get_batch_images_and_label(DATA_DIR,batch_size,NUM_CLASSES,False,IMAGE_SIZE,IMAGE_SIZE)
  66.  
  67. #获取模型参数的命名空间
  68. arg_scope = vgg.vgg_arg_scope()
  69.  
  70. #创建网络
  71. with slim.arg_scope(arg_scope):
  72.  
  73. '''
  74. 2.定义占位符和网络结构
  75. '''
  76. #输入图片
  77. input_images = tf.placeholder(dtype=tf.float32,shape = [None,IMAGE_SIZE,IMAGE_SIZE,3])
  78. #图片标签
  79. input_labels = tf.placeholder(dtype=tf.float32,shape = [None,NUM_CLASSES])
  80. #训练还是测试?测试的时候弃权参数会设置为1.0
  81. is_training = tf.placeholder(dtype = tf.bool)
  82.  
  83. #创建vgg16网络 如果想冻结所有层,可以指定slim.conv2d中的 trainable=False
  84. logits,end_points = vgg.vgg_16(input_images, is_training=is_training,num_classes = NUM_CLASSES)
  85. #print(end_points) 每个元素都是以vgg_16/xx命名
  86.  
  87. '''
  88. #从当前图中搜索指定scope的变量,然后从检查点文件中恢复这些变量(即vgg_16网络中定义的部分变量)
  89. #如果指定了恢复检查点文件中不存在的变量,则会报错 如果不知道检查点文件有哪些变量,我们可以打印检查点文件查看变量名
  90. params = []
  91. conv1 = slim.get_variables(scope="vgg_16/conv1")
  92. params.extend(conv1)
  93. conv2 = slim.get_variables(scope="vgg_16/conv2")
  94. params.extend(conv2)
  95. conv3 = slim.get_variables(scope="vgg_16/conv3")
  96. params.extend(conv3)
  97. conv4 = slim.get_variables(scope="vgg_16/conv4")
  98. params.extend(conv4)
  99. conv5 = slim.get_variables(scope="vgg_16/conv5")
  100. params.extend(conv5)
  101. fc6 = slim.get_variables(scope="vgg_16/fc6")
  102. params.extend(fc6)
  103. fc7 = slim.get_variables(scope="vgg_16/fc7")
  104. params.extend(fc7)
  105. '''
  106.  
  107. # Restore only the convolutional layers: 从检查点载入当前图除了fc8层之外所有变量的参数
  108. params = slim.get_variables_to_restore(exclude=['vgg_16/fc8'])
  109. #用于恢复模型 如果使用这个保存或者恢复的话,只会保存或者恢复指定的变量
  110. restorer = tf.train.Saver(params)
  111.  
  112. #预测标签
  113. pred = tf.argmax(logits,axis=1)
  114.  
  115. '''
  116. 3 定义代价函数和优化器
  117. '''
  118. #代价函数
  119. cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=input_labels,logits=logits))
  120.  
  121. #设置优化器
  122. optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
  123.  
  124. #预测结果评估
  125. correct = tf.equal(pred,tf.argmax(input_labels,1)) #返回一个数组 表示统计预测正确或者错误
  126. accuracy = tf.reduce_mean(tf.cast(correct,tf.float32)) #求准确率
  127.  
  128. num_batch = int(np.ceil(n_train / batch_size))
  129.  
  130. #用于保存检查点文件
  131. save = tf.train.Saver(max_to_keep=1)
  132.  
  133. #恢复模型
  134. with tf.Session() as sess:
  135. sess.run(tf.global_variables_initializer())
  136.  
  137. #检查最近的检查点文件
  138. ckpt = tf.train.latest_checkpoint(train_log_dir)
  139. if ckpt != None:
  140. save.restore(sess,ckpt)
  141. print('从上次训练保存后的模型继续训练!')
  142. else:
  143. restorer.restore(sess, checkpoint_file)
  144. print('从官方模型加载训练!')
  145.  
  146. #创建一个协调器,管理线程
  147. coord = tf.train.Coordinator()
  148.  
  149. #启动QueueRunner, 此时文件名才开始进队。
  150. threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  151.  
  152. '''
  153. 4 查看预处理之后的图片
  154. '''
  155. imgs, labs = sess.run([train_images, train_labels])
  156. print('原始训练图片信息:',imgs.shape,labs.shape)
  157. show_img = np.array(imgs[0],dtype=np.uint8)
  158. plt.imshow(show_img)
  159. plt.title('Original train image')
  160. plt.show()
  161.  
  162. imgs, labs = sess.run([test_images, test_labels])
  163. print('原始测试图片信息:',imgs.shape,labs.shape)
  164. show_img = np.array(imgs[0],dtype=np.uint8)
  165. plt.imshow(show_img)
  166. plt.title('Original test image')
  167. plt.show()
  168.  
  169. print('开始训练!')
  170. for epoch in range(training_epochs):
  171. total_cost = 0.0
  172. for i in range(num_batch):
  173. imgs, labs = sess.run([train_images, train_labels])
  174. _,loss = sess.run([optimizer,cost],feed_dict={input_images:imgs,input_labels:labs,is_training:True})
  175. total_cost += loss
  176.  
  177. #打印信息
  178. if epoch % display_epoch == 0:
  179. print('Epoch {}/{} average cost {:.9f}'.format(epoch+1,training_epochs,total_cost/num_batch))
  180.  
  181. #进行预测处理
  182. imgs, labs = sess.run([test_images, test_labels])
  183. cost_values,accuracy_value = sess.run([cost,accuracy],feed_dict = {input_images:imgs,input_labels:labs,is_training:False})
  184. print('Epoch {}/{} Test cost {:.9f}'.format(epoch+1,training_epochs,cost_values))
  185. print('准确率:',accuracy_value)
  186.  
  187. #保存模型
  188. save.save(sess,os.path.join(train_log_dir,train_log_file),global_step = epoch)
  189. print('Epoch {}/{} 模型保存成功'.format(epoch+1,training_epochs))
  190.  
  191. print('训练完成')
  192.  
  193. #终止线程
  194. coord.request_stop()
  195. coord.join(threads)
  196.  
  197. def flowers_test():
  198. '''
  199. 使用微调好的网络进行测试
  200. '''
  201. '''
  202. 1.设置参数,并加载数据
  203. '''
  204. #微调后的检查点文件和日志文件路径
  205. save_dir = './log/vgg16/fine_tune'
  206.  
  207. #设置batch_size
  208. batch_size = 128
  209.  
  210. #加载数据
  211. train_images, train_labels = input_data.get_batch_images_and_label(DATA_DIR,batch_size,NUM_CLASSES,True,IMAGE_SIZE,IMAGE_SIZE)
  212. test_images, test_labels = input_data.get_batch_images_and_label(DATA_DIR,batch_size,NUM_CLASSES,False,IMAGE_SIZE,IMAGE_SIZE)
  213.  
  214. #获取模型参数的命名空间
  215. arg_scope = vgg.vgg_arg_scope()
  216.  
  217. #创建网络
  218. with slim.arg_scope(arg_scope):
  219.  
  220. '''
  221. 2.定义占位符和网络结构
  222. '''
  223. #输入图片
  224. input_images = tf.placeholder(dtype=tf.float32,shape = [None,IMAGE_SIZE,IMAGE_SIZE,3])
  225. #训练还是测试?测试的时候弃权参数会设置为1.0
  226. is_training = tf.placeholder(dtype = tf.bool)
  227.  
  228. #创建vgg16网络
  229. logits,end_points = vgg.vgg_16(input_images, is_training=is_training,num_classes = NUM_CLASSES)
  230.  
  231. #预测标签
  232. pred = tf.argmax(logits,axis=1)
  233.  
  234. restorer = tf.train.Saver()
  235.  
  236. #恢复模型
  237. with tf.Session() as sess:
  238. sess.run(tf.global_variables_initializer())
  239. ckpt = tf.train.latest_checkpoint(save_dir)
  240. if ckpt != None:
  241. #恢复模型
  242. restorer.restore(sess,ckpt)
  243. print("Model restored.")
  244.  
  245. #创建一个协调器,管理线程
  246. coord = tf.train.Coordinator()
  247.  
  248. #启动QueueRunner, 此时文件名才开始进队。
  249. threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  250.  
  251. '''
  252. 查看预处理之后的图片
  253. '''
  254. imgs, labs = sess.run([test_images, test_labels])
  255. print('原始测试图片信息:',imgs.shape,labs.shape)
  256. show_img = np.array(imgs[0],dtype=np.uint8)
  257. plt.imshow(show_img)
  258. plt.title('Original test image')
  259. plt.show()
  260.  
  261. pred_value = sess.run(pred,feed_dict = {input_images:imgs,is_training:False})
  262. print('预测结果为:',pred_value)
  263. print('实际结果为:',np.argmax(labs,1))
  264. correct = np.equal(pred_value,np.argmax(labs,1))
  265. print('准确率为:', np.mean(correct))
  266.  
  267. #终止线程
  268. coord.request_stop()
  269. coord.join(threads)
  270.  
  271. if __name__ == '__main__':
  272. tf.reset_default_graph()
  273. flowers_fine_tuning()
  274. flowers_test()

这里我在训练的时候,冻结了出输出层之外的所有层,运行结果如下:

三轮之后,我们可以看到准确率大概在60%。

如果我们不冻结其它层,(训练所有层,速度慢),3轮下来,准确率可以达到90%左右。

2.Slim库实现代码

使用slim库简化上面的代码:

  1. def flowers_simple_fine_tuning():
  2. '''
  3. 演示一个VGG16的例子
  4. 微调 这里只调整VGG16最后一层全连接层,把1000类改为5类
  5. 对网络进行训练 使用slim库简化代码
  6. '''
  7. batch_size = 128
  8.  
  9. learning_rate = 1e-4
  10.  
  11. #用于保存微调后的检查点文件和日志文件路径
  12. train_log_dir = './log/vgg16/slim_fine_tune'
  13.  
  14. #官方下载的检查点文件路径
  15. checkpoint_file = './log/vgg16/vgg_16.ckpt'
  16.  
  17. if not tf.gfile.Exists(train_log_dir):
  18. tf.gfile.MakeDirs(train_log_dir)
  19.  
  20. #创建一个图,作为当前图
  21. with tf.Graph().as_default():
  22.  
  23. #加载数据
  24. train_images, train_labels = input_data.get_batch_images_and_label(DATA_DIR,batch_size,NUM_CLASSES,True,IMAGE_SIZE,IMAGE_SIZE)
  25.  
  26. #创建vgg16网络 如果想冻结所有层,可以指定slim.conv2d中的 trainable=False
  27. logits,end_points = vgg.vgg_16(train_images, is_training=True,num_classes = NUM_CLASSES)
  28.  
  29. #交叉熵代价函数
  30. slim.losses.softmax_cross_entropy(logits, onehot_labels=train_labels)
  31. total_loss = slim.losses.get_total_loss()
  32.  
  33. #设置写入到summary中的变量
  34. tf.summary.scalar('losses/total_loss', total_loss)
  35.  
  36. '''
  37. 设置优化器 这里不能指定成Adam优化器,因为我们的官方模型文件中使用的就是GradientDescentOptimizer优化器,
  38. 因此我们要和官方模型一致,如果想使用AdamOptimizer优化器,我们可以在调用完vgg16()网络后,就执行恢复模型。
  39. 而把执行恢复模型的代码放在后面,会由于我们在当前图中定义了一些检查点中不存在变量,恢复时在检查点文件找不
  40. 到变量,因此会报错。
  41. '''
  42. optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  43. #optimizer = tf.train.AdamOptimizer(learning_rate)
  44. # create_train_op that ensures that when we evaluate it to get the loss,
  45. # the update_ops are done and the gradient updates are computed.
  46. train_tensor = slim.learning.create_train_op(total_loss, optimizer)
  47.  
  48. #检查最近的检查点文件
  49. ckpt = tf.train.latest_checkpoint(train_log_dir)
  50. if ckpt != None:
  51. variables_to_restore = slim.get_model_variables()
  52. init_fn = slim.assign_from_checkpoint_fn(ckpt,variables_to_restore)
  53. print('从上次训练保存后的模型继续训练!')
  54. else:
  55. # Restore only the convolutional layers: 从检查点载入除了fc8层之外的参数到当前图
  56. variables_to_restore = slim.get_variables_to_restore(exclude=['vgg_16/fc8'])
  57. init_fn = slim.assign_from_checkpoint_fn(checkpoint_file, variables_to_restore)
  58. print('从官方模型加载训练!')
  59.  
  60. print('开始训练!')
  61. #开始训练网络
  62. slim.learning.train(train_tensor,
  63. train_log_dir,
  64. number_of_steps=100, #迭代次数 一次迭代batch_size个样本
  65. save_summaries_secs=300, #存summary间隔秒数
  66. save_interval_secs=300, #存模模型间隔秒数
  67. init_fn=init_fn)

上面的代码中我们用到了input_data.py文件,主要负责加载数据集,程序如下:

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Fri Jun 8 08:52:30 2018
  4.  
  5. @author: zy
  6. """
  7.  
  8. '''
  9. 导入flowers数据集
  10. '''
  11.  
  12. from datasets import download_and_convert_flowers
  13. from preprocessing import vgg_preprocessing
  14. from datasets import flowers
  15. import tensorflow as tf
  16.  
  17. slim = tf.contrib.slim
  18.  
  19. def read_flower_image_and_label(dataset_dir,is_training=False):
  20. '''
  21. 下载flower_photos.tgz数据集
  22. 切分训练集和验证集
  23. 并将数据转换成TFRecord格式 5个训练数据文件(3320),5个验证数据文件(350),还有一个标签文件(存放每个数字标签对应的类名)
  24.  
  25. args:
  26. dataset_dir:数据集所在的目录
  27. is_training:设置为TRue,表示加载训练数据集,否则加载验证集
  28. return:
  29. image,label:返回随机读取的一张图片,和对应的标签
  30. '''
  31. download_and_convert_flowers.run(dataset_dir)
  32. '''
  33. 利用slim读取TFRecord中的数据
  34. '''
  35. #选择数据集train
  36. if is_training:
  37. dataset = flowers.get_split(split_name = 'train',dataset_dir=dataset_dir)
  38. else:
  39. dataset = flowers.get_split(split_name = 'validation',dataset_dir=dataset_dir)
  40.  
  41. #创建一个数据provider
  42. provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
  43.  
  44. #通过provider的get随机获取一条样本数据 返回的是两个张量
  45. [image,label] = provider.get(['image','label'])
  46.  
  47. return image,label
  48.  
  49. def get_batch_images_and_label(dataset_dir,batch_size,num_classes,is_training=False,output_height=224, output_width=224,num_threads=10):
  50. '''
  51. 每次取出batch_size个样本
  52.  
  53. 注意:这里预处理调用的是slim库图片预处理的函数,例如:如果你使用的vgg网络,就调用vgg网络的图像预处理函数
  54. 如果你使用的是自己定义的网络,则可以自己写适合自己图像的预处理函数,比如归一化处理也可以使用其他网络已经写好的预处理函数
  55.  
  56. args:
  57. dataset_dir:数据集所在的目录
  58. batch_size:一次取出的样本数量
  59. num_classes:输出的类别 用于对标签one_hot编码
  60. is_training:设置为TRue,表示加载训练数据集,否则加载验证集
  61. output_height:输出图片高度
  62. output_width:输出图片宽
  63.  
  64. return:
  65. images,labels:返回随机读取的batch_size张图片,和对应的标签one_hot编码
  66. '''
  67. #获取单张图像和标签
  68. image,label = read_flower_image_and_label(dataset_dir,is_training)
  69. # 图像预处理 这里要求图片数据是tf.float32类型的
  70. image = vgg_preprocessing.preprocess_image(image, output_height, output_width,is_training=is_training)
  71.  
  72. #缩放处理
  73. #image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  74. #image = tf.image.resize_image_with_crop_or_pad(image, output_height, output_width)
  75.  
  76. # shuffle_batch 函数会将数据顺序打乱
  77. # bacth 函数不会将数据顺序打乱
  78. images, labels = tf.train.batch(
  79. [image, label],
  80. batch_size = batch_size,
  81. capacity=5 * batch_size,
  82. num_threads = num_threads)
  83.  
  84. #one-hot编码
  85. labels = slim.one_hot_encoding(labels,num_classes)
  86.  
  87. return images,labels

3.CNN网络代码,与vgg16微调效果对比

我们这里使用三层的cnn网络对flower数据集进行分类,测试一下其效果如何:

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Fri Jun 8 08:51:45 2018
  4.  
  5. @author: zy
  6. """
  7.  
  8. '''
  9. 使用卷积神经网络训练flowers数据集
  10. 用来和微调后的VGG网络对比
  11. '''
  12.  
  13. import tensorflow as tf
  14. import input_data
  15. import numpy as np
  16.  
  17. slim = tf.contrib.slim
  18.  
  19. def cnn(inputs,num_classes=5):
  20. '''
  21. 定义一个cnn网络结构
  22.  
  23. args:
  24. inputs:输入形状为[batch_size,in_height,in_width,in_channel]
  25. 输入图片大小为224 x 224 x3
  26. num_classes:类别数
  27.  
  28. '''
  29. with tf.variable_scope('cnn'):
  30. with slim.arg_scope([slim.conv2d,slim.fully_connected,slim.max_pool2d,slim.avg_pool2d],
  31. padding='SAME',
  32. ):
  33. net = slim.conv2d(inputs,64,[5,5],4,weights_initializer=tf.truncated_normal_initializer(stddev=0.01),scope='conv1') #batch_size x 56 x 56 x64
  34. net = slim.max_pool2d(net,[2,2],scope='pool1') #batch_size x 28 x 28 x64
  35. net = slim.conv2d(net,64,[3,3],2,weights_initializer=tf.truncated_normal_initializer(stddev=0.01),scope='conv2') #batch_size x 14 x 14 x64
  36. net = slim.max_pool2d(net,[2,2],scope='pool2') #batch_size x 7 x 7 x64
  37. #net = slim.conv2d(net,num_classes,[7,7],7,weights_initializer=tf.truncated_normal_initializer(stddev=0.01),scope='conv3') #batch_size x 1 x 1 x num_classes
  38. net = slim.conv2d(net,num_classes,[1,1],1,weights_initializer=tf.truncated_normal_initializer(stddev=0.01),scope='conv3') #batch_size x7 x 7 xnum_classes
  39. net = slim.avg_pool2d(net,[7,7],7,scope='pool3') #全局平均池化层
  40. net = tf.squeeze(net,[1,2]) #batch_size x num_classes
  41. return net
  42.  
  43. DATA_DIR = './datasets/data/flowers'
  44. #输出类别
  45. NUM_CLASSES = 5
  46. IMAGE_SIZE = 224
  47.  
  48. def flower_cnn():
  49. '''
  50. 使用CNN网络训练flower数据集
  51. '''
  52. #设置batch_size
  53. batch_size = 128
  54.  
  55. learning_rate = 1e-4
  56.  
  57. #训练集数据长度
  58. n_train = 3320
  59. #测试集数据长度
  60. #n_test = 350
  61. #迭代轮数
  62. training_epochs = 20
  63.  
  64. display_epoch = 1
  65.  
  66. #加载数据
  67. train_images, train_labels = input_data.get_batch_images_and_label(DATA_DIR,batch_size,NUM_CLASSES,True,IMAGE_SIZE,IMAGE_SIZE)
  68. test_images, test_labels = input_data.get_batch_images_and_label(DATA_DIR,batch_size,NUM_CLASSES,True,IMAGE_SIZE,IMAGE_SIZE)
  69.  
  70. #定义占位符
  71. input_images = tf.placeholder(dtype=tf.float32,shape = [None,IMAGE_SIZE,IMAGE_SIZE,3])
  72. input_labels = tf.placeholder(dtype=tf.float32,shape = [None,NUM_CLASSES])
  73. is_training = tf.placeholder(dtype = tf.bool)
  74.  
  75. #创建cnn网络
  76. logits = cnn(input_images,num_classes = NUM_CLASSES)
  77.  
  78. #预测标签
  79. pred = tf.argmax(logits,axis=1)
  80.  
  81. #代价函数
  82. cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=input_labels,logits=logits))
  83.  
  84. #设置优化器
  85. optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
  86.  
  87. #预测结果评估
  88. correct = tf.equal(pred,tf.argmax(input_labels,1)) #返回一个数组 表示统计预测正确或者错误
  89. accuracy = tf.reduce_mean(tf.cast(correct,tf.float32)) #求准确率
  90.  
  91. num_batch = int(np.ceil(n_train / batch_size))
  92.  
  93. '''
  94. 启动会话,开始训练
  95. '''
  96. with tf.Session() as sess:
  97. sess.run(tf.global_variables_initializer())
  98.  
  99. #创建一个协调器,管理线程
  100. coord = tf.train.Coordinator()
  101.  
  102. #启动QueueRunner, 此时文件名才开始进队。
  103. threads=tf.train.start_queue_runners(sess=sess,coord=coord)
  104.  
  105. print('开始训练!')
  106. for epoch in range(training_epochs):
  107. total_cost = 0.0
  108. for i in range(num_batch):
  109. imgs, labs = sess.run([train_images, train_labels])
  110. _,loss = sess.run([optimizer,cost],feed_dict={input_images:imgs,input_labels:labs,is_training:True})
  111. total_cost += loss
  112.  
  113. #打印信息
  114. if epoch % display_epoch == 0:
  115. print('Epoch {}/{} Train average cost {:.9f}'.format(epoch+1,training_epochs,total_cost/num_batch))
  116. #进行预测处理
  117. imgs, labs = sess.run([test_images, test_labels])
  118. cost_values,accuracy_value = sess.run([cost,accuracy],feed_dict = {input_images:imgs,input_labels:labs,is_training:False})
  119. print('Epoch {}/{} Test cost {:.9f}'.format(epoch+1,training_epochs,cost_values))
  120. print('准确率:',accuracy_value)
  121.  
  122. print('训练完成')
  123. #终止线程
  124. coord.request_stop()
  125. coord.join(threads)
  126.  
  127. if __name__ == '__main__':
  128. tf.reset_default_graph()
  129. flower_cnn()

我们可以看到20轮下来准确率大概在55%,效果并不是很好。而使用vgg16微调的效果明显更高。

参考文章

[1]【Tensorflow】辅助工具篇——tensorflow slim(TF-Slim)介绍

[2]TF-Slim简介

第二十四节,TensorFlow下slim库函数的使用以及使用VGG网络进行预训练、迁移学习(附代码)的更多相关文章

  1. 风炫安全WEB安全学习第二十四节课 利用XSS钓鱼攻击

    风炫安全WEB安全学习第二十四节课 利用XSS钓鱼攻击 XSS钓鱼攻击 HTTP Basic Authentication认证 大家在登录网站的时候,大部分时候是通过一个表单提交登录信息. 但是有时候 ...

  2. Scala入门到精通——第二十四节 高级类型 (三)

    作者:摆摆少年梦 视频地址:http://blog.csdn.net/wsscy2004/article/details/38440247 本节主要内容 Type Specialization Man ...

  3. [ExtJS5学习笔记]第二十四节 Extjs5中表格gridpanel或者表单数据后台传输remoteFilter设置

    本文地址:http://blog.csdn.net/sushengmiyan/article/details/39667533 官方文档:http://docs.sencha.com/extjs/5. ...

  4. 【php增删改查实例】第二十四节 - 文件上传在项目中的具体应用

    文件上传在项目中,一般有两个用武之地,分别为设置用户的头像和上传附件.本节我们演示如果进行用户头像的上传. 因为一个用户单独并且唯一对应了一个头像,是一对一的关系,所以我们需要去给tm_users表添 ...

  5. 第二十四节:Java语言基础-讲解数组的综合应用

    数组的综合应用 // 打印数组 public static void printArray(int[] arr) { for(int x=0;x<arr.length;x++) { if(x!= ...

  6. php第二十四节课

    三级联动 <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3 ...

  7. Gradle 1.12用户指南翻译——第二十四章. Groovy 插件

    其他章节的翻译请参见: http://blog.csdn.net/column/details/gradle-translation.html 翻译项目请关注Github上的地址: https://g ...

  8. 大白话5分钟带你走进人工智能-第十四节过拟合解决手段L1和L2正则

                                                                               第十四节过拟合解决手段L1和L2正则 第十三节中, ...

  9. 第三百四十四节,Python分布式爬虫打造搜索引擎Scrapy精讲—craw母版l创建自动爬虫文件—以及 scrapy item loader机制

    第三百四十四节,Python分布式爬虫打造搜索引擎Scrapy精讲—craw母版l创建自动爬虫文件—以及 scrapy item loader机制 用命令创建自动爬虫文件 创建爬虫文件是根据scrap ...

随机推荐

  1. How to sign app

    codesign --display --verbose=4 /applications/qq.app codesign --display --entitlements - /application ...

  2. 查询的model里面 一般都要有一个要返回的model做属性 ;查询前要传入得参数,查询后返回的参数 都要集合在一个model中

    查询的model里面 一般都要有一个要返回的model做属性

  3. 51-node-1649齐头并进(最短路)

    题意:中文题,没啥坑点: 解题思路:这道题一开始以为要跑两个最短路,后来发现不用,因为如果给定了铁路的线路,那么,公路一定是n个节点无向图的补图,所以,铁路和公路之间一定有一个是可以直接从1到n的,我 ...

  4. Bash 5.0 发布及其新功能

    导读 邮件列表证实最近发布了 Bash-5.0.而且,令人兴奋的是它还有新的功能和变量.如果你一直在使用 Bash 4.4.XX,那么你一定会喜欢 Bash 的第五个主要版本. 第五个版本侧重于新的 ...

  5. Qt QTimer

    QTimer类提供了重复和单次触发信号的定时器. QTimer类为定时器提供了一个高级别的编程接口.很容易使用:首先,创建一个QTimer,连接timeout()信号到适当的槽函数,并调用start( ...

  6. Editor markdown编辑器

    代码示例网址:http://pandao.github.io/editor.md/examples/index.html 引入文件 <link rel="stylesheet" ...

  7. linq之group by 的使用

    group by var list = from s in _sysBll.GetList(s => s.ParamID == "TraSchType" && ...

  8. BZOJ4873[Shoi2017]寿司餐厅——最大权闭合子图

    题目描述 Kiana最近喜欢到一家非常美味的寿司餐厅用餐.每天晚上,这家餐厅都会按顺序提供n种寿司,第i种寿司有一个 代号ai和美味度di,i,不同种类的寿司有可能使用相同的代号.每种寿司的份数都是无 ...

  9. 微信小程序——使用vue构建小程序【外传】

    文档 http://mpvue.com/mpvue/ 根据文档构建完成的页面如下 更多的,还要继续看下文档~

  10. BZOJ 2049 洞穴勘测

    LCT判断联通性 没什么特别的..还是一个普通的板子题,把LCT当并查集用了,只不过LCT灵活一些,还可以断边 话说自从昨天被维修数列那题榨干之后我现在写splay都不用动脑子了,,机械式的码spla ...