一、单机编程框架

单机程序是指启动和运行都在一台机器的一个进程中完成,因为没有网络开销,非常适合参数不多、计算量小的模型。

步骤,创建单机数据流图,创建并运行单机会话。

  1. saver = tf.train.Saver()
  2. sess = tf.InteractiveSession()
  3. tf.global_variables_initializer().run()
  4.  
  5. for i in range(1000):
  6. batch_xs,batch_ys = mnist.train.next_batch(100)
  7. sess.run(train_step,feed_dict={x:batch_xs,y_=batch_ys})
  8. if i%100 = 0:
  9. saver.save(sess,'mnist.ckpt')
  10. correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
  11. accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  12. sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})

如果想指定机器上的设备如cpu,gpu

可以使用

with tf.device('/cpu:0'):

  ……

二、分布式程序编程框架

PS-worker是一种经典的分布式架构,它在大规模分布式机器学习和深度学习中有广泛的应用,tensorFlow提供了对PS-worker的支持。

步骤

(1).pull,各worker根据数据流图的拓扑结构,从PS拉取最新的模型参数

(2).feed,各worker按照一定的规则填充不同批次的批数据

(3).compute,各worker使用相同的模型参数和不同的批数据计算梯度,得出不同的梯度值

(4).push,各worker将上一步计算得到的梯度值推送到PS

(5).update,PS汇总数据,求出梯度平均值后更新模型参数

分布式程序运行步骤 创建集群,创建分布式数据流图,创建分布式会话

集群创建, tf.train.Server(host,job_name,task_index)

将操作放置在目标设备上

  1. with tf.device('/job:PS/task:0'):
  2. weights_1 = tf.Variable()
  3. with tf.device('/job:PS/task:1'):
  4. weights_2 = tf.Variable()
  5. with tf.device('/job:worker/task:1'):
  6. tf.nn.relu()

3.训练机制

同步训练机制

每个worker独立训练,直到所有worker计算出梯度值后进行模型参数的汇总计算,并更新当前训练步的模型参数,计算较快的worker需要阻塞等待计算较慢的worker

  1. y = tf.nn.softmax(tf.nn.xw_plus_b(hid,sm_w,sm_b))
  2. cross_entropy = -tf.reduce_sum(FLAGS.learning_rate)
  3. if FLAGS.sync_replicas:
  4. opt = tf.train.SyncReplicasOptimizer(opt,replicas_to_aggregate=10,total_num_replicas=100,name='mnist_sync')
  5. opt.minimize(cross_entropy,global_step=1)

异步训练机制

每个worker独立训练,计算出梯度值后立即进行模型参数计算,每个worker无阻塞等待其他所有worker的梯度计算完成。

tensorflow(五)的更多相关文章

  1. TensorFlow(五):手写数字识别加强版

    # 该版本的最终识别准确率达到98%以上 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_d ...

  2. Anaconda安装和卸载+虚拟环境Tensorflow安装以及末尾问题大全(附Anaconda安装包),这一篇就够了!!!

    前言 实话说,在自己亲手捣鼓了一下午加一晚上后,本人深深地感受到了对于"Anaconda安装+虚拟环境Tensorflow安装"里面的坑点之多,再加上目前一些博主的资料有点久远,尤 ...

  3. centos7 手把手从零搭建深度学习环境 (以TensorFlow2.0为例)

    目录 一. 搭建一套自己的深度学习平台 二. 安装系统 三. 安装NVIDA组件 四. 安装深度学习框架 TensorFlow 五. 配置远程访问 六. 验收 七. 福利(救命稻草

  4. tensorflow笔记(五)之MNIST手写识别系列二

    tensorflow笔记(五)之MNIST手写识别系列二 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7455233.html ...

  5. tensorflow入门笔记(五) name_scope和variable_scope

    一.上下文管理器(context manager) 上下文管理器是实现了上下文协议的对象,主要用于资源的获取与释放.上下文协议包括__enter__.__exit__,简单说就是,具备__enter_ ...

  6. 深度学习(五)基于tensorflow实现简单卷积神经网络Lenet5

    原文作者:aircraft 原文地址:https://www.cnblogs.com/DOMLX/p/8954892.html 参考博客:https://blog.csdn.net/u01287127 ...

  7. tensorflow学习笔记五----------逻辑回归

    在逻辑回归中使用mnist数据集.导入相应的包以及数据集. import numpy as np import tensorflow as tf import matplotlib.pyplot as ...

  8. tensorflow学习笔记五:mnist实例--卷积神经网络(CNN)

    mnist的卷积神经网络例子和上一篇博文中的神经网络例子大部分是相同的.但是CNN层数要多一些,网络模型需要自己来构建. 程序比较复杂,我就分成几个部分来叙述. 首先,下载并加载数据: import ...

  9. TF Boys (TensorFlow Boys ) 养成记(五)

    有了数据,有了网络结构,下面我们就来写 cifar10 的代码. 首先处理输入,在 /home/your_name/TensorFlow/cifar10/ 下建立 cifar10_input.py,输 ...

随机推荐

  1. 下载jQuery

    下载jQuery :https://jquery.com/download/ . 将下载好的文件放到项目中 引入到代码中 <script type="text/javascript&q ...

  2. Django中使用ORM

    一.ORM概念 对象关系映射(Object Relational Mapping,简称ORM)模式是一种为了解决面向对象与关系数据库存在的互不匹配的现象的技术. 简单的说,ORM是通过使用描述对象和数 ...

  3. Assignment写作需要掌握的两种表达方式

    在正式开始写Assignment之前都会进行文献检索和整理,选择适合Assignment选题的文献资料进行阅读和引用.对于文献中与自己的观点高度相关的参考资料要如何具体引用,而不造成抄袭或者增加文章的 ...

  4. Spring Boot without the web server

    https://stackoverflow.com/questions/26105061/spring-boot-without-the-web-server/28565277 1. spring.m ...

  5. 关于Java中内省的总结

    内省基于JavaBean规范对反射进行了封装,提供了更加便捷的通过getter/setter方法来访问字段的方式 Java内省的知识结构图 JavaBean的规范 JavaBean在现在可以认为就是普 ...

  6. Linux简介和环境的搭建

    Linux的学习方向 网络服务器 嵌入式程序开发 Linux的设计哲学:一切皆文件 常用命令:cd 切换目录sudo shutdown -h now 关机命令sudo reboot 重启sudo ro ...

  7. POJ 1745:Divisibility 枚举某一状态的DP

    Divisibility Time Limit: 1000MS   Memory Limit: 10000K Total Submissions: 11001   Accepted: 3933 Des ...

  8. C++ 一个exe的两个运行实例之间共享数据

    #pragma data_seg("Shared") volatile int iNum = 0; #pragma data_seg() #pragma comment(linke ...

  9. 遍历数组提取List[Int]

    def toFlatMap(input:List[Any],result:List[Int]):List[Int]=input match{ case h::t=>h match {case e ...

  10. 18 12 26 css 学习 选择器

    1.标签选择器 标签选择器,此种选择器影响范围大,建议尽量应用在层级选择器中.举例: *{margin:0;padding:0} div{color:red} <div>....</ ...