神经网络高维互信息计算Python实现(MINE)
论文
Belghazi, Mohamed Ishmael, et al. “ Mutual information neural estimation .” International Conference on Machine Learning . 2018.
利用神经网络的梯度下降法可以实现快速高维连续随机变量之间互信息的估计,上述论文提出了Mutual Information Neural Estimator (MINE)。NN在维度和样本量上都是线性可伸缩的,MI的计算可以通过反向传播进行训练。
核心

Python实现
现有github上的代码无法计算和估计高维随机变量,只能计算一维随机变量,下面的代码给出的修改方案能够计算真实和估计高维随机变量的真实互信息。
其中,为了计算理论的真实互信息,我们不直接暴力求解矩阵(耗时,这也是为什么要有MINE的原因),我们采用给定生成随机变量的参数计算理论互信息。
SIGNAL_NOISE = 0.2
SIGNAL_POWER = 3
完整代码基于pytorch
# Name: MINE_simple
# Author: Reacubeth
# Time: 2020/12/15 18:49
# Mail: noverfitting@gmail.com
# Site: www.omegaxyz.com
# *_*coding:utf-8 *_*
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
SIGNAL_NOISE = 0.2
SIGNAL_POWER = 3
data_dim = 3
num_instances = 20000
def gen_x(num, dim):
return np.random.normal(0., np.sqrt(SIGNAL_POWER), [num, dim])
def gen_y(x, num, dim):
return x + np.random.normal(0., np.sqrt(SIGNAL_NOISE), [num, dim])
def true_mi(power, noise, dim):
return dim * 0.5 * np.log2(1 + power/noise)
mi = true_mi(SIGNAL_POWER, SIGNAL_NOISE, data_dim)
print('True MI:', mi)
hidden_size = 10
n_epoch = 500
class MINE(nn.Module):
def __init__(self, hidden_size=10):
super(MINE, self).__init__()
self.layers = nn.Sequential(nn.Linear(2 * data_dim, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1))
def forward(self, x, y):
batch_size = x.size(0)
tiled_x = torch.cat([x, x, ], dim=0)
idx = torch.randperm(batch_size)
shuffled_y = y[idx]
concat_y = torch.cat([y, shuffled_y], dim=0)
inputs = torch.cat([tiled_x, concat_y], dim=1)
logits = self.layers(inputs)
pred_xy = logits[:batch_size]
pred_x_y = logits[batch_size:]
loss = - np.log2(np.exp(1)) * (torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y))))
# compute loss, you'd better scale exp to bit
return loss
model = MINE(hidden_size)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
plot_loss = []
all_mi = []
for epoch in tqdm(range(n_epoch)):
x_sample = gen_x(num_instances, data_dim)
y_sample = gen_y(x_sample, num_instances, data_dim)
x_sample = torch.from_numpy(x_sample).float()
y_sample = torch.from_numpy(y_sample).float()
loss = model(x_sample, y_sample)
model.zero_grad()
loss.backward()
optimizer.step()
all_mi.append(-loss.item())
fig, ax = plt.subplots()
ax.plot(range(len(all_mi)), all_mi, label='MINE Estimate')
ax.plot([0, len(all_mi)], [mi, mi], label='True Mutual Information')
ax.set_xlabel('training steps')
ax.legend(loc='best')
plt.show()
结果
变量维度为1

变量维度为3

需要指出的是在计算最终的互信息时需要将基数e转为基数2。如果只是求得一个比较值,在真实使用的过程中可以省略。
本文的文字及图片来源于网络,仅供学习、交流使用,不具有任何商业用途,如有问题请及时联系我们以作处理
想要获取更多Python学习资料可以加
QQ:2955637827私聊
或加Q群630390733
大家一起来学习讨论吧!
神经网络高维互信息计算Python实现(MINE)的更多相关文章
- 基于神经网络的混合计算(DNC)-Hybrid computing using a NN with dynamic external memory
前言: DNC可以称为NTM的进一步发展,希望先看看这篇译文,关于NTM的译文:人工机器-NTM-Neutral Turing Machine 基于神经网络的混合计算 Hybrid computing ...
- 北京地铁月度消费总金额计算(Python版)
最近业余时间在学习Python,这是那天坐地铁时突发奇想,想看看我这一个月的地铁费共多少钱,所以简单的构思了下思路,就直接开写了,没想到用Python来实现还挺简单的. 设计思路: 每次乘车正常消费7 ...
- 函数计算 Python 连接 SQL Server 小结
python 连接数据库通常要安装第三方模块,连接 MS SQL Server 需要安装 pymssql .由于 pymsql 依赖于 FreeTDS,对于先于 2.1.3 版本的 pymssql,需 ...
- GIL计算python 2 和 python 3 计算密集型
首先我画了一张图来表示GIL运行的方式: Python 3执行如下计算代码:#-*-conding:utf-8-*-import threading import timedef add(): n = ...
- 计算Python运行时间
可以调用datetime 或者 time库实现得到Python运行时间 方法1 import datetime start_t = datetime.datetime.now() #运行大型代码 e ...
- 机器学习作业(四)神经网络参数的拟合——Python(numpy)实现
题目下载[传送门] 题目简述:识别图片中的数字,训练该模型,求参数θ. 出现了一个问题:虽然训练的模型能够有很好的预测准确率,但是使用minimize函数时候始终无法成功,无论设计的迭代次数有多大,如 ...
- 相似度与距离计算python代码实现
#定义几种距离计算函数 #更高效的方式为把得分向量化之后使用scipy中定义的distance方法 from math import sqrt def euclidean_dis(rating1, r ...
- 计算Python代码运行时间长度方法
在代码中有时要计算某部分代码运行时间,便于分析. import time start = time.clock() run_function() end = time.clock() print st ...
- 菜鸟之路——机器学习之BP神经网络个人理解及Python实现
关键词: 输入层(Input layer).隐藏层(Hidden layer).输出层(Output layer) 理论上如果有足够多的隐藏层和足够大的训练集,神经网络可以模拟出任何方程.隐藏层多的时 ...
随机推荐
- 如何使用系统清理缓存软件优化MacBook
在我们使用我们的Mac一定的时间后,总是不可避免的出现Mac内存不足的情况,所以清理垃圾软件也就成为了我们电脑里必不可少的软件.苹果软件商店中有很多各有不同的清理垃圾软件,但我们往往很难从这一大堆软件 ...
- JUC并发工具包之CountDownLatch
1.介绍 本文将介绍CountDownLatch并给出实践中的几个例子,通过使用CountDownLatch我们可以让一个线程阻塞直到其他一个或多个线程执行完成. A synchronization ...
- Java中的第三大特性-多态性
一.多态性的概念 多态性是以继承为基础上的,举个例子,人属于动物,狗也属于动物,所以动物就是父类,而人和狗都是动物的子类,都属于动物. 二.多态的使用 (1)多态一般用于方法参数或者方法返回值,特别当 ...
- 基于混沌Logistic加密算法的图片加密与还原
摘要 一种基于混沌Logistic加密算法的图片加密与还原的方法,并利用Lena图和Baboon图来验证这种加密算法的加密效果.为了能够体现该算法在图片信息加密的效果,本文还采用了普通行列置乱加密算法 ...
- 如何测试一个APP
1.是否支持各种手机系统 2.是否会因为分辨率而出错 3.不同机型能否安装 4.老旧机型 能否通用 5.广告时长 6.测试能否登陆注册 7.卸载时是否会发生意外 8.安装时会不会误认为带病毒 9.用户 ...
- vs2019 Com组件初探-简单的COM编写以及实现跨语言调用
前提条件 1.掌握C++基础语法 2.平台安装 vs2019 3.本地平台为 windows 10 1909 X64 4.了解vbs基础语法 本次目标 1.掌握Com组件的概念及原理 2.编写一个简单 ...
- k8s+docker_part2
docker+k8s 目录 docker+k8s 1 简介 1.1 docker是什么 1.2 为什么要用docker 1.2.1 docker容器虚拟化的好处 1.2.2 docker在开发和运维中 ...
- Spring Boot + Elasticsearch 使用示例
本文分别使用 Elasticsearch Repository 和 ElasticsearchTemplate 实现 Elasticsearch 的简单的增删改查 一.Elastic Stack El ...
- 腾讯短信平台ASP接口范例
疫情后一个小项目要用到腾讯短信平台,因为比较老,用ASP写的,平台没有相应的ASP接口,百度不到,无奈之下自己写了一个,也方便需要的朋友们. 主要代码如下: <!--#include file= ...
- 数据库:Flask-SQLAlchemy
一.安装以及使用 1.安装 安装 flask-sqlalchemy pip install flask-sqlalchemy 如果连接的是 mysql 数据库,需要安装 mysqldb pip ins ...