本文目的

在介绍estimator分布式的时候,官方文档由于版本更新导致与接口不一致。具体是:在estimator分布式当中,使用dataset作为数据输入,在1.12版本中,数据训练只是dataset的数据,就是所有设备加起来,跑一遍数据。

而在2.0版本中,训练数据是dataset的数据乘以分

布式的设备数。也就是说,在每个设备当中都会完整地跑一遍dataset的所有数据。

1.12版本读取

1. 在主线程当中创建图

下面这段代码中,在client中调用了input function,得到迭代器。这是属于estimator distribute train调用的代码

with ops.Graph().as_default() as g:
# We want to create the iterations variable outside the distribution scope
# as that is just stored on the host and mainly used to drive the loop
# and doesn't need to be a Mirrored/Device variable.
if is_tpu_strategy:
steps_per_run_variable = training.get_or_create_steps_per_run_variable()
with self._train_distribution.scope():
random_seed.set_random_seed(self._config.tf_random_seed)
iterator, input_hooks = self._get_iterator_from_input_fn(
input_fn, model_fn_lib.ModeKeys.TRAIN, self._train_distribution)
  • _get_iterator_from_input_fn * 这个函数会生成迭代器供后续训练读取数据。
  def _get_iterator_from_input_fn(self, input_fn, mode, distribution=None):
if distribution is not None:
result = distribution.distribute_dataset(
lambda: self._call_input_fn(input_fn, mode))
else:
result = self._call_input_fn(input_fn, mode) iterator = result.make_initializable_iterator()
input_hooks = [estimator_util._DatasetInitializerHook(iterator)] # pylint: disable=protected-access
return iterator, input_hooks

这里会调用distribute_dataset生成dataset。

再点进去看以后可看到会创建这样一个PerDeviceDataset

class PerDeviceDataset(object):
"""Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" def __init__(self, dataset, devices, prefetch_on_device=None):
self._devices = devices # Default to using prefetching in graph mode, unless specified.
# TODO(priyag): Enable prefetching in eager mode.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = not context.executing_eagerly()
assert not (self._prefetch_on_device and context.executing_eagerly()), (
"Prefetching is only supported in graph mode currently") if self._prefetch_on_device:
self._dataset = dataset.apply(
prefetching_ops_v2.prefetch_to_devices(self._devices))
else:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
self._dataset = dataset.batch(len(devices), drop_remainder=True)

最后一行代码可以看到,在原dataset上又封装了一层batch。将数据根据设备数切分。

后面创建迭代器也是封装为PerDeviceDataIterator,形成一个字典映射,不同设备不同数据,根据batch 的index切分。

分布式训练

在1.12版本中的训练比较简单。对于MirroredStrategy来说,会给每个一个device创建一个线程,

有一个缺点就是,每一次run都会创建线程,在todo里看到,后续会优化掉应该。

下面是在client中从迭代器获取数据,传递给每个device去运算的代码,

self._train_distribution.call_for_each_tower

features, labels = estimator_util.parse_iterator_result(
iterator.get_next())
grouped_estimator_spec = self._train_distribution.call_for_each_tower(
self._call_model_fn,
features,
labels, # although this will be None it seems
model_fn_lib.ModeKeys.TRAIN,
self.config)
loss = self._train_distribution.unwrap(
self._train_distribution.reduce(
distribute_lib.get_loss_reduction(),
grouped_estimator_spec.loss,
destinations='/device:CPU:0'))[0]
distributed_train_op = grouped_estimator_spec.train_op

call_for_each_tower是每个设备训练的接口

def _call_for_each_tower(distribution, fn, *args, **kwargs):
"""Run `fn` in separate threads, once per tower/worker device.
run_concurrently = kwargs.pop("run_concurrently", True)
if not context.executing_eagerly():
# Lots of TF library code isn't thread-safe in graph mode, and
# there is little to be gained by turning on multithreading when
# constructing a graph.
run_concurrently = False
# Needed for per-thread device, etc. contexts in graph mode.
ops.get_default_graph().switch_to_thread_local()
elif run_concurrently is None:
run_concurrently = True coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) shared_variable_store = {} # TODO(isaprykin): Create these threads once instead of during every run()
# call.
threads = []
for index, d in enumerate(distribution.worker_devices):
variable_creator_fn = shared_variable_creator.make_fn(
shared_variable_store, index)
t = MirroredStrategy._MirroredTowerThread( # pylint: disable=protected-access
distribution, coord, d, variable_creator_fn, fn,
*values.select_device(d, args), **values.select_device(d, kwargs))
threads.append(t) for t in threads:
t.start()

其中,select_device就是取对应设备key对应的值。完成整个分布式训练。

TensorFlow Distribution(分布式中的数据读取和训练)的更多相关文章

  1. DataTable to Excel(使用NPOI、EPPlus将数据表中的数据读取到excel格式内存中)

    /// <summary> /// DataTable to Excel(将数据表中的数据读取到excel格式内存中) /// </summary> /// <param ...

  2. TensorFlow走过的坑之---数据读取和tf中batch的使用方法

    首先介绍数据读取问题,现在TensorFlow官方推荐的数据读取方法是使用tf.data.Dataset,具体的细节不在这里赘述,看官方文档更清楚,这里主要记录一下官方文档没有提到的坑,以示" ...

  3. oracle中的数据读取与查找

    数据读取 首先数据块读入到Buffer Cache中,并将其放在LRU(Last Recently Used)链表的MRU(Most Recently Used)端,当需要再次访问该块时可以直接从bu ...

  4. c#中使用数据读取器读取查询结果

    今天有时间了. 在看<c#数据库入门经典> ,总结数据读取器查询结果. 针对单个结果集使用读取器,有3中方法: String connString =..; String sql =@&q ...

  5. 如何在ADO中使用数据读取器(DataReader)读取数据

    DbDataReader类型(实现IDataReader接口)是从数据源获取信息最简单也最快速的方法. 数据读取器是只读向前的效据流.井且一次返回一条记录.因此.只有当你向数据源提交 Select 查 ...

  6. 《TensorFlow实战》中AlexNet卷积神经网络的训练中

    TensorFlow实战中AlexNet卷积神经网络的训练 01 出错 TypeError: as_default() missing 1 required positional argument: ...

  7. Android中Json数据读取与创建

    一:  Json的特性和在数据交互中的地位就不用说了,直接看案例. 首先在android studio中创建assets文件目录,用于存放Json数据文件,android studio 1.3 默认项 ...

  8. Android中Json数据读取与创建的方法

    转自:http://www.jb51.net/article/70875.htm 首先介绍下JSON的定义,JSON是JavaScript Object Notation的缩写. 一种轻量级的数据交换 ...

  9. TensorFlow实践笔记(一):数据读取

    本文整理了TensorFlow中的数据读取方法,在TensorFlow中主要有三种方法读取数据: Feeding:由Python提供数据. Preloaded data:预加载数据. Reading ...

随机推荐

  1. 【iOS】file not found: .../Build/Products/Debug-iphonesimulator file not found

    今天又遇到了这个问题: ld: file not found: /Users/***/Library/Developer/Xcode/DerivedData/***-dfscappaygvbougtb ...

  2. 调用ffmpeg视频压缩工具类

    package com.example.demo; import com.alibaba.fastjson.JSONObject;import com.aliyun.oss.ClientExcepti ...

  3. kubernetes对接第三方认证

    kubernetes对接第三方认证 kubernetes离线安装包地址 概述 本文介绍如何使用github账户去关联自己kubernetes账户.达到如下效果: 使用github用户email作为ku ...

  4. python 处理json数据

    python 处理 json数据 以下是登录账号后获取的json数据,headers中注意加入cookie值 需要处理的数据如下: 全部代码如下 #!/usr/bin/env python # -*- ...

  5. PID算法 旋转倒立摆与平衡车的区别。此贴后边会更新。

    我做PID算法的背景和经历:本人之前电子信息科学与技术专业,对控制方向颇感兴趣,刚上大学时听到实验室老师说PID算法,那年在暑假集训准备全国电子设计竞赛,我正在练习做一个以前专科的题目,帆板角度控制系 ...

  6. 一文了解:Redis的RDB持久化

    一文了解:Redis的RDB持久化 Redis是内存数据库,为了保证数据不在故障后丢失,Redis需要将数据持久化到硬盘上. Redis持久化有两种方式:一种是快照,全量备份.一种是AOF方式,连续增 ...

  7. Of efficiency and methodology

    There are only too many articles and books which pertains to the discussion of efficiency and method ...

  8. 0x15 字符串

    KMP算法 next数组的求法 void calc_next() { next[]=; , j=; i<=n; ++i) { &&a[i]!=a[j+]) j=next[j]; ...

  9. linux学习总结--linux100day(day2)

    Linux中的哲学--一切皆文件 为了便于操作,我们可以使用secureCRT或Xshell连接到我们的虚拟机. 要用远程工具连接到虚拟机上,我们只需要打开虚拟机上的ssh服务,在xshell中填写主 ...

  10. 从输入URL到浏览器显示页面发生了哪些事情---个人理解

    经典面试题:从输入URL到页面显示发生了哪些事情 以前一直都记不住,这次自己理解了一下 用自己的话总结了一次,不对的地方希望大佬给我指出来 1.主机通过DHCP协议获取客户端的IP地址.子网掩码和DN ...