TensorFlow初探之简单神经网络训练mnist数据集(TensorFlow2.0代码)
- from __future__ import print_function
- from tensorflow.examples.tutorials.mnist import input_data
- #加载数据集
- mnist = input_data.read_data_sets(r"C:/Users/HPBY/tem/data/",one_hot=True)#加载本地数据 以独热编码形式
- import tensorflow as tf
- #设置超参
- learning_rate = 0.01 #设置学习率
- num_step = #训练次数
- batch_size = #批次
- display_step = #多少次显示一次结果
- #设置网络参数
- n_hidden_1 = #隐含层1 256节点
- n_hidden_2 = #隐含层2 256节点
- num_inputs = #输入一位向量28*
- num_class = #-9的数字一共10个分类
- X = tf.placeholder("float",[None, num_inputs]#占位符784输入 10输出
- Y = tf.placeholder("float",[None, num_class])
- # 储存网络层权重和偏置值
- weights={#随机初始化并权重和偏置值
- 'h1' : tf.Variable(tf.random_normal([num_inputs, n_hidden_1])),
- 'h2' : tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
- 'out' : tf.Variable(tf.random_normal([n_hidden_2, num_class]))
- }
- biases = {
- 'b1' : tf.Variable(tf.random_normal([n_hidden_1])),
- 'b2' : tf.Variable(tf.random_normal([n_hidden_2])),
- 'out' :tf.Variable(tf.random_normal([num_class]))
- }
- #创建模型
- def neural_net(x):
- #全连接隐含层1,2隐含层256个节点
- layer_1 = tf.add(tf.matmul(x,weights['h1']), biases['b1'])#matmul是计算
- layer_2 = tf.add(tf.matmul(layer_1, weights['h2']),biases['b2'])
- out_layer = tf.matmul(layer_2, weights['out'])+biases['out']
- return out_layer
- #构建模型
- logits = neural_net(X)
- #定义损失函数和优化器
- loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
- logits=logits, labels=Y))
- opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
- train_op = opt.minimize(loss_op)
- #评价模型
- correct_pred = tf.equal(tf.argmax(logits,), tf.argmax(Y, ))
- accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
- #初始化变量
- init = tf.global_variables_initializer()
- #开始训练
- with tf.Session() as sess:
- sess.run(init)
- for step in range(,num_step+):
- batch_x, batch_y =mnist.train.next_batch(batch_size)
- sess.run(train_op,feed_dict={X:batch_x, Y:batch_y})
- if step % display_step == or step == :
- loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x,
- Y: batch_y})
- print("Step " + str(step) + ", Minibatch Loss= " + \
- "{:.4f}".format(loss) + ", Training Accuracy= " + \
- "{:.3f}".format(acc))
- print("Optimization Finished!")
- # Calculate accuracy for MNIST test images
- print("Testing Accuracy:", \
- sess.run(accuracy, feed_dict={X: mnist.test.images,
- Y: mnist.test.labels}))
数据集来自 http://yann.lecun.com/exdb/mnist/ 以本地加载方式加载数据集
神经网络模型如下:
独热编码参考https://www.cnblogs.com/zongfa/p/9305657.html
很简单的一种编码方式也经常用到
比如我们有“今天刀塔本子出了吗”这个形式的9个不同的词,那么我们独热编码就会形成一个九维的向量,
今是第1个词表示的向量为[1,0,0,0,0,0,0,0,0]
刀是第3个词表示的向量为[0,0,1,0,0,0,0,0,0]
神经网络原理与推导参考程序媛小姐姐的BP神经网络讲解,非常详细:http://www.cnblogs.com/charlotte77/p/5629865.html
TensorFlow初探之简单神经网络训练mnist数据集(TensorFlow2.0代码)的更多相关文章
- TensorFlow——LSTM长短期记忆神经网络处理Mnist数据集
1.RNN(Recurrent Neural Network)循环神经网络模型 详见RNN循环神经网络:https://www.cnblogs.com/pinard/p/6509630.html 2. ...
- 搭建简单模型训练MNIST数据集
# -*- coding = utf-8 -*- # @Time : 2021/3/16 # @Author : pistachio # @File : test1.py # @Software : ...
- Tensorflow学习教程------普通神经网络对mnist数据集分类
首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...
- 使用一层神经网络训练mnist数据集
import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...
- mxnet卷积神经网络训练MNIST数据集测试
mxnet框架下超全手写字体识别—从数据预处理到网络的训练—模型及日志的保存 import numpy as np import mxnet as mx import logging logging. ...
- TensorFlow 训练MNIST数据集(2)—— 多层神经网络
在我的上一篇随笔中,采用了单层神经网络来对MNIST进行训练,在测试集中只有约90%的正确率.这次换一种神经网络(多层神经网络)来进行训练和测试. 1.获取MNIST数据 MNIST数据集只要一行代码 ...
- TensorFlow——CNN卷积神经网络处理Mnist数据集
CNN卷积神经网络处理Mnist数据集 CNN模型结构: 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5*5,步长为1,卷积核:32个 第一层池化:池化视野2*2,步长为2 第二层卷积 ...
- 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集
上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...
- 使用caffe训练mnist数据集 - caffe教程实战(一)
个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...
随机推荐
- KiCad 一款强大的 BOM 和 装配图生成插件
KiCad 一款强大的 BOM 和 装配图生成插件 可以生成 BOM 和在线的图形. https://github.com/openscopeproject/InteractiveHtmlBom In ...
- python2和3的区别
一.默认编码 2:ascii 3:utf-8 二.数字 python3无long
- springmvc的dispatchservlet初始化
初始化做的事情,处理下controller的映射关系 https://blog.csdn.net/qq_38410730/article/details/79426673
- Application、QueryString、session、cookie、ViewState、Server.Transfer等
Application: WebForm1.aspx: protected void Button1_Click(object sender, EventArgs e) { ; Response.Re ...
- STM32 USB-三个HID-interface 复合(组合)设备的代码实现-基于固件库(原创)
一.概论: 在STM32_USB-FS-Device_Lib_V4.1.0的Custom_HID工程基础上进行修改: 开发一款设备,有三个HID接口,mouse+pen+自定义HID 其中:0_HID ...
- tensorflow报cudnn错误
E tensorflow/stream_executor/cuda/cuda_dnn.cc:363] Loaded runtime CuDNN library: 7.0.5 but source wa ...
- Azure CosmosDB (13) CosmosDB数据建模
<Windows Azure Platform 系列文章目录> 我们在使用NoSQL的时候,如Azure Cosmos DB,可以非常快速的查询非结构化,或半结构化的数据.我们需要花一些时 ...
- JIRA的邮件通知
提交测试或提交上线申请时发送给相关的开发人员.测试人员.运维人员. 使用插件Notification
- meter命令行模式运行,实时获取压测结果 (没试过 说不定以后要用)
jmeter很小,很快,使用方便,可以在界面运行,可以命令行运行.简单介绍下命令行运行的方式 上面一条命令应该可以满足大部分需求. 使用-R指定节点时,当然要首先在这些节点上启动jmeter-serv ...
- 廖雪峰Java8JUnit单元测试-2使用JUnit-4超时测试
1.超时测试 可以为JUnit的单个测试设置超时: 超时设置1秒:@Test(timeout=1000),单位为毫秒 2.示例 Leibniz定理:PI/4= 1 - 1/3 + 1/5 - 1/7 ...