最近在学习tensorflow serving,但是就这样平淡看代码可能觉得不能真正思考,就想着写个文章看看,自己写给自己的,就像自己对着镜子演讲一样,写个文章也像自己给自己讲课,这样思考的比较深,学到的也比较多,有错欢迎揪出,

minist_saved_model.py 是tensorflow的第一个例子,里面有很多serving的知识,还不了解,现在看。下面是它的入口函数,然后直接跳转到main

if __name__ == '__main__':
tf.app.run()

在main函数里:

首先,是对一些参数取值等的合理性校验:

def main(_):
if len(sys.argv) < 2 or sys.argv[-1].startswith('-'):
print('Usage: mnist_export.py [--training_iteration=x] '
'[--model_version=y] export_dir')
sys.exit(-1)
if FLAGS.training_iteration <= 0:
print 'Please specify a positive value for training iteration.'
sys.exit(-1)
if FLAGS.model_version <= 0:
print 'Please specify a positive value for version number.'
sys.exit(-1)

然后,就开始train model,既然是代码解读加上自己能力还比较弱,简单的我得解读呀,牛人绕道。。。

# Train model
print 'Training model...'
#输入minist数据,这个常见的,里面的源码就是查看有没有数据,没有就在网上
下载下来,然后封装成一个个batch
mnist = mnist_input_data.read_data_sets(FLAGS.work_dir, one_hot=True) #这是创建一个session,Session是Graph和执行者之间的媒介,Session.run()实际
上将graph、fetches、feed_dict序列化到字节数组中进行计算
sess = tf.InteractiveSession() #定义一个占位符,为以后数据等输入留好接口
serialized_tf_example = tf.placeholder(tf.string, name='tf_example') #feature_configs 顾名思义,是特征配置,从形式上看这是一个字典,字典中
初始化key为‘x’,value 是 tf.FixedLenFeature(shape=[784], dtype=tf.float32)的返
回值,而该函数的作用是解析定长的输入特征feature相关配置
feature_configs = {'x': tf.FixedLenFeature(shape=[784], dtype=tf.float32),} #parse_example 常用于稀疏输入数据
tf_example = tf.parse_example(serialized_tf_example, feature_configs) #
x = tf.identity(tf_example['x'], name='x') # use tf.identity() to assign name #因为输出是10类,所以y_设置成None×10
y_ = tf.placeholder('float', shape=[None, 10]) #定义权重变量
w = tf.Variable(tf.zeros([784, 10])) #定义偏置变量
b = tf.Variable(tf.zeros([10])) #对定义的变量进行参数初始化
sess.run(tf.global_variables_initializer()) #对输入的x和权重w,偏置b进行处理
y = tf.nn.softmax(tf.matmul(x, w) + b, name='y') #计算交叉熵
cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) #配置优化函数
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #这个函数的作用是返回 input 中每行最大的 k 个数,并且返回它们所在位置的索引
values, indices = tf.nn.top_k(y, 10) #这函数返回一个将索引的Tensor映射到字符串的查找表
table = tf.contrib.lookup.index_to_string_table_from_tensor(
tf.constant([str(i) for i in xrange(10)])) #在tabel中查找索引
prediction_classes = table.lookup(tf.to_int64(indices)) #然后开始训练迭代啦
for _ in range(FLAGS.training_iteration):
#获取一个batch数据
batch = mnist.train.next_batch(50)
#计算train_step运算,train_step是优化函数的,这个执行带来的作用就是
根据学习率,最小化cross_entropy,执行一次,就调整参数权重w一次
train_step.run(feed_dict={x: batch[0], y_: batch[1]}) #将得到的y和y_进行对比
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) #对比结果计算准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float')) #运行sess,并使用更新后的最终权重,去做预测,并返回预测结果
print 'training accuracy %g' % sess.run(
accuracy, feed_dict={x: mnist.test.images,
y_: mnist.test.labels})
print 'Done training!'

上面就是训练的过程,就和普通情况下train模型是一样的道理,现在,我们看后面的model export

# Export model
# WARNING(break-tutorial-inline-code): The following code snippet is
# in-lined in tutorials, please update tutorial documents accordingly
# whenever code changes.
#export_path_base基本路径代表你要将model export到哪一个路径下面,
#它的值的获取是传入参数的最后一个,训练命令为:
bazel-bin/tensorflow_serving/example/mnist_saved_model /tmp/mnist_model
那输出的路径就是/tmp/mnist_model
export_path_base = sys.argv[-1] #export_path 真正输出的路径是在基本路径的基础上加上版本号,默认是version=1
export_path = os.path.join(
tf.compat.as_bytes(export_path_base),
tf.compat.as_bytes(str(FLAGS.model_version)))
print 'Exporting trained model to', export_path #官网解释:Builds the SavedModel protocol buffer and saves variables and assets.
builder = tf.saved_model.builder.SavedModelBuilder(export_path) # Build the signature_def_map.
# serialized_tf_example是上面提到的占位的输入,
#其当时定义为tf.placeholder(tf.string, name='tf_example') #tf.saved_model.utils.build_tensor_info 的作用是构建一个TensorInfo proto
#输入参数是张量的名称,类型,大小,这里是string,想应该是名称吧,毕竟
#代码还没全部看完,先暂时这么猜测。输出是,基于提供参数的a tensor protocol
# buffer
classification_inputs = tf.saved_model.utils.build_tensor_info(
serialized_tf_example) #函数功能介绍同上,这里不同的是输入参数是prediction_classes,
#其定义,prediction_classes = table.lookup(tf.to_int64(indices)),是一个查找表
#为查找表构建a tensor protocol buffer
classification_outputs_classes = tf.saved_model.utils.build_tensor_info(
prediction_classes) #函数功能介绍同上,这里不同的是输入参数是values,
#其定义,values, indices = tf.nn.top_k(y, 10),是返回的预测值
#为预测值构建a tensor protocol buffer
classification_outputs_scores = tf.saved_model.utils.build_tensor_info(values) #然后,继续看,下面那么多行都是一个语句,一个个结构慢慢解析
#下面可以直观地看到有三个参数,分别是inputs ,ouputs和method_name
#inputs ,是一个字典,其key是tensorflow serving 固定定义的接口,
#为: tf.saved_model.signature_constants.CLASSIFY_INPUTS,value的话
#就是之前build的a tensor protocol buffer 之 classification_inputs
#同样的,output 和method_name 也是一个意思,好吧,这部分就
#了解完啦。
classification_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={
tf.saved_model.signature_constants.CLASSIFY_INPUTS:
classification_inputs
},
outputs={
tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES:
classification_outputs_classes,
tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:
classification_outputs_scores
},
method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME)) #这两句话都和上面一样,都是构建a tensor protocol buffer
tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y) 这个和上面很多行的classification_signature,一样的
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'images': tensor_info_x},
outputs={'scores': tensor_info_y},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) #这个不一样了,tf.group的官网解释挺简洁的
#Create an op that groups multiple operations.
#When this op finishes, all ops in input have finished. This op has no output.
#Returns:An Operation that executes all its inputs.
#我们看下另一个tf.tables_initializer():
#Returns:An Op that initializes all tables. Note that if there are not tables the returned Op is a NoOp
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op') #下面是重点啦,怎么看出来的?因为上面都是定义什么的,下面是最后的操作啦
#就一个函数:builder.add_meta_graph_and_variables,
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'predict_images':
prediction_signature,
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
classification_signature,
},
legacy_init_op=legacy_init_op) builder.save()
print 'Done exporting!'
这里要从 tf.saved_model.builder.SavedModelBuilder 创建build开始,下面是看官网的,
可以直接参考:https://www.tensorflow.org/api_docs/python/tf/saved_model/builder/SavedModelBuilder

创建builder的是class SaveModelBuilder的功能是用来创建SaverModel

protocol buffer 并保存变量和资源,SaverModelBuilder类提供了创建

SaverModel protocol buffer 的函数方法

tensorflow serving 之minist_saved_model.py解读的更多相关文章

  1. tensorflow serving

    1.安装tensorflow serving 1.1确保当前环境已经安装并可运行tensorflow 从github上下载源码 git clone --recurse-submodules https ...

  2. tensorflow serving 中 No module named tensorflow_serving.apis,找不到predict_pb2问题

    最近在学习tensorflow serving,但是运行官网例子,不使用bazel时,发现运行mnist_client.py的时候出错, 在api文件中也没找到predict_pb2,因此,后面在网上 ...

  3. Tensorflow Serving 模型部署和服务

    http://blog.csdn.net/wangjian1204/article/details/68928656 本文转载自:https://zhuanlan.zhihu.com/p/233614 ...

  4. tensorflow serving 编写配置文件platform_config_file的方法

    1.安装grpc gRPC 的安装: $ pip install grpcio 安装 ProtoBuf 相关的 python 依赖库: $ pip install protobuf 安装 python ...

  5. Tensorflow Serving介绍及部署安装

    TensorFlow Serving 是一个用于机器学习模型 serving 的高性能开源库.它可以将训练好的机器学习模型部署到线上,使用 gRPC 作为接口接受外部调用.更加让人眼前一亮的是,它支持 ...

  6. 如何用 tensorflow serving 部署服务

    第一步,读一读这篇博客 https://www.jb51.net/article/138932.htm (浅谈Tensorflow模型的保存与恢复加载) 第二步: 参考博客: https://blog ...

  7. Tensorflow serving的编译

    Tensorflow serving提供了部署tensorflow生成的模型给线上服务的方法,包括模型的export,load等等. 安装参考这个 https://github.com/tensorf ...

  8. 谷歌发布 TensorFlow Serving

    TensorFlow服务是一个灵活的,高性能的机器学习模型的服务系统,专为生产环境而设计. TensorFlow服务可以轻松部署新的算法和实验,同时保持相同的服务器体系结构和API. TensorFl ...

  9. 学习笔记TF067:TensorFlow Serving、Flod、计算加速,机器学习评测体系,公开数据集

    TensorFlow Serving https://tensorflow.github.io/serving/ . 生产环境灵活.高性能机器学习模型服务系统.适合基于实际数据大规模运行,产生多个模型 ...

随机推荐

  1. pycharm中快捷键的使用

    转载自:https://blog.csdn.net/fighter_yy/article/details/40860949 Alt+Enter 自动添加包 shift+O 自动建议代码补全 Ctrl+ ...

  2. ERROR: duplicate key value violates unique constraint "xxx"

    在postgresql中,由于为表的主键建立了自增序列,且数据是从正式库拷贝到正式库的,所以报错如下: (主要原因:自增序列中的当前序列号小于真实数据中的最大主键值,因此在新增数据时,会报唯一值的错误 ...

  3. ajax,jsonp跨域访问数据

    访问高德aip天气接口 <!DOCTYPE html> <html> <head> <meta charset="utf-8"> & ...

  4. YARN label 特性 & 指定队列及label提交任务

    以下基于 hadoop版本 hadoop-2.8.4 给各个节点打标签 yarn rmadmin -addToClusterNodeLabels fastcpu,normal # 是否独占默认是tru ...

  5. hive设置参数的方法

    1.修改环境变量 ${HIVE_HOME}/conf/hive-site.xml 2.命令行参数 -e : 执行短命令 -f :  执行文件(适合脚本封装) -S : 安静模式,不显示MR的运行过程 ...

  6. spark sql 中的结构化数据

    1. 连接mysql 首先需要把mysql-connector-java-5.1.39.jar 拷贝到 spark 的jars目录里面: scala> import org.apache.spa ...

  7. Flex学习笔记-皮肤

    1文件结构 MXML应用程序 index.mxml 皮肤文件 components.button.skin.btnSkin1.mxml  皮肤文件的组件随便引用了spark.components.Bu ...

  8. python __class__属性

    >>> class a(object): pass >>> o=a() >>> dir(o) ['__class__', '__delattr__ ...

  9. NRF51822之使用外部32Mhz晶振

    硬件平台为微雪BLE400的(将原来的16mhz晶振改为32mhz.两个旁电容改为22pf) 以nRF51_SDK_10.0.0_dc26b5e\examples\ble_peripheral\ble ...

  10. eclipse中无法新建Android工程 出现问题:Plug-in org.eclipse.ajdt.ui was unable to load

    转自:http://www.bubuko.com/infodetail-757338.html eclipse中打开后新建Android项目区仍无法创建,出现下列提示对话框: Plug-in org. ...