我自己对mnist官方例程进行了部分注解,希望分享出来有助于入门选手更好理解tensorflow的运行机制,可以拷贝到IDE再调试看看,看看具体数据流向还有一部分tensorflow里面用到的库。
我用的是pip安装的tensorflow-GPU-1.13,这段源码原始位置在https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py

代码:

 from __future__ import absolute_import
from __future__ import division
from __future__ import print_function #absl是python标准库内的
from absl import app as absl_app
from absl import flags import tensorflow as tf # pylint: disable=g-bad-import-order from official.mnist import dataset
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers LEARNING_RATE = 1e-4 #参数默认data_format = 'channels_first'
def create_model(data_format):
"""Model to recognize digits in the MNIST dataset. Network structure is equivalent to:
https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
and
https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py But uses the tf.keras API. Args:
data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
typically faster on GPUs while 'channels_last' is typically faster on
CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats Returns:
A tf.keras.Model.
""" #data_format:一个字符串,可以是channels_last(默认)或channels_first,\
# 表示输入中维度的顺序,channels_last对应于具有形状(batch, height, width, channels)\
# 的输入,而channels_first对应于具有形状(batch, channels, height, width)的输入.
#这里感觉输入只有三个维度,默认是单通道图?
if data_format == 'channels_first':
input_shape = [1, 28, 28]
else:
assert data_format == 'channels_last'
input_shape = [28, 28, 1] #将tf.keras.layers.MaxPooling2D传递给max_pool
l = tf.keras.layers
max_pool = l.MaxPooling2D(
(2, 2), (2, 2), padding='same', data_format=data_format)
# The model consists of a sequential chain of layers, so tf.keras.Sequential
# (a subclass of tf.keras.Model) makes for a compact description.
return tf.keras.Sequential(
[
#输入层确保输入的大小符合网络需要[28, 28]->[1, 28, 28]
l.Reshape(
target_shape=input_shape,
input_shape=(28 * 28,)),
#卷积
l.Conv2D(
32,#filters:整数, 输出空间的维数(即卷积中的滤波器数),就是卷积核个数
5,#卷积核大小,这里是5x5
padding='same',
data_format=data_format,
activation=tf.nn.relu),
#最大pooling
max_pool,
#卷积
l.Conv2D(
64,
5,
padding='same',
data_format=data_format,
activation=tf.nn.relu),
# 最大pooling
max_pool,
#在保留第0轴的情况下对输入的张量进行Flatten(扁平化),拉直?
l.Flatten(),
#fc 1024 -> units: 该层的神经单元结点数。
l.Dense(1024, activation=tf.nn.relu),
l.Dropout(0.4),
#fc输出
l.Dense(10)
]) #添加了很多参数,指定了一部分的值,数据url,模型url,batch_size等等
def define_mnist_flags():
flags_core.define_base()
flags_core.define_performance(num_parallel_calls=False)
flags_core.define_image()
flags.adopt_module_key_flags(flags_core)
#自定义项参数都在这里设置了
flags_core.set_defaults(data_dir='./tmp/mnist_data',
model_dir='./tmp/mnist_model',
batch_size=100,
train_epochs=40,
stop_threshold=0.998) def model_fn(features, labels, mode, params):
"""The model_fn argument for creating an Estimator."""
# 翻译成中文,注释的意思就是添加一个data_format的参数,下面的Estimator类需要用到
model = create_model(params['data_format'])
image = features
# 来判断一个对象是否是一个已知的类型。
if isinstance(image, dict):
image = features['image'] #测试模式
if mode == tf.estimator.ModeKeys.PREDICT:
logits = model(image, training=False)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits),
}
#如果只是测试到这里就返回了
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions)
}) #训练模式
if mode == tf.estimator.ModeKeys.TRAIN:
#设置LEARNING_RATE
optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE) logits = model(image, training=True)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
accuracy = tf.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1)) # Name tensors to be logged with LoggingTensorHook.
tf.identity(LEARNING_RATE, 'learning_rate')
tf.identity(loss, 'cross_entropy')
tf.identity(accuracy[1], name='train_accuracy') # Save accuracy scalar to Tensorboard output.
tf.summary.scalar('train_accuracy', accuracy[1]) return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
if mode == tf.estimator.ModeKeys.EVAL:
logits = model(image, training=False)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
loss=loss,
eval_metric_ops={
'accuracy':
tf.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1)),
}) def run_mnist(flags_obj):
"""Run MNIST training and eval loop. Args:
flags_obj: An object containing parsed flag values.
""" #apply_clean是官方例程里面提供的用来清理现存model的方法,\
# 取决于flags_obj.clean(True则清理flags_obj.model_dir内的文件)
model_helpers.apply_clean(flags_obj) #把自定义的实现传给tf.estimator.Estimator
model_function = model_fn #tf.ConfigProto()主要的作用是配置tf.Session的运算方式,比如gpu运算或者cpu运算
session_config = tf.ConfigProto(
#设置线程一个操作内部并行运算的线程数,比如矩阵乘法,如果设置为0,则表示以最优的线程数处理
inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
#设置多个操作并行运算的线程数,比如 c = a + b,d = e + f . 可以并行运算
intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
#有时候,不同的设备,它的cpu和gpu是不同的,如果将这个选项设置成True,\
# 那么当运行设备不满足要求时,会自动分配GPU或者CPU
allow_soft_placement=True) #获取gpu数目,优化算法等,用于优化
distribution_strategy = distribution_utils.get_distribution_strategy(
flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) #所有输出(检查点,事件文件等)都被写入model_dir或其子目录.如果model_dir未设置,则使用临时目录.
#可以通过RunConfig对象(包含了有关执行环境的信息)传递config参数.它被传递给model_fn,\
# 如果model_fn有一个名为“config”的参数(和输入函数以相同的方式).如果该config参数未被传递,\
# 则由Estimator进行实例化.不传递配置意味着使用对本地执行有用的默认值.Estimator使配置对模型\
# 可用(例如,允许根据可用的工作人员数量进行专业化),并且还使用其一些字段来控制内部,特别是关于检查点
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy, session_config=session_config) data_format = flags_obj.data_format
#channels_first,即(3,128,128,128)通道数在最前面
#channels_last,即(128,128,128,3)通道数在最后面
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')#判断安装的TF是否支持GPU #estimator类对TensorFlow模型进行训练和计算.
#Estimator对象包装由model_fn指定的模型,其中,给定输入和其他一些参数,返回需要进行训练、计算,或预测的操作.
mnist_classifier = tf.estimator.Estimator(
#这个model_fn是参数名而已
model_fn=model_function,#模型对象
model_dir=flags_obj.model_dir,#模型目录,如果为空会创建一个临时目录
#猜测会去model_dir中寻找数据
config=run_config,#运行的一些参数
params={
'data_format': data_format,#数据类型
}) #这里定义了两个内部函数,只能被这个语句块的内部调用
# Set up training and evaluation input functions.
def train_input_fn():
"""Prepare data for training.""" # When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
ds = dataset.train(flags_obj.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size) # Iterate through the dataset a set number (`epochs_between_evals`) of times
# during each training session.
ds = ds.repeat(flags_obj.epochs_between_evals)
return ds def eval_input_fn():
return dataset.test(flags_obj.data_dir).batch(
flags_obj.batch_size).make_one_shot_iterator().get_next() # Set up hook that outputs training logs every 100 steps.
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks, model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size) # Train and evaluate model.
for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
#训练一次,验证一次
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results) #如果eval_results['accuracy'] >= flags_obj.stop_threshold 说明模型训练好了
if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
eval_results['accuracy']):
break # Export the model
if flags_obj.export_dir is not None:
#预分配内存,等待数据进入
image = tf.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': image,
})
#输出模型
mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn) def main(_):
run_mnist(flags.FLAGS) if __name__ == '__main__':
#日志
tf.logging.set_verbosity(tf.logging.INFO)
#给flags.FLAGS添加了很多参数项目
define_mnist_flags()
#带参数的启动
absl_app.run(main)

tensorflow--mnist注解的更多相关文章

  1. TensorFlow MNIST(手写识别 softmax)实例运行

    TensorFlow MNIST(手写识别 softmax)实例运行 首先要有编译环境,并且已经正确的编译安装,关于环境配置参考:http://www.cnblogs.com/dyufei/p/802 ...

  2. 学习笔记TF056:TensorFlow MNIST,数据集、分类、可视化

    MNIST(Mixed National Institute of Standards and Technology)http://yann.lecun.com/exdb/mnist/ ,入门级计算机 ...

  3. TensorFlow MNIST 问题解决

    TensorFlow MNIST 问题解决 一.数据集下载错误 错误:IOError: [Errno socket error] [Errno 101] Network is unreachable ...

  4. Mac tensorflow mnist实例

    Mac tensorflow mnist实例 前期主要需要安装好tensorflow的环境,Mac 如果只涉及到CPU的版本,推荐使用pip3,傻瓜式安装,一行命令!代码使用python3. 在此附上 ...

  5. tensorflow MNIST Convolutional Neural Network

    tensorflow MNIST Convolutional Neural Network MNIST CNN 包含的几个部分: Weight Initialization Convolution a ...

  6. tensorflow MNIST新手教程

    官方教程代码如下: import gzip import os import tempfile import numpy from six.moves import urllib from six.m ...

  7. TensorFlow MNIST初级学习

    MNIST MNIST 是一个入门级计算机视觉数据集,包含了很多手写数字图片,如图所示: 数据集中包含了图片和对应的标注,在 TensorFlow 中提供了这个数据集,我们可以用如下方法进行导入: f ...

  8. 学习笔记TF057:TensorFlow MNIST,卷积神经网络、循环神经网络、无监督学习

    MNIST 卷积神经网络.https://github.com/nlintz/TensorFlow-Tutorials/blob/master/05_convolutional_net.py .Ten ...

  9. AI tensorflow MNIST

    MNIST 数据 train-images-idx3-ubyte.gz:训练集图片 train-labels-idx1-ubyte.gz:训练集图片类别 t10k-images-idx3-ubyte. ...

  10. tensorflow——MNIST机器学习入门

    将这里的代码在项目中执行下载并安装数据集. 执行下面代码,训练.并评估模型: # _*_coding:utf-8_*_ import inputdata mnist = inputdata.read_ ...

随机推荐

  1. 2019-04-29 EasyWeb下配置Atomikos+SQLServer分布式数据源

    初次尝试: 配置Mysql时候使用的是Atomikos+DruidXADataSource,所以觉得配置SQLServer应该也是仅仅配置配置就够了,于是引入JDBC驱动依赖后,配置了文件 sprin ...

  2. 【HTML+CSS】在书写代码时的便捷应用

    创建多个相同元素: <!DOCTYPE html> <html> <head> <meta charset="UTF-8"> < ...

  3. bugku 逆向 love

    可以看到将输入先经过 sub_4110BE 这个函数进行加密 然后每一位加上下标本身 再和str2比较 正确就是right 点开加密函数: 关键语句就在这里 我们可以看到是算输入的字符按三个一组能分为 ...

  4. Spring Boot 2.x 编写 RESTful API (二) 校验

    用Spring Boot编写RESTful API 学习笔记 约束规则对子类依旧有效 groups 参数 每个约束用注解都有一个 groups 参数 可接收多个 class 类型 (必须是接口) 不声 ...

  5. 资源预加载preload和资源预读取prefetch简明学习

    前面的话 基于VUE的前端小站改造成SSR服务器端渲染后,HTML文档会自动使用preload和prefetch来预加载所需资源,本文将详细介绍preload和prefetch的使用 资源优先级 在介 ...

  6. [ffmpeg] 解码API

    版本迭代 ffmpeg解码API经过了好几个版本的迭代,上一个版本的API是 解码视频:avcodec_decode_video2 解码音频:avcodec_decode_audio4 我们现在能看到 ...

  7. 使用Python操作MongoDB

    MongoDB简介(摘自:http://www.runoob.com/mongodb/mongodb-intro.html) MongoDB 由C++语言编写,是一个基于分布式文件存储的开源数据库系统 ...

  8. LIS ZOJ - 4028

    http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=4028 memset超时 这题竟然是一个差分约束 好吧呢 对于每一个a[i] ...

  9. v-for 循环element-ui菜单

    vue 使用了element-ui的菜单组件, 这个组件的el-menu-item项上,有一个属性index,值是字符串类型, 在使用v-for的index时,它是一个数值型,所以如果直接写index ...

  10. Python3开发过程常见的异常(最近更新:2019-04-26)

    持续更新中... 常见异常解决方案 1.Base Python3.7环境相关:https://www.cnblogs.com/dotnetcrazy/p/9095793.html 1.1.Indent ...