使用之前那个格式写法到后面层数多的话会很乱,所以编写了一个函数创建层,这样看起来可读性高点也更方便整理后期修改维护

  1. #全连接层函数
  2.  
  3. def fcn_layer(
  4. inputs, #输入数据
  5. input_dim, #输入层神经元数量
  6. output_dim,#输出层神经元数量
  7. activation =None): #激活函数
  8.  
  9. W = tf.Variable(tf.truncated_normal([input_dim,output_dim],stddev = 0.1))
  10. #以截断正态分布的随机初始化W
  11. b = tf.Variable(tf.zeros([output_dim]))
  12. #以0初始化b
  13. XWb = tf.matmul(inputs,W)+b # Y=WX+B
  14.  
  15. if(activation==None): #默认不使用激活函数
  16. outputs =XWb
  17. else:
  18. outputs = activation(XWb) #代入参数选择的激活函数
  19. return outputs #返回
  1. #各层神经元数量设置
  2. H1_NN = 256
  3. H2_NN = 64
  4. H3_NN = 32
  5.  
  6. #构建输入层
  7. x = tf.placeholder(tf.float32,[None,784],name='X')
  8. y = tf.placeholder(tf.float32,[None,10],name='Y')
  9. #构建隐藏层
  10. h1 = fcn_layer(x,784,H1_NN,tf.nn.relu)
  11. h2 = fcn_layer(h1,H1_NN,H2_NN,tf.nn.relu)
  12. h3 = fcn_layer(h2,H2_NN,H3_NN,tf.nn.relu)
  13. #构建输出层
  14. forward = fcn_layer(h3,H3_NN,10,None)
  15. pred = tf.nn.softmax(forward)#输出层分类应用使用softmax当作激活函数

这样写方便后期维护 不必对着一群 W1 W2..... Wn

接下来记录一下保存模型的方法

  1. #保存模型
  2. save_step = 5 #储存模型力度
  3. import os
  4. ckpt_dir = '.ckpt_dir/'
  5. if not os.path.exists(ckpt_dir):
  6. os.makedirs(ckpt_dir)

  5轮训练保存一次,以后大模型可以调高点,接下来需要在模型整合处修改一下

  1. saver = tf.train.Saver() #声明完所有变量以后,调用tf.train.Saver开始记录

  2. if(epochs+1) % save_step == 0:
      saver.save(sess, os.path.join(ckpt_dir,"mnist_h256_model_{:06d}.ckpt".format(epochs+1)))#储存模型
      print("mnist_h256_model_{:06d}.ckpt saved".format(epochs+1))#输出情况

至此储存模型结束

接下来是还原模型,要注意还原的模型层数和神经元数量大小需要和之前储存模型的大小一致。

第一步设置保存模型文件的路径

  1. #必须指定存储位置
  2. ckpt_dir = "/ckpt_dir/"

存盘只会保存最近的5次,恢复会恢复最新那一份

  1. #恢复模型,创建会话
  2.  
  3. saver = tf.train.Saver()
  4.  
  5. sess = tf.Session()
  6. init = tf.global_variables_initializer()
  7. sess.run(init)
  8.  
  9. ckpt = tf.train.get_checkpoint_state(ckpt_dir)#选择模型保存路径
  10. if ckpt and ckpt.model_checkpoint_path:
  11. saver.restore(sess ,ckpt.model_checkpoint_path)#从已保存模型中读取参数
  12. print("Restore model from"+ckpt.model_checkpoint_path)

 至此模型恢复完成 下面可以选择继续训练或者评估使用

最后附上完整代码

  1. import tensorflow as tf
  2. import tensorflow.examples.tutorials.mnist.input_data as input_data
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from time import time
  6. mnist = input_data.read_data_sets("data/",one_hot = True)
  7. #导入Tensorflwo和mnist数据集等 常用库
  8. #全连接层函数
  9.  
  10. def fcn_layer(
  11. inputs, #输入数据
  12. input_dim, #输入层神经元数量
  13. output_dim,#输出层神经元数量
  14. activation =None): #激活函数
  15.  
  16. W = tf.Variable(tf.truncated_normal([input_dim,output_dim],stddev = 0.1))
  17. #以截断正态分布的随机初始化W
  18. b = tf.Variable(tf.zeros([output_dim]))
  19. #以0初始化b
  20. XWb = tf.matmul(inputs,W)+b # Y=WX+B
  21.  
  22. if(activation==None): #默认不使用激活函数
  23. outputs =XWb
  24. else:
  25. outputs = activation(XWb) #代入参数选择的激活函数
  26. return outputs #返回
  27. #各层神经元数量设置
  28. H1_NN = 256
  29. H2_NN = 64
  30. H3_NN = 32
  31.  
  32. #构建输入层
  33. x = tf.placeholder(tf.float32,[None,784],name='X')
  34. y = tf.placeholder(tf.float32,[None,10],name='Y')
  35. #构建隐藏层
  36. h1 = fcn_layer(x,784,H1_NN,tf.nn.relu)
  37. h2 = fcn_layer(h1,H1_NN,H2_NN,tf.nn.relu)
  38. h3 = fcn_layer(h2,H2_NN,H3_NN,tf.nn.relu)
  39. #构建输出层
  40. forward = fcn_layer(h3,H3_NN,10,None)
  41. pred = tf.nn.softmax(forward)#输出层分类应用使用softmax当作激活函数
  42. #损失函数使用交叉熵
  43. loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = forward,labels = y))
  44. #设置训练参数
  45. train_epochs = 50
  46. batch_size = 50
  47. total_batch = int(mnist.train.num_examples/batch_size) #随机抽取样本
  48. learning_rate = 0.01
  49. display_step = 1
  50. #优化器
  51. opimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function)
  52. #定义准确率
  53. correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))
  54. accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  55. #保存模型
  56. save_step = 5 #储存模型力度
  57. import os
  58. ckpt_dir = '.ckpt_dir/'
  59. if not os.path.exists(ckpt_dir):
  60. os.makedirs(ckpt_dir)
  61. #开始训练
  62. sess = tf.Session()
  63. init = tf.global_variables_initializer()
  64. saver = tf.train.Saver() #声明完所有变量以后,调用tf.train.Saver开始记录
  65. startTime = time()
  66. sess.run(init)
  67. for epochs in range(train_epochs):
  68. for batch in range(total_batch):
  69. xs,ys = mnist.train.next_batch(batch_size)#读取批次数据
  70. sess.run(opimizer,feed_dict={x:xs,y:ys})#执行批次数据训练
  71.  
  72. #total_batch个批次训练完成后,使用验证数据计算误差与准确率
  73. loss,acc = sess.run([loss_function,accuracy],
  74. feed_dict={
  75. x:mnist.validation.images,
  76. y:mnist.validation.labels})
  77. #输出训练情况
  78. if(epochs+1) % display_step == 0:
  79. epochs += 1
  80. print("Train Epoch:",epochs,
  81. "Loss=",loss,"Accuracy=",acc)
  82. if(epochs+1) % save_step == 0:
  83. saver.save(sess, os.path.join(ckpt_dir,"mnist_h256_model_{:06d}.ckpt".format(epochs+1)))
  84. print("mnist_h256_model_{:06d}.ckpt saved".format(epochs+1))
  85. duration = time()-startTime
  86. print("Trian Finshed takes:","{:.2f}".format(duration))#显示预测耗时
  87. #评估模型
  88. accu_test = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
  89. print("model accuracy:",accu_test)
  90. #恢复模型,创建会话
  91.  
  92. saver = tf.train.Saver()
  93.  
  94. sess = tf.Session()
  95. init = tf.global_variables_initializer()
  96. sess.run(init)
  97.  
  98. ckpt = tf.train.get_checkpoint_state(ckpt_dir)#选择模型保存路径
  99. if ckpt and ckpt.model_checkpoint_path:
  100. saver.restore(sess ,ckpt.model_checkpoint_path)#从已保存模型中读取参数
  101. print("Restore model from"+ckpt.model_checkpoint_path)

完整代码

  

基于tensorflow使用全连接层函数实现多层神经网络并保存和读取模型的更多相关文章

  1. 深度学习原理与框架-卷积网络细节-图像分类与图像位置回归任务 1.模型加载 2.串接新的全连接层 3.使用SGD梯度对参数更新 4.模型结果测试 5.各个模型效果对比

    对于图像的目标检测任务:通常分为目标的类别检测和目标的位置检测 目标的类别检测使用的指标:准确率, 预测的结果是类别值,即cat 目标的位置检测使用的指标:欧式距离,预测的结果是(x, y, w, h ...

  2. 基于tensorflow实现mnist手写识别 (多层神经网络)

    标题党其实也不多,一个输入层,三个隐藏层,一个输出层 老样子先上代码 导入mnist的路径很长,现在还记不住 import tensorflow as tf import tensorflow.exa ...

  3. tensorflow 1.0 学习:池化层(pooling)和全连接层(dense)

    池化层定义在 tensorflow/python/layers/pooling.py. 有最大值池化和均值池化. 1.tf.layers.max_pooling2d max_pooling2d( in ...

  4. tensorflow 添加一个全连接层

    对于一个全连接层,tensorflow都为我们封装好了. 使用:tf.layers.dense() tf.layers.dense( inputs, units, activation=None, u ...

  5. keras channels_last、preprocess_input、全连接层Dense、SGD优化器、模型及编译

    channels_last 和 channels_first keras中 channels_last 和 channels_first 用来设定数据的维度顺序(image_data_format). ...

  6. resnet18全连接层改成卷积层

    想要尝试一下将resnet18最后一层的全连接层改成卷积层看会不会对网络效果和网络大小有什么影响 1.首先先对train.py中的更改是: train.py代码可见:pytorch实现性别检测 # m ...

  7. Caffe源码阅读(1) 全连接层

    Caffe源码阅读(1) 全连接层 发表于 2014-09-15   |   今天看全连接层的实现.主要看的是https://github.com/BVLC/caffe/blob/master/src ...

  8. 深度学习基础系列(十)| Global Average Pooling是否可以替代全连接层?

    Global Average Pooling(简称GAP,全局池化层)技术最早提出是在这篇论文(第3.2节)中,被认为是可以替代全连接层的一种新技术.在keras发布的经典模型中,可以看到不少模型甚至 ...

  9. TensorFlow------单层(全连接层)实现手写数字识别训练及测试实例

    TensorFlow之单层(全连接层)实现手写数字识别训练及测试实例: import tensorflow as tf from tensorflow.examples.tutorials.mnist ...

随机推荐

  1. 网站如何接入第三方登录,微信登录和QQ登录:注册认证篇

    第三方登录平台接入 (QQ\微信登录) QQ登录接入 第一步成为QQ应用开发者,审核期限七天 一.所需材料 1.公司注册相关信息 2.营业执照扫描件 微信登录接入 第一步成为微信开发平台开发者,认证费 ...

  2. 4、url控制系统

    第1节:简单配置 参考代码: from django.conf.urls import url from . import views urlpatterns = [ url(r'^articles/ ...

  3. Node.js实战(三)之第一个Web服务器

    这次的示例同样也可以说是HelloWorld,只不过不同的是这是web服务器示例. (1)编写web.js,内容如下: var http = require("http") fun ...

  4. 青岛大学开源OJ平台搭建

    源码地址为:https://github.com/QingdaoU/OnlineJudge 可参考的文档为:https://github.com/QingdaoU/OnlineJudgeDeploy/ ...

  5. Leetcode——198. 打家劫舍

    题目描述:题目链接 这道题目也是一道动态规划的题目: 分析一道动态规划的题目可以将解决问题的思路分为下面三个部分: 1:问题的描述.可以定义数组d[ i ] 用于表示第i -1家可以获得的最大金额. ...

  6. js 自己项目中几种打开或弹出页面的方法

    自己项目中,几种打开或弹出页面的方法(部分需要特定环境下) var blnTop = false;//是否在顶层显示 ///动态生成模态窗体(通过字符串生成) ///strModalId:模态窗体ID ...

  7. mysql中查看一个字段中,有几个逗号

    利用replace.length的内置函数

  8. 【js】AddFavorite/SetHome提醒用户自行操作加入收藏/设置主页

    除了老版本的ie, 就已经没有浏览器能支持js添加收藏夹和设置首页, 浏览器没有开放这个权限了,external.addFavorite这个给禁了. 不过AddFavorite可以起到提醒用户自行操作 ...

  9. jQuery对底部导航进行跳转并高亮显示

    这两天弄一个mui的底部菜单,有点费时了,尝试了用vue写,纯js写,还有根据mui的写,还是有些问题和麻烦.直到看了网上的一些例子,才想明白,之前一直是一种点击触发事件才高亮的思维去做,这个虽然可以 ...

  10. UWP 下载文件显示下载进度

    <Page x:Class="WgscdProject.TestDownloadPage" xmlns="http://schemas.microsoft.com/ ...