我们在《Python中的随机采样和概率分布(二)》介绍了如何用Python现有的库对一个概率分布进行采样,其中的dirichlet分布大家一定不会感到陌生。该分布的概率密度函数为

\[P(\bm{x}; \bm{\alpha}) \propto \prod_{i=1}^{k} x_{i}^{\alpha_{i}-1} \\
\bm{x}=(x_1,x_2,...,x_k),\quad x_i > 0 , \quad \sum_{i=1}^k x_i = 1\\
\bm{\alpha} = (\alpha_1,\alpha_2,..., \alpha_k). \quad \alpha_i > 0
\]

其中\(\bm{\alpha}\)为参数。

我们在联邦学习中,经常会假设不同client间的数据集不满足独立同分布(non-iid)。那么我们如何将一个现有的数据集按照non-iid划分呢?我们知道带标签样本的生成分布看可以表示为\(p(\bm{x}, y)\),我们进一步将其写作\(p(\bm{x}, y)=p(\bm{x}|y)p(y)\)。其中如果要估计\(p(\bm{x}|y)\)的计算开销非常大,但估计\(p(y)\)的计算开销就很小。所有我们按照样本的标签分布来对样本进行non-iid划分是一个非常高效、简便的做法。

总而言之,我们采取的算法思路是尽量让每个client上的样本标签分布不同。我们设有\(K\)个类别标签,\(N\)个client,每个类别标签的样本需要按照不同的比例划分在不同的client上。我们设矩阵\(\bm{X}\in \mathbb{R}^{K*N}\)为类别标签分布矩阵,其行向量\(\bm{x}_k\in \mathbb{R}^N\)表示类别\(k\)在不同client上的概率分布向量(每一维表示\(k\)类别的样本划分到不同client上的比例),该随机向量就采样自dirichlet分布。

据此,我们可以写出以下的划分算法:

  1. import numpy as np
  2. np.random.seed(42)
  3. def split_noniid(train_labels, alpha, n_clients):
  4. '''
  5. 参数为alpha的dirichlet分布将数据索引划分为n_clients个子集
  6. '''
  7. n_classes = train_labels.max()+1
  8. label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)
  9. # (K, N)的类别标签分布矩阵X,记录每个client占有每个类别的多少
  10. class_idcs = [np.argwhere(train_labels==y).flatten()
  11. for y in range(n_classes)]
  12. # 记录每个K个类别对应的样本下标
  13. client_idcs = [[] for _ in range(n_clients)]
  14. # 记录N个client分别对应样本集合的索引
  15. for c, fracs in zip(class_idcs, label_distribution):
  16. # np.split按照比例将类别为k的样本划分为了N个子集
  17. # for i, idcs 为遍历第i个client对应样本集合的索引
  18. for i, idcs in enumerate(np.split(c, (np.cumsum(fracs)[:-1]*len(c)).astype(int))):
  19. client_idcs[i] += [idcs]
  20. client_idcs = [np.concatenate(idcs) for idcs in client_idcs]
  21. return client_idcs

加下来我们在EMNIST数据集上调用该函数进行测试,并进行可视化呈现。我们设client数量\(N=10\),dirichlet概率分布的参数向量\(\bm{\alpha}\)满足\(\alpha_i=1.0,\space i=1,2,...N\):

  1. import torch
  2. from torchvision import datasets
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. torch.manual_seed(42)
  6. if __name__ == "__main__":
  7. N_CLIENTS = 10
  8. DIRICHLET_ALPHA = 1.0
  9. train_data = datasets.EMNIST(root=".", split="byclass", download=True, train=True)
  10. test_data = datasets.EMNIST(root=".", split="byclass", download=True, train=False)
  11. n_channels = 1
  12. input_sz, num_cls = train_data.data[0].shape[0], len(train_data.classes)
  13. train_labels = np.array(train_data.targets)
  14. # 我们让每个client不同label的样本数量不同,以此做到non-iid划分
  15. client_idcs = split_noniid(train_labels, alpha=DIRICHLET_ALPHA, n_clients=N_CLIENTS)
  16. # 展示不同client的不同label的数据分布
  17. plt.figure(figsize=(20,3))
  18. plt.hist([train_labels[idc]for idc in client_idcs], stacked=True,
  19. bins=np.arange(min(train_labels)-0.5, max(train_labels) + 1.5, 1),
  20. label=["Client {}".format(i) for i in range(N_CLIENTS)], rwidth=0.5)
  21. plt.xticks(np.arange(num_cls), train_data.classes)
  22. plt.legend()
  23. plt.show()

最终的可视化结果如下:



可以看到,62个类别标签在不同client上的分布确实不同,证明我们的样本划分算法是有效的。

联邦学习:按Dirichlet分布划分Non-IID样本的更多相关文章

  1. 联邦学习:按混合分布划分Non-IID样本

    我们在博文<联邦学习:按病态独立同分布划分Non-IID样本>中学习了联邦学习开山论文[1]中按照病态独立同分布(Pathological Non-IID)划分样本. 在上一篇博文< ...

  2. LDA学习之beta分布和Dirichlet分布

    ---恢复内容开始--- 今天学习LDA主题模型,看到Beta分布和Dirichlet分布一脸的茫然,这俩玩意怎么来的,再网上查阅了很多资料,当做读书笔记记下来: 先来几个名词: 共轭先验: 在贝叶斯 ...

  3. Apache Pulsar 在腾讯 Angel PowerFL 联邦学习平台上的实践

    腾讯 Angel PowerFL 联邦学习平台 联邦学习作为新一代人工智能基础技术,通过解决数据隐私与数据孤岛问题,重塑金融.医疗.城市安防等领域. 腾讯 Angel PowerFL 联邦学习平台构建 ...

  4. 【一周聚焦】 联邦学习 arxiv 2.16-3.10

    这是一个新开的每周六定期更新栏目,将本周arxiv上新出的联邦学习等感兴趣方向的文章进行总结.与之前精读文章不同,本栏目只会简要总结其研究内容.解决方法与效果.这篇作为栏目首发,可能不止本周内容(毕竟 ...

  5. 关于Beta分布、二项分布与Dirichlet分布、多项分布的关系

    在机器学习领域中,概率模型是一个常用的利器.用它来对问题进行建模,有几点好处:1)当给定参数分布的假设空间后,可以通过很严格的数学推导,得到模型的似然分布,这样模型可以有很好的概率解释:2)可以利用现 ...

  6. 【论文考古】联邦学习开山之作 Communication-Efficient Learning of Deep Networks from Decentralized Data

    B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas, "Communication-Efficient Learni ...

  7. Beta分布和Dirichlet分布

    在<Gamma函数是如何被发现的?>里证明了\begin{align*} B(m, n) = \int_0^1 x^{m-1} (1-x)^{n-1} \text{d} x = \frac ...

  8. LDA-math-认识Beta/Dirichlet分布

    http://cos.name/2013/01/lda-math-beta-dirichlet/#more-6953 2. 认识Beta/Dirichlet分布2.1 魔鬼的游戏—认识Beta 分布 ...

  9. 机器学习的数学基础(1)--Dirichlet分布

    机器学习的数学基础(1)--Dirichlet分布 这一系列(机器学习的数学基础)主要包括目前学习过程中回过头复习的基础数学知识的总结. 基础知识:conjugate priors共轭先验 共轭先验是 ...

随机推荐

  1. [flask] jinja自定义filter来过滤html标签

    问题描述 数据库存储了html格式的博客文章,在主页(index)显示的时候带有html标签,如何过滤掉呢? 解决方案 用jinja自定义filter过滤掉html标签 我是用的工厂函数,因此在工厂函 ...

  2. 51 Nod 1083 矩阵取数问题(动态规划)

    原题链接:https://www.51nod.com/onlineJudge/questionCode.html#!problemId=1083 题目分析:通过读题发现我们只能往右边或者下边走,意味着 ...

  3. STM32 EXTI(外部中断)

    一.EXTI 简介 EXTI(External interrupt/event controller)-外部中断/事件控制器,管理了控制器的 20个中断/事件线.每个中断/事件线都对应有一个边沿检测器 ...

  4. C 库函数 - pow()

    1.C 标准库 - <math.h> 2.C 库函数 double pow(double x, double y) 返回 x 的 y 次幂,即 xy. 3.pow() 函数的声明. dou ...

  5. 【视频解码性能对比】opencv + cuvid + gpu vs. ffmpeg + cpu

    视频大小:1168856 字节画面尺寸:480*848帧数:275opencv + cuvid + tesla P4, 解码性能:1426.84 fps ffmpeg 4.0 API + [Intel ...

  6. 事务与一致性:刚性or柔性

    转发自 https://cloud.tencent.com/developer/article/1038871 在高并发场景下,分布式储存和处理已经是常用手段.但分布式的结构势必会带来"不一 ...

  7. vuecli学习01 - 环境搭建

    到这个链接下载nvm的安装包:https://github.com/coreybutler/nvm-windows/releases. 然后点击一顿下一步,安装即可! 安装完成后,还需要配置环境变量. ...

  8. WEB前端基础之SCC(字体颜色背景-盒子模型)

    目录 一:伪元素选择器 1.首字调整>>>:也是一种文档布局的方式 2.在文本的前面通过css动态渲染文本>>>:特殊文本无法选中 3.在文本的后面通过css动态渲 ...

  9. 使用Hot Chocolate和.NET 6构建GraphQL应用(3) —— 实现Query基础功能

    系列导航 使用Hot Chocolate和.NET 6构建GraphQL应用文章索引 需求 在本文中,我们通过一个简单的例子来看一下如何实现一个最简单的GraphQL的接口. 实现 引入Hot Cho ...

  10. Android开发-页面绘制

    今天主要绘制了记账页面 记账页面用到的布局是TableLayout加Viewpager联动的方式,通过设置一个标题头可以实现页面的左右滑动,viewpager中添加两个fragment. 需要制作两个 ...