如下样例基于tensorflow实现了一个简单的3层深度学习入门框架程序,程序主要有如下特性:

1、  基于著名的MNIST手写数字集样例数据:http://yann.lecun.com/exdb/mnist/

2、  加入衰减学习率优化,使得学习率可以根据训练步数指数级减少,在训练后期增加模型稳定性

3、  加入L2正则化,减少各个权重值大小,避免过拟合问题

4、  加入滑动平均模型,提高模型在验证数据上的准确性

网络一共3层,第一层输入层784个节点的输入层,第二层隐藏层有500个节点,第三层输出层有10个节点。

  1. # 导入模块库
  2. import tensorflow as tf
  3. import datetime
  4. import numpy as np
  5.  
  6. # 已经被废弃掉了
  7. #from tensorflow.examples.tutorials.mnist import input_data
  8. from tensorflow.contrib.learn.python.learn.datasets import mnist
  9. from tensorflow.contrib.layers import l2_regularizer
  10.  
  11. # 屏蔽AVX2特性告警信息
  12. import os
  13. os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
  14.  
  15. # 屏蔽mnist.read_data_sets被弃用告警
  16. import logging
  17. class WarningFilter(logging.Filter):
  18. def filter(self, record):
  19. msg = record.getMessage()
  20. tf_warning = 'datasets' in msg
  21. return not tf_warning
  22. logger = logging.getLogger('tensorflow')
  23. logger.addFilter(WarningFilter())
  24.  
  25. # 神经网络结构定义:输入784个特征值,包含一个500个节点的隐藏层,10个节点的输出层
  26. INPUT_NODE = 784
  27. OUTPUT_NODE = 10
  28. LAYER1_NODE = 500
  29.  
  30. # 随机梯度下降法数据集大小为100,训练步骤为30000
  31. BATCH_SIZE = 100
  32. TRAINING_STEPS = 30000
  33.  
  34. # 衰减学习率
  35. LEARNING_RATE_BASE = 0.8
  36. LEARNING_RATE_DECAY = 0.99
  37.  
  38. # L2正则化
  39. REGULARIZATION_RATE = 0.0001
  40. MOVING_AVERAGE_DECAY = 0.99
  41.  
  42. validation_accuracy_rate_list = []
  43. test_accuracy_rate_list = []
  44.  
  45. # 定义前向更新过程
  46. def inference(input_tensor,avg_class,weights1,biase1,weights2,biase2):
  47. if avg_class == None:
  48. layer1 = tf.nn.relu(tf.matmul(input_tensor,weights1) + biase1)
  49. return tf.matmul(layer1,weights2) + biase2
  50. else:
  51. layer1 = tf.nn.relu(tf.matmul(input_tensor,avg_class.average(weights1)) + avg_class.average(biase1))
  52. return tf.matmul(layer1,avg_class.average(weights2)) + avg_class.average(biase2)
  53.  
  54. # 定义训练过程
  55. def train(mnist_datasets):
  56. # 定义输入
  57. x = tf.placeholder(dtype=tf.float32,shape=[None,784])
  58. y_ = tf.placeholder(dtype=tf.float32,shape=[None,10])
  59.  
  60. # 定义训练参数
  61. weights1 = tf.Variable(tf.truncated_normal(shape=[INPUT_NODE,LAYER1_NODE],mean=0.0,stddev=0.1))
  62. biase1 = tf.Variable(tf.constant(value=0.1,dtype=tf.float32,shape=[LAYER1_NODE]))
  63. weights2 = tf.Variable(tf.truncated_normal(shape=[LAYER1_NODE,OUTPUT_NODE],mean=0.0,stddev=0.1))
  64. biase2 = tf.Variable(tf.constant(value=0.1,dtype=tf.float32,shape=[OUTPUT_NODE]))
  65.  
  66. # 前向更新
  67. # 训练数据时,不需要使用滑动平均模型,所以avg_class输入为空
  68. y = inference(x,None,weights1,biase1,weights2,biase2)
  69.  
  70. # 该变量记录训练次数,训练模型时常常需要设置为不可训练的变量,即trainable=False
  71. global_step = tf.Variable(initial_value=0,trainable=False)
  72.  
  73. # 生成滑动平均模型,用于验证
  74. variable_averages = tf.train.ExponentialMovingAverage(decay=MOVING_AVERAGE_DECAY,num_updates=global_step)
  75. # 在所有代表神经网络的可训练变量上,应用滑动模型,即所有的可训练变量都有一个影子变量
  76. variable_averages_ops = variable_averages.apply(tf.trainable_variables())
  77.  
  78. # 定义数据验证时,前向更新结果
  79. average_y = inference(x,variable_averages,weights1,biase1,weights2,biase2)
  80.  
  81. # 计算交叉熵
  82. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y_,1),logits=y)
  83. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  84.  
  85. # 计算L2正则化损失
  86. regularizer = l2_regularizer(REGULARIZATION_RATE)
  87. regularization = regularizer(weights1) + regularizer(weights2)
  88.  
  89. # 计算总损失Loss
  90. loss = cross_entropy_mean + regularization
  91.  
  92. # 定义指数衰减的学习率
  93. learning_rate = tf.train.exponential_decay(learning_rate=LEARNING_RATE_BASE,global_step=global_step,
  94. decay_steps=mnist_datasets.train.num_examples / BATCH_SIZE,
  95. decay_rate=LEARNING_RATE_DECAY)
  96.  
  97. # 定义随机梯度下降算法来优化损失函数
  98. train_step = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)\
  99. .minimize(loss = loss,global_step = global_step)
  100.  
  101. # 每次前向更新完以后,既需要反向更新参数值,又需要对滑动平均模型中影子变量进行更新
  102. # 和train_op = tf.group(train_step,variable_averages_ops)是等价的
  103. with tf.control_dependencies([train_step,variable_averages_ops]):
  104. train_op = tf.no_op(name='train')
  105.  
  106. # 定义验证运算,计算准确率
  107. correct_prediction = tf.equal(tf.argmax(average_y,1),tf.argmax(y_,1))
  108. accuracy = tf.reduce_mean(tf.cast(x=correct_prediction,dtype=tf.float32))
  109.  
  110. with tf.Session() as sess:
  111. init = tf.global_variables_initializer()
  112. sess.run(init)
  113.  
  114. validate_feed = {x:mnist_datasets.validation.images,
  115. y_:mnist_datasets.validation.labels}
  116. test_feed = {x:mnist_datasets.test.images,
  117. y_:mnist_datasets.test.labels}
  118.  
  119. for i in range(TRAINING_STEPS):
  120. # 每1000轮,用测试和验证数据分别对模型进行评估
  121. if i % 1000 == 0:
  122. validate_accuracy_rate = sess.run(accuracy,validate_feed)
  123. print("%s: After %d training steps(s),validation accuracy"
  124. "using average model is %g "%(datetime.datetime.now(),i,validate_accuracy_rate))
  125.  
  126. test_accuracy_rate = sess.run(accuracy, test_feed)
  127. print("%s: After %d training steps(s),test accuracy"
  128. "using average model is %g " % (datetime.datetime.now(),i, test_accuracy_rate))
  129.  
  130. validation_accuracy_rate_list.append(validate_accuracy_rate)
  131. test_accuracy_rate_list.append(test_accuracy_rate)
  132.  
  133. # 获得训练数据
  134. xs,ys = mnist_datasets.train.next_batch(BATCH_SIZE)
  135. sess.run(train_op,feed_dict={x:xs,y_:ys})
  136.  
  137. # 主程序入口
  138. def main(argv=None):
  139. mnist_datasets = mnist.read_data_sets(train_dir='MNIST_data/',one_hot=True)
  140. train(mnist_datasets)
  141. print("validation accuracy rate list:",validation_accuracy_rate_list)
  142. print("test accuracy rate list:",test_accuracy_rate_list)
  143.  
  144. # 模块入口
  145. if __name__ == '__main__':
  146. tf.app.run()

每1000轮,使用测试和验证数据分别对模型进行评估,绘制出如下准确率曲线图,其中蓝色曲线表示验证数据准确率,深红色曲线表示测试数据准确率,不难发现,通过引入滑动平均模型,模型在验证数据上有更好的准确率。

进一步,通过如下代码,我们对两个准确率求解相关系数:

  1. import numpy as np
  2. import math
  3.  
  4. x = np.array([0.1748, 0.9764, 0.9816, 0.9834, 0.982, 0.984, 0.9838, 0.9842, 0.9846, 0.985, 0.9848, 0.9854, 0.9854, 0.9838, 0.9846, 0.9838, 0.9848, 0.9844, 0.9846, 0.9858, 0.9846, 0.9848, 0.9852, 0.9844, 0.9846, 0.9848, 0.9852, 0.9846, 0.9852, 0.9854])
  5. y = np.array([0.1839, 0.9751, 0.9796, 0.9807, 0.9813, 0.9825, 0.983, 0.983, 0.983, 0.9829, 0.9836, 0.9831, 0.9828, 0.9832, 0.9828, 0.9829, 0.9836, 0.9835, 0.9838, 0.9833, 0.9833, 0.9833, 0.9833, 0.9838, 0.9835, 0.9838, 0.9829, 0.9836, 0.9834, 0.984])
  6.  
  7. # 计算相关度
  8. def computeCorrelation(x,y):
  9. xBar = np.mean(x)
  10. yBar = np.mean(y)
  11. SSR = 0.0
  12. varX = 0.0
  13. varY = 0.0
  14. for i in range(0,len(x)):
  15. diffXXbar = x[i] - xBar
  16. difYYbar = y[i] - yBar
  17. SSR += (diffXXbar * difYYbar)
  18. varX += diffXXbar**2
  19. varY += difYYbar**2
  20. SST = math.sqrt(varX * varY)
  21. return SSR/SST
  22.  
  23. # 计算R平方
  24. def polyfit(x,y,degree):
  25. results = {}
  26. coeffs = np.polyfit(x,y,degree)
  27. results['polynomial'] = coeffs.tolist()
  28. p = np.poly1d(coeffs)
  29. yhat = p(x)
  30. ybar = np.sum(y)/len(y)
  31. ssreg = np.sum((yhat - ybar)**2)
  32. sstot = np.sum((y - ybar)**2)
  33. results['determination'] = ssreg/sstot
  34. return results
  35.  
  36. result = computeCorrelation(x,y)
  37. r = result
  38. r_2 = result**2
  39. print("r:",r)
  40. print("r^2:",r*r)
  41. print(polyfit(x,y,1)['determination'])

结果显示,二者相关系数大于0.9999,这意味着在MNIST问题上,完全可以模型在验证数据上的表现来判断模型的优劣。当然,这个仅仅是MNIST数据集上,在其它问题上,还需要具体分析。

  1. C:\Users\Administrator\Anaconda3\python.exe D:/tensorflow-study/sample.py
  2. r: 0.9999913306679183
  3. r^2: 0.999982661410994
  4. 0.9999826614109977

day-19 多种优化模型下的简单神经网络tensorflow示例的更多相关文章

  1. 简单神经网络TensorFlow实现

    学习TensorFlow笔记 import tensorflow as tf #定义变量 #Variable 定义张量及shape w1= tf.Variable(tf.random_normal([ ...

  2. Python小白的数学建模课-19.网络流优化问题

    流在生活中十分常见,例如交通系统中的人流.车流.物流,供水管网中的水流,金融系统中的现金流,网络中的信息流.网络流优化问题是基本的网络优化问题,应用非常广泛. 网络流优化问题最重要的指标是边的成本和容 ...

  3. 通过/proc/sys/net/ipv4/优化Linux下网络性能

    通过/proc/sys/net/ipv4/优化Linux下网络性能 /proc/sys/net/ipv4/优化1)      /proc/sys/net/ipv4/ip_forward该文件表示是否打 ...

  4. MySQL数据库的优化(下)MySQL数据库的高可用架构方案

    MySQL数据库的优化(下)MySQL数据库的高可用架构方案 2011-03-09 08:53 抚琴煮酒 51CTO 字号:T | T 在上一篇MySQL数据库的优化中,我们跟随笔者学习了单机MySQ ...

  5. ios下最简单的正则,RegexKitLite

    ios下最简单的正则,RegexKitLite 1.去RegexKitLite下载类库,解压出来会有一个例子包及2个文件,其实用到的就这2个文件,添加到工程中.备用地址:http://www.coco ...

  6. 小型Web页打包优化(下)

    之前我们推送了一篇小型Web项目打包优化文章,(链接),我们使用了一段时间, 在这过程中我们也一直在思考, 怎么能把结构做的更好.于是我们改造了一版, 把可以改进的地方和可能会出现的问题, 在这一版中 ...

  7. ssdb主从及双主模型配置和简单管理

    ssdb主从及双主模型配置和简单管理 levelDB是一个key->value 的数据存储库,其只能在本地保存数据,支持持久化,并且支持保存非常大的数据,单机redis在保存较大数据的时候数十G ...

  8. 19.Mysql优化数据库对象

    19.优化数据库对象19.1 优化表的数据类型应用设计时需要考虑字段的类型和长度,并留有一定长度冗余.procedure analyse()函数可以对表中列的数据类型提出优化建议.procedure ...

  9. Windows下编译TensorFlow1.3 C++ library及创建一个简单的TensorFlow C++程序

    由于最近比较忙,一直到假期才有空,因此将自己学到的知识进行分享.如果有不对的地方,请指出,谢谢!目前深度学习越来越火,学习.使用tensorflow的相关工作者也越来越多.最近在研究tensorflo ...

随机推荐

  1. 对于PHP绘图技术的理解

    要使用PHP绘图,就得在php.ini文件中设置一下 找到这个位置 ;extension=php_gd2.dll,然后把前面的分号去掉,重启下apache就可以了 几乎每行代码我都写了注释,方便看懂 ...

  2. Difftime

    功 能:返回两个time_t型变量之间的时间间隔,即 计算两个时刻之间的时间差. 用 法: double difftime(time_t time2, time_t time1);

  3. C++练习 | 递归判断二叉树是否同构

    #include <iostream> using namespace std; struct Tree { int data; Tree *lchild; Tree *rchild; } ...

  4. Qt绘制动态曲线

    首先*.pro文件中加一句 QT += charts 然后 mainwindow.cpp文件如下: #include "mainwindow.h" #include "u ...

  5. 【Linux】管理文件系统

    文件系统概念: 文件系统是指文件的组织与管理结构,是一个有关于磁盘中各种有用信息的记录——即是保存以下信息的结构记录表 当前所使用磁盘的容量信息 磁盘的可用信息,包括已占用和剩余的空间: 文件与目录的 ...

  6. python学习笔记(一)学习资料记录

    相关资料网站 1. python3简明教程 适合新学者,因为可以在线操作,并且校验结果,同时还有考试系统.比较基础 2. python数据分析数据科学中文英文工具书籍下载 免费的中英文数据的PDF下载 ...

  7. Zeta--S3 Linux抓取一帧YUV图像后使用硬件编码器编码成H.264

    #include <stdio.h> #include <stdlib.h> #include <string.h> #include <getopt.h&g ...

  8. MAC下绕开百度网盘限速下载的方法,三步操作永久生效

    第一步:下载所需工具:(①②步我放在同一个文件夹,可一起下载,链接失效请留言) 工具地址:链接: https://pan.baidu.com/s/1raicYzM 密码: ve3n ①下载Aria2G ...

  9. 改脚本之dbscaner

    默认的DBscaner只是用了ipy模块支持一个段的解析,但是我想让他加载脚本进行检测 所以,直接看 def __init__(self, target, thread): self.target = ...

  10. 关于makefile中自动产生依赖的理解

    本博文是在学习了<GNU Make中文手册>后记录下来的自己的关于自动产生makefile依赖的语句的理解,向大家分享. <GNU make中文手册>中的相关章节见一下链接: ...