本文已在公众号机器视觉与算法建模发布,转载请联系我。

使用TensorFlow的基本流程

本篇文章将介绍使用tensorflow的训练模型的基本流程,包括制作读取TFRecord,训练和保存模型,读取模型。

准备

TFRecord

TensorFlow提供了一种统一的格式来存储数据,这个格式就是TFRecord.

  1. message Example {
  2. Features features = 1;
  3. };
  4. message Features{
  5. map<string,Feature> featrue = 1;
  6. };
  7. message Feature{
  8. oneof kind{
  9. BytesList bytes_list = 1;
  10. FloatList float_list = 2;
  11. Int64List int64_list = 3;
  12. }
  13. };

从代码中我们可以看到, tf.train.Example 包含了一个字典,它的键是一个字符串,值为Feature,Feature可以取值为字符串(BytesList)、浮点数列表(FloatList)、整型数列表(Int64List)。

写入一个TFRecord一般分为三步:

  • 读取需要转化的数据
  • 将数据转化为Example Protocol Buffer,并写入这个数据结构
  • 通过将数据转化为字符串后,通过TFRecordWriter写出

方法一

这次我们的数据是分别保存在多个文件夹下的,因此读取数据最直接的方法是遍历目录下所有文件,然后读入写出TFRecord文件。该方法对应文件MakeTFRecord.py,我们来看关键代码

  1. filenameTrain = 'TFRecord/train.tfrecords'
  2. filenameTest = 'TFRecord/test.tfrecords'
  3. writerTrain = tf.python_io.TFRecordWriter(filenameTrain)
  4. writerTest = tf.python_io.TFRecordWriter(filenameTest)
  5. folders = os.listdir(HOME_PATH)
  6. for subFoldersName in folders:
  7. label = transform_label(subFoldersName)
  8. path = os.path.join(HOME_PATH, subFoldersName) # 文件夹路径
  9. subFoldersNameList = os.listdir(path)
  10. i = 0
  11. for imageName in subFoldersNameList:
  12. imagePath = os.path.join(path, imageName)
  13. images = cv2.imread(imagePath)
  14. res = cv2.resize(images, (128, 128), interpolation=cv2.INTER_CUBIC)
  15. image_raw_data = res.tostring()
  16. example = tf.train.Example(features=tf.train.Features(feature={
  17. 'label': _int64_feature(label),
  18. 'image_raw': _bytes_feature(image_raw_data)
  19. }))
  20. if i <= len(subFoldersNameList) * 3 / 4:
  21. writerTrain.write(example.SerializeToString())
  22. else:
  23. writerTest.write(example.SerializeToString())
  24. i += 1

在做数据的时候,我打算将3/4的数据用做训练集,剩下的1/4数据作为测试集,方便起见,将其保存为两个文件。

基本流程就是遍历Fnt目录下的所有文件夹,再进入子文件夹遍历其目录下的图片文件,然后用OpenCV的imread方法将其读入,再将图片数据转化为字符串。在TFRecord提供的数据结构中`_bytes_feature'是存储字符串的。

以上将图片成功读入并写入了TFRecord的数据结构中,那图片对应的标签怎么办呢?

  1. def transform_label(folderName):
  2. label_dict = {
  3. 'Sample001': 0,
  4. 'Sample002': 1,
  5. 'Sample003': 2,
  6. 'Sample004': 3,
  7. 'Sample005': 4,
  8. 'Sample006': 5,
  9. 'Sample007': 6,
  10. 'Sample008': 7,
  11. 'Sample009': 8,
  12. 'Sample010': 9,
  13. 'Sample011': 10,
  14. }
  15. return label_dict[folderName]

我建立了一个字典,由于一个文件下的图片都是同一类的,所以将图片对应的文件夹名字与它所对应的标签,产生映射关系。代码中label = transform_label(subFoldersName)通过该方法获得,图片的标签。

方法二

在使用方法一产生的数据训练模型,会发现非常容易产生过拟合。因为我们在读数据的时候是将它打包成batch读入的,虽然可以使用tf.train.shuffle_batch方法将队列中的数据打乱再读入,但是由于一个类中的数据过多,会导致即便打乱后也是同一个类中的数据。例如:数字0有1000个样本,假设你读取的队列长达1000个,这样即便打乱队列后读取的图片任然是0。这在训练时容易过拟合。为了避免这种情况发生,我的想法是在做数据时将图片打乱后写入。对应文件MakeTFRecord2.py,关键代码如下

  1. folders = os.listdir(HOME_PATH)
  2. for subFoldersName in folders:
  3. path = os.path.join(HOME_PATH, subFoldersName) # 文件夹路径
  4. subFoldersNameList = os.listdir(path)
  5. for imageName in subFoldersNameList:
  6. imagePath = os.path.join(path, imageName)
  7. totalList.append(imagePath)
  8. # 产生一个长度为图片总数的不重复随机数序列
  9. dictlist = random.sample(range(0, len(totalList)), len(totalList))
  10. print(totalList[0].split('\\')[1].split('-')[0]) # 这是图片对应的类别
  11. i = 0
  12. for path in totalList:
  13. images = cv2.imread(totalList[dictlist[i]])
  14. res = cv2.resize(images, (128, 128), interpolation=cv2.INTER_CUBIC)
  15. image_raw_data = res.tostring()
  16. label = transform_label(totalList[dictlist[i]].split('\\')[1].split('-')[0])
  17. print(label)
  18. example = tf.train.Example(features=tf.train.Features(feature={
  19. 'label': _int64_feature(label),
  20. 'image_raw': _bytes_feature(image_raw_data)
  21. }))
  22. if i <= len(totalList) * 3 / 4:
  23. writerTrain.write(example.SerializeToString())
  24. else:
  25. writerTest.write(example.SerializeToString())
  26. i += 1

基本过程:遍历目录下所有的图片,将它的路径加入一个大的列表。通过一个不重复的随机数序列,来控制使用哪张图片。这就达到随机的目的。

怎么获取标签呢?图片文件都是类型-序号这个形式命名的,这里通过获取它的类型名,建立字典产生映射关系。

  1. def transform_label(imgType):
  2. label_dict = {
  3. 'img001': 0,
  4. 'img002': 1,
  5. 'img003': 2,
  6. 'img004': 3,
  7. 'img005': 4,
  8. 'img006': 5,
  9. 'img007': 6,
  10. 'img008': 7,
  11. 'img009': 8,
  12. 'img010': 9,
  13. 'img011': 10,
  14. }
  15. return label_dict[imgType]

原尺寸图片CNN

对应CNN_train.py文件

训练的时候怎么读取TFRecord数据呢,参考以下代码

  1. # 读训练集数据
  2. def read_train_data():
  3. reader = tf.TFRecordReader()
  4. filename_train = tf.train.string_input_producer(["TFRecord128/train.tfrecords"])
  5. _, serialized_example_test = reader.read(filename_train)
  6. features = tf.parse_single_example(
  7. serialized_example_test,
  8. features={
  9. 'label': tf.FixedLenFeature([], tf.int64),
  10. 'image_raw': tf.FixedLenFeature([], tf.string),
  11. }
  12. )
  13. img_train = features['image_raw']
  14. images_train = tf.decode_raw(img_train, tf.uint8)
  15. images_train = tf.reshape(images_train, [128, 128, 3])
  16. labels_train = tf.cast(features['label'], tf.int64)
  17. labels_train = tf.cast(labels_train, tf.int64)
  18. labels_train = tf.one_hot(labels_train, 10)
  19. return images_train, labels_train

通过features[键名]的方式将存入的数据读取出来,键名和数据类型要与写入的保持一致。

关于这里的卷积神经网络,我是参考王学长培训时的代码写的。当然照搬肯定不行,会遇到loss NaN的情况,我解决的方法是仿照AlexNet中,在卷积后加入LRN层,进行局部响应归一化。在设置参数时,加入l2正则项。关键代码如下

  1. def weights_with_loss(shape, stddev, wl):
  2. var = tf.truncated_normal(stddev=stddev, shape=shape)
  3. if wl is not None:
  4. weight_loss = tf.multiply(tf.nn.l2_loss(var), wl, name='weight_loss')
  5. tf.add_to_collection('losses', weight_loss)
  6. return tf.Variable(var)
  7. def net(image, drop_pro):
  8. W_conv1 = weights_with_loss([5, 5, 3, 32], 5e-2, wl=0.0)
  9. b_conv1 = biasses([32])
  10. conv1 = tf.nn.relu(conv(image, W_conv1) + b_conv1)
  11. pool1 = max_pool_2x2(conv1)
  12. norm1 = tf.nn.lrn(pool1, 4, bias=1, alpha=0.001 / 9.0, beta=0.75)
  13. W_conv2 = weights_with_loss([5, 5, 32, 64], stddev=5e-2, wl=0.0)
  14. b_conv2 = biasses([64])
  15. conv2 = tf.nn.relu(conv(norm1, W_conv2) + b_conv2)
  16. norm2 = tf.nn.lrn(conv2, 4, bias=1, alpha=0.001 / 9.0, beta=0.75)
  17. pool2 = max_pool_2x2(norm2)
  18. W_conv3 = weights_with_loss([5, 5, 64, 128], stddev=0.04, wl=0.004)
  19. b_conv3 = biasses([128])
  20. conv3 = tf.nn.relu(conv(pool2, W_conv3) + b_conv3)
  21. pool3 = max_pool_2x2(conv3)
  22. W_conv4 = weights_with_loss([5, 5, 128, 256], stddev=1 / 128, wl=0.004)
  23. b_conv4 = biasses([256])
  24. conv4 = tf.nn.relu(conv(pool3, W_conv4) + b_conv4)
  25. pool4 = max_pool_2x2(conv4)
  26. image_raw = tf.reshape(pool4, shape=[-1, 8 * 8 * 256])
  27. # 全连接层
  28. fc_w1 = weights_with_loss(shape=[8 * 8 * 256, 1024], stddev=1 / 256, wl=0.0)
  29. fc_b1 = biasses(shape=[1024])
  30. fc_1 = tf.nn.relu(tf.matmul(image_raw, fc_w1) + fc_b1)
  31. # drop-out层
  32. drop_out = tf.nn.dropout(fc_1, drop_pro)
  33. fc_2 = weights_with_loss([1024, 10], stddev=0.01, wl=0.0)
  34. fc_b2 = biasses([10])
  35. return tf.matmul(drop_out, fc_2) + fc_b2

128x128x3原图训练过程



在验证集上的正确率



这里使用的是1281283的图片,图片比较大,所以我产生了一个想法。在做TFRecord数据的时候,将图片尺寸减半。所以就有了第二种方法。

图片尺寸减半CNN

对应文件CNN_train2.py

与上面那种方法唯一的区别是将图片尺寸128*128*3改成了64*64*3所以我这里就不重复说明了。

64x64x3图片训过程



在验证集上的正确率

保存模型

CNN_train.py中,对应保存模型的代码是

  1. def save_model(sess, step):
  2. MODEL_SAVE_PATH = "./model128/"
  3. MODEL_NAME = "model.ckpt"
  4. saver = tf.train.Saver()
  5. saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=step)
  6. save_model(sess, i)

i是迭代的次数,可以不填其对应的参数global_step

在测试集上检验准确率

对应文件AccuracyTest.py

代码基本与训练的代码相同,这里直接讲怎么恢复模型。关键代码

  1. ckpt = tf.train.get_checkpoint_state(MODEL_PATH)
  2. if ckpt and ckpt.model_checkpoint_path:
  3. #加载模型
  4. saver.restore(sess, ckpt.model_checkpoint_path)

值得一提的是tf.train.get_checkpoint_state该方法会自动找到文件夹下迭代次数最多的模型,然后读入。而saver.restore(sess, ckpt.model_checkpoint_path)方法将恢复,模型在训练时最后一次迭代的变量参数。

查看读入的TFRecord图片

对应文件ReadTest.py

如果你想检查下在制作TFRecord时,图片是否处理的正确,最简单的方法就是将图片显示出来。关键代码如下

  1. def plot_images(images, labels):
  2. for i in np.arange(0, 20):
  3. plt.subplot(5, 5, i + 1)
  4. plt.axis('off')
  5. plt.title(labels[i], fontsize=14)
  6. plt.subplots_adjust(top=1.5)
  7. plt.imshow(images[i])
  8. plt.show()
  9. plot_images(image, label

总结

在摸索过程中遇到很多问题,多亏了王学长耐心帮助,也希望这篇文章能帮助更多人吧。

新手上路,如果有错,欢迎指正,谢谢。

代码已上传github:https://github.com/wmpscc/TensorflowBaseDemo

阅读原文

使用TensorFlow训练模型的基本流程的更多相关文章

  1. 使用TensorFlow训练模型的基本流程【转】

    原文地址(https://github.com/wmpscc/TensorflowBaseDemo ) 本篇文章将介绍使用tensorflow的训练模型的基本流程,包括制作读取TFRecord,训练和 ...

  2. tensorflow之神经网络实现流程总结

    tensorflow之神经网络实现流程总结 1.数据预处理preprocess 2.前向传播的神经网络搭建(包括activation_function和层数) 3.指数下降的learning_rate ...

  3. 深度学习入门篇--手把手教你用 TensorFlow 训练模型

    欢迎大家前往腾讯云技术社区,获取更多腾讯海量技术实践干货哦~ 作者:付越 导语 Tensorflow在更新1.0版本之后多了很多新功能,其中放出了很多用tf框架写的深度网络结构(https://git ...

  4. 如何用Tensorflow训练模型成pb文件和和如何加载已经训练好的模型文件

    这篇薄荷主要是讲了如何用tensorflow去训练好一个模型,然后生成相应的pb文件.最后会将如何重新加载这个pb文件. 首先先放出PO主的github: https://github.com/ppp ...

  5. TensorFlow——训练模型的保存和载入的方法介绍

    我们在训练好模型的时候,通常是要将模型进行保存的,以便于下次能够直接的将训练好的模型进行载入. 1.保存模型 首先需要建立一个saver,然后在session中通过saver的save即可将模型保存起 ...

  6. Window7安装tensorflow整套环境详细流程

    安装tensorflow方式有好多种,为了方便编译环境以及包管理,这里采用Anaconda平台安装tensorflow. tensorflow官网:http://www.tensorflow.org/ ...

  7. tensorflow搭建神经网络基本流程

    定义添加神经层的函数 1.训练的数据2.定义节点准备接收数据3.定义神经层:隐藏层和预测层4.定义 loss 表达式5.选择 optimizer 使 loss 达到最小 然后对所有变量进行初始化,通过 ...

  8. 基于tensorflow训练模型的显存不足解决办法

    import tensorflow as tfimport osos.environ["CUDA_VISIBLE_DEVICES"] = '0' #指定第一块GPU可用config ...

  9. Tensorflow[架构流程]

    1. tensorflow工作流程 如官网所示: 根据整体架构或者代码功能可以分为: 图1.1 tensorflow架构 如图所示,一层C的api接口将底层的核运行时部分与顶层的多语言接口分离开. 而 ...

随机推荐

  1. jsonArray jsonString list<Object> 之间转换

    1.示例: package com.test.demo.pojo; import lombok.Data; import lombok.experimental.Accessors; /** * @p ...

  2. redis深入学习

    Redis持久化 官方文档: https://redis.io/topics/persistence 1.RDB和AOF优缺点 RDB: 可以在指定的时间间隔内生成数据集的时间点快照,把当前内存里的状 ...

  3. CF1230E Kamil and Making a Stream

    题目大意是求 \(\sum_{v,fa,lca(v,fa)=fa}gcd(v \to fa)\) 容易发现 \(\gcd\) 只会变小,所以根据这玩意是从上到下的,每次暴力一下就可以了,\(\gcd\ ...

  4. 视频会议系统MCU服务器视频传输处理模式

    视频会议系统MCU服务器视频传输处理模式 视频会议系统的组成主要包括终端.MCU服务器.网守等,其中的MCU服务器是整个系统的核心,视频会议系统的性能很大程度取决于MCU服务器的性能,因此MCU服务器 ...

  5. Can you answer these queries? HDU - 4027 有点坑

    #include<iostream> #include<cstring> #include<cstdio> #include<math.h> #incl ...

  6. Linux——基础之vi编辑器,编辑器之神!

    VI编辑器是什么? 我们学了怎么多的命令,都是为了我们的linux系统和远程操作的方便,那么我们现在怎么,编辑服务器上的文件和软件呢? 换句话说,就是我们如何通过命令行去完成文本和代码的编写,和系统的 ...

  7. java网页程序采用 spring 防止 csrf 攻击 转

    银行项目开发过程中,基本都会采用 spring 框架,所以完全可以不用自己开发 filter 去拦截 csrf 攻击的请求,而直接采用实现 spring 提供的 HandlerInterceptor ...

  8. Java各种类

    1.Object类 equals方法 2.Date类 构造方法 成员方法 DateFormat类 Calendar类 3.System类 StringBuilder原理 构造方法 toString方法 ...

  9. jQuery Moblie 问题汇总

    1  使用jQuery动态添加html,没有jQuery Moblie的样式 $("body").html(listview);//以上代码只是把结构加上去了,但是却没有加上jqm ...

  10. 【BZOJ 1022】 [SHOI2008]小约翰的游戏John(Anti_SG)

    Description 小约翰经常和他的哥哥玩一个非常有趣的游戏:桌子上有n堆石子,小约翰和他的哥哥轮流取石子,每个人取 的时候,可以随意选择一堆石子,在这堆石子中取走任意多的石子,但不能一粒石子也不 ...