基本思路:
1、对数据分块,使用多个worker分别处理一个数据块,每个worker暴露两个接口,分别是损失计算的接口loss和梯度计算的接口grad;
2、同时定义full_loss和full_grad接口对每个worker的loss和grad进行聚合;
3、使用bfgs算法进行参数优化,分别使用full_loss和full_grad作为bfgs的损失函数和梯度函数,即可进行网络参数优化;
注意:在此实现中,每个worker内部每次均计算一个数据块上的损失和梯度,而非一个batch

#0、导入依赖
import numpy as np
import os
import scipy.optimize import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data import ray
import ray.experimental.tf_utils #1、定义模型
class LinearModel(object):
def __init__(self, shape):
"""Creates a LinearModel object."""
x = tf.placeholder(tf.float32, [None, shape[0]])
w = tf.Variable(tf.zeros(shape))
b = tf.Variable(tf.zeros(shape[1]))
self.x = x
self.w = w
self.b = b
y = tf.nn.softmax(tf.matmul(x, w) + b)
y_ = tf.placeholder(tf.float32, [None, shape[1]])
self.y_ = y_
cross_entropy = tf.reduce_mean(
-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
self.cross_entropy = cross_entropy
self.cross_entropy_grads = tf.gradients(cross_entropy, [w, b])
self.sess = tf.Session() self.variables = ray.experimental.tf_utils.TensorFlowVariables(
cross_entropy, self.sess) def loss(self, xs, ys):
"""计算loss"""
return float(
self.sess.run(
self.cross_entropy, feed_dict={
self.x: xs,
self.y_: ys
})) def grad(self, xs, ys):
"""计算梯度"""
return self.sess.run(
self.cross_entropy_grads, feed_dict={
self.x: xs,
self.y_: ys
}) #2、定义远程worker,用于计算模型loss、grads
@ray.remote
class NetActor(object):
def __init__(self, xs, ys):
os.environ["CUDA_VISIBLE_DEVICES"] = ""
with tf.device("/cpu:0"):
self.net = LinearModel([784, 10])
self.xs = xs
self.ys = ys # 计算一个数据块的loss
def loss(self, theta):
net = self.net
net.variables.set_flat(theta)
return net.loss(self.xs, self.ys) # 计算一个数据块的梯度
def grad(self, theta):
net = self.net
net.variables.set_flat(theta)
gradients = net.grad(self.xs, self.ys)
return np.concatenate([g.flatten() for g in gradients]) def get_flat_size(self):
return self.net.variables.get_flat_size() #3、获取远程worker损失的函数
def full_loss(theta):
theta_id = ray.put(theta)
loss_ids = [actor.loss.remote(theta_id) for actor in actors]
return sum(ray.get(loss_ids)) #4、获取远程worker梯度的函数
def full_grad(theta):
theta_id = ray.put(theta)
grad_ids = [actor.grad.remote(theta_id) for actor in actors]
# 使用fmin_l_bfgs_b须转换为float64数据类型
return sum(ray.get(grad_ids)).astype("float64") #5、使用lbfgs进行训练
if __name__ == "__main__":
ray.init() mnist = input_data.read_data_sets("MNIST_data", one_hot=True)   # 数据分块,每个worker跑一个数据块
num_batches = 10
batch_size = mnist.train.num_examples // num_batches
batches = [mnist.train.next_batch(batch_size) for _ in range(num_batches)] actors = [NetActor.remote(xs, ys) for (xs, ys) in batches]   # 参数初始化
dim = ray.get(actors[0].get_flat_size.remote())
theta_init = 1e-2 * np.random.normal(size=dim)   # 优化
result = scipy.optimize.fmin_l_bfgs_b(
full_loss, theta_init, maxiter=10, fprime=full_grad, disp=True)   

基于ray的分布式机器学习(一)的更多相关文章

  1. 基于ray的分布式机器学习(二)

    基本思路:基于parameter server + multiple workers模式.同步方式:parameter server负责网络参数的统一管理,每次迭代均将参数发送给每一个worker,多 ...

  2. Angel 实现FFM 一、对于Angel 和分布式机器学习的简单了解

    Angel是腾讯开源的一个分布式机器学习框架.是一个PS模式的分布式机器学习框架. https://github.com/Angel-ML/angel   这是github地址. 我了解的分布式机器学 ...

  3. 分布式机器学习系统笔记(一)——模型并行,数据并行,参数平均,ASGD

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术.应用感兴趣的同学加入. 文章索引::"机器学 ...

  4. Adam:大规模分布式机器学习框架

    引子 转载请注明:http://blog.csdn.net/stdcoutzyx/article/details/46676515 又是好久没写博客,记得有一次看Ng大神的訪谈录,假设每周读三篇论文, ...

  5. 分布式机器学习框架:MxNet 前言

           原文连接:MxNet和Caffe之间有什么优缺点一.前言: Minerva: 高效灵活的并行深度学习引擎 不同于cxxnet追求极致速度和易用性,Minerva则提供了一个高效灵活的平台 ...

  6. [转帖]Greenplum :基于 PostgreSQL 的分布式数据库内核揭秘 (上篇)

    Greenplum :基于 PostgreSQL 的分布式数据库内核揭秘 (上篇) https://www.infoq.cn/article/3IJ7L8HVR2MXhqaqI2RA 学长的文章.. ...

  7. 分布式机器学习:逻辑回归的并行化实现(PySpark)

    1. 梯度计算式导出 我们在博客<统计学习:逻辑回归与交叉熵损失(Pytorch实现)>中提到,设\(w\)为权值(最后一维为偏置),样本总数为\(N\),\(\{(x_i, y_i)\} ...

  8. 分布式机器学习:同步并行SGD算法的实现与复杂度分析(PySpark)

    1 分布式机器学习概述 大规模机器学习训练常面临计算量大.训练数据大(单机存不下).模型规模大的问题,对此分布式机器学习是一个很好的解决方案. 1)对于计算量大的问题,分布式多机并行运算可以基本解决. ...

  9. 分布式机器学习:模型平均MA与弹性平均EASGD(PySpark)

    计算机科学一大定律:许多看似过时的东西可能过一段时间又会以新的形式再次回归. 1 模型平均方法(MA) 1.1 算法描述与实现 我们在博客<分布式机器学习:同步并行SGD算法的实现与复杂度分析( ...

随机推荐

  1. 15款NOSQL数据库

    1.MongoDB 介绍 MongoDB是一个基于分布式文件存储的数据库.由C++语言编写.主要解决的是海量数据的访问效率问题,为WEB应用提供可扩展的高性能数据存储解决方案.当数据量达到50GB以上 ...

  2. 前端学习 node 快速入门 系列 —— 服务端渲染

    其他章节请看: 前端学习 node 快速入门 系列 服务端渲染 在简易版 Apache一文中,我们用 node 做了一个简单的服务器,能提供静态资源访问的能力. 对于真正的网站,页面中的数据应该来自服 ...

  3. 数字转人民币读法-python3

    """ 2 把一个浮点数分解成证书备份和小数部分 3 """ 4 def divide(num): 5 intnum = int(num) ...

  4. java中ReentrantLock核心源码详解

    ReentrantLock简介 ReentrantLock是一个可重入且独占式的锁,它具有与使用synchronized监视器锁相同的基本行为和语义,但与synchronized关键字相比,它更灵活. ...

  5. 【Linux学习笔记0】-虚拟机运行CentOS(VMware12+CentOS)

    目录 一,资源 二,VMware12安装 记录自己学习linux的过程.这将会是一个系列,本文是该系列的第一部分,主要记录虚拟机(VMware12)及对应操作系统(CentOS)的安装过程. 虚拟机( ...

  6. ubuntu修改默认启动内核

    一.序言 新换的笔记本由于太新的主板芯片,驱动还没有完善.每次升级系统内核都要小心谨慎.经常发生部分硬件驱动失败的事情.系统Ubuntu 20.04.2 LTS x86_64 ,我现在使用的两个版本的 ...

  7. vue 快速入门 系列 —— 虚拟 DOM

    其他章节请看: vue 快速入门 系列 虚拟 DOM 什么是虚拟 dom dom 是文档对象模型,以节点树的形式来表现文档. 虚拟 dom 不是真正意义上的 dom.而是一个 javascript 对 ...

  8. jd的艺术

    我看最近的狗东的ldz很火哈.所以我也来凑个热闹发个教程. 准备工作 1.一台openwrt系统设备 2.一个脑子 3.一双手 话不多说,开始吧! 步骤 一.链接N1(你的设备) 这里需要一款ssh工 ...

  9. 201871030140-朱婷婷 实验三 结对项目—《D{0-1}KP 实例数据集算法实验平台》项目报告

    项目 内容 课程班级博客链接 2018级卓越班 这个作业要求链接 实验三 结对项目 我的课程学习目标 1.体验软件项目开发中的两人合作,练习结对编程:2.掌握GitHub协作开发程序的操作方法. 这个 ...

  10. (十三)VMware Harbor 身份验证模式

    VMware Harbor 修改Harbor仓库admin用户 参考:https://blog.csdn.net/qq_40460909 https://blog.csdn.net/qq_404609 ...