TensorFlow入门示例教程
本部分的代码目前都是基于GitHub大佬非常详细的TensorFlow的教程上,首先给出链接:
https://github.com/aymericdamien/TensorFlow-Examples/
本人对其中部分代码做了注释和中文翻译,会持续更新,目前包括:
1. 传统多层神经网络用语MNIST数据集分类(代码讲解,翻译)
1. 传统多层神经网络用语MNIST数据集分类(代码讲解,翻译)
- 1 """ Neural Network.
- 2
- 3 A 2-Hidden Layers Fully Connected Neural Network (a.k.a Multilayer Perceptron)
- 4 implementation with TensorFlow. This example is using the MNIST database
- 5 of handwritten digits (http://yann.lecun.com/exdb/mnist/).
- 6
- 7 Links:
- 8 [MNIST Dataset](http://yann.lecun.com/exdb/mnist/).
- 9
- 10 Author: Aymeric Damien
- 11 Project: https://github.com/aymericdamien/TensorFlow-Examples/
- 12 """
- 13
- 14 from __future__ import print_function
- 15
- 16 # Import MNIST data
- 17 # 导入mnist数据集
- 18 from tensorflow.examples.tutorials.mnist import input_data
- 19 mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
- 20
- 21 # 导入tf
- 22 import tensorflow as tf
- 23
- 24 # Parameters
- 25 # 设定各种超参数
- 26 learning_rate = 0.1 # 学习率
- 27 num_steps = 500 # 训练500次
- 28 batch_size = 128 # 每批次取128个样本训练
- 29 display_step = 100 # 每训练100步显示一次
- 30
- 31 # Network Parameters
- 32 # 设定网络的超参数
- 33 n_hidden_1 = 256 # 1st layer number of neurons
- 34 n_hidden_2 = 256 # 2nd layer number of neurons
- 35 num_input = 784 # MNIST data input (img shape: 28*28)
- 36 num_classes = 10 # MNIST total classes (0-9 digits)
- 37
- 38 # tf Graph input
- 39 # tf图的输入,因为不知道到底输入大小是多少,因此设定占位符
- 40 X = tf.placeholder("float", [None, num_input])
- 41 Y = tf.placeholder("float", [None, num_classes])
- 42
- 43 # Store layers weight & bias
- 44 # 初始化w和b
- 45 weights = {
- 46 'h1': tf.Variable(tf.random_normal([num_input, n_hidden_1])),
- 47 'h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
- 48 'out': tf.Variable(tf.random_normal([n_hidden_2, num_classes]))
- 49 }
- 50 biases = {
- 51 'b1': tf.Variable(tf.random_normal([n_hidden_1])),
- 52 'b2': tf.Variable(tf.random_normal([n_hidden_2])),
- 53 'out': tf.Variable(tf.random_normal([num_classes]))
- 54 }
- 55
- 56
- 57 # Create model
- 58 # 创建模型
- 59 def neural_net(x):
- 60 # Hidden fully connected layer with 256 neurons
- 61 # 隐藏层1,全连接了256个神经元
- 62 layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['b1'])
- 63 # Hidden fully connected layer with 256 neurons
- 64 # 隐藏层2,全连接了256个神经元
- 65 layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])
- 66 # Output fully connected layer with a neuron for each class
- 67 # 最后作为输出的全连接层,对每一分类连接一个神经元
- 68 out_layer = tf.matmul(layer_2, weights['out']) + biases['out']
- 69 return out_layer
- 70
- 71 # Construct model
- 72 # 开启模型
- 73 # 输入数据X,得到得分向量logits
- 74 logits = neural_net(X)
- 75 # 用softmax分类器将得分向量转变成概率向量
- 76 prediction = tf.nn.softmax(logits)
- 77
- 78 # Define loss and optimizer
- 79 # 定义损失和优化器
- 80 # 交叉熵损失, 求均值得到---->loss_op
- 81 loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
- 82 logits=logits, labels=Y))
- 83 # 优化器使用的是Adam算法优化器
- 84 optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
- 85 # 最小化损失得到---->可以训练的train_op
- 86 train_op = optimizer.minimize(loss_op)
- 87
- 88 # Evaluate model
- 89 # 评估模型
- 90 # tf.equal() 逐个元素进行判断,如果相等就是True,不相等,就是False。
- 91 correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
- 92 # tf.cast() 数据类型转换----> tf.reduce_mean() 再求均值
- 93 accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
- 94
- 95 # Initialize the variables (i.e. assign their default value)
- 96 # 初始化这些变量(作用比如说,给他们分配随机默认值)
- 97 init = tf.global_variables_initializer()
- 98
- 99 # Start training
- 100 # 现在开始训练啦!
- 101 with tf.Session() as sess:
- 102
- 103 # Run the initializer
- 104 # 运行初始化器
- 105 sess.run(init)
- 106
- 107 for step in range(1, num_steps+1):
- 108 # 每批次128个训练,取出这128个对应的data:x;标签:y
- 109 batch_x, batch_y = mnist.train.next_batch(batch_size)
- 110 # Run optimization op (backprop)
- 111 # train_op是优化器得到的可以训练的op,通过反向传播优化模型
- 112 sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
- 113 # 每100步打印一次训练的成果
- 114 if step % display_step == 0 or step == 1:
- 115 # Calculate batch loss and accuracy
- 116 # 计算每批次的是损失和准确度
- 117 loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x,
- 118 Y: batch_y})
- 119 print("Step " + str(step) + ", Minibatch Loss= " + \
- 120 "{:.4f}".format(loss) + ", Training Accuracy= " + \
- 121 "{:.3f}".format(acc))
- 122
- 123 print("Optimization Finished!")
- 124
- 125 # Calculate accuracy for MNIST test images
- 126 # 看看在测试集上,我们的模型表现如何
- 127 print("Testing Accuracy:", \
- 128 sess.run(accuracy, feed_dict={X: mnist.test.images,
- 129 Y: mnist.test.labels}))
TensorFlow入门示例教程的更多相关文章
- [转] Struts2入门示例教程
原文地址:http://blog.csdn.net/wwwgeyang777/article/details/19078545/ 回顾Struts2的使用过程,网上搜的教程多多少少都会有点问题,重新记 ...
- DWR 3.0 入门示例教程
DWR(Direct Web Remoting) DWR is a Java library that enables Java on the server and JavaScript in a b ...
- Step by Step 真正从零开始,TensorFlow详细安装入门图文教程!帮你完成那个最难的从0到1
摘要: Step by Step 真正从零开始,TensorFlow详细安装入门图文教程!帮你完成那个最难的从0到1 安装遇到问题请文末留言. 悦动智能公众号:aibbtcom AI这个概念好像突然就 ...
- 【转】真正从零开始,TensorFlow详细安装入门图文教程!(帮你完成那个最难的从0到1)
AI这个概念好像突然就火起来了,年初大比分战胜李世石的AlphaGo成功的吸引了大量的关注,但其实看看你的手机上的语音助手,相机上的人脸识别,今日头条上帮你自动筛选出来的新闻,还有各大音乐软件的歌曲& ...
- 真正从零开始,TensorFlow详细安装入门图文教程!
本文转载地址:https://www.leiphone.com/news/201606/ORlQ7uK3TIW8xVGF.html AI这个概念好像突然就火起来了,年初大比分战胜李世石的AlphaGo ...
- TensorFlow入门,基本介绍,基本概念,计算图,pip安装,helloworld示例,实现简单的神经网络
TensorFlow入门,基本介绍,基本概念,计算图,pip安装,helloworld示例,实现简单的神经网络
- TensorFlow入门教程集合
TensorFlow入门教程之0: BigPicture&极速入门 TensorFlow入门教程之1: 基本概念以及理解 TensorFlow入门教程之2: 安装和使用 TensorFlow入 ...
- ASP.NET Aries 入门开发教程7:DataGrid的行操作(主键操作区)
前言: 抓紧勤奋,再接再励,预计共10篇来结束这个系列. 上一篇介绍:ASP.NET Aries 入门开发教程6:列表数据表格的格式化处理及行内编辑 本篇介绍主键操作区相关内容. 1:什么时候有默认的 ...
- ASP.NET Aries 入门开发教程6:列表数据表格的格式化处理及行内编辑
前言: 为了赶进度,周末也写文了! 前几篇讲完查询框和工具栏,这节讲表格数据相关的操作. 先看一下列表: 接下来我们有很多事情可以做. 1:格式化 - 键值的翻译 对于“启用”列,已经配置了格式化 # ...
随机推荐
- java.lang.IllegalArgumentException: Failed to register servlet with name 'dispatcher'.Check if there is another servlet registered under the same name
前言 一年前接手了一个项目,项目始终无法运行,不管咋样,都无法处理,最近,在一次尝试中,终于成功处理了. 含义 意思很明显了,注册了一个相同的dispatcher,可是找了很久,没有相同的Contro ...
- 如何完整删除DISK DRILL
前两天装了DISK DRILL 右上角出现一个温度提示的图标 现在把DISK DRILL卸载了 但右上角的温度提示图标仍然存在 请问如何删除? 打开系统偏好设置----用户与群----管理员(点 ...
- vi TOhtml:复制保持格式和高亮
1. 文本编辑:在vim中编辑好,复制到opera mail中就会格式错乱,比如:行前空格.缩进消失:2. 代码复制到其他地方,无法显示彩色高亮:找到了一个变通方案:使用TOhtml把vim内容转换为 ...
- Android 如何让EditText不自动获取焦点&隐藏软键盘
感谢大佬:https://blog.csdn.net/a18615971648/article/details/72869345 有时候的项目当中进入某个页面edittext会自动获取焦点弹出软键盘, ...
- Linux 内核引导参数简介
概述 内核引导参数大体上可以分为两类:一类与设备无关.另一类与设备有关.与设备有关的引导参数多如牛毛,需要你自己阅读内核中的相应驱动程序源码以获取其能够接受的引导参数.比如,如果你想知道可以向 AHA ...
- 一键部署mysql 无修改直接cp 执行 100% 有效
一键部署mysql 无修改直接cp 执行 100% 有效 将安装包拖至/opt目录下,编一个脚本文件,然后source执行脚本,等脚本执行完成, 即可使用mysql -u root -p点击 ...
- unittest基础篇1
转自http://blog.csdn.net/huilan_same/article/details/52944782 unittest是xUnit系列框架中的一员,如果你了解xUnit的其他成员,那 ...
- python基础语法_闭包详解
https://www.cnblogs.com/Lin-Yi/p/7305364.html 闭包有啥用??!! 很多伙伴很糊涂,闭包有啥用啊??还这么难懂! 3.1装饰器!!!装饰器是做什么的??其 ...
- 神奇小证明之——世界上只有5个正多面体+构造x3=2a3
今天我彻底放飞自我了...作业还没写完...但就是要总结一些好玩的小性质...谁给我的勇气呢?
- Solution -「CodeChef JUMP」Jump Mission
\(\mathcal{Description}\) Link. 有 \(n\) 个编号 \(1\sim n\) 的格子排成一排,并有三个权值序列 \(\{a_n\},\{h_n\},\{p_n ...