目录

Wang Z., Zhang W., Liu N. and Wang J. Scalable rule-based representation learning for interpretable classification. In Advances in Neural Information Processing Systems (NIPS), 2021.

传统的诸如决策树之类的机器学习方法具有很强的结构性, 也因此具有很好的可解释性. 和深度学习方法相比, 这类方法比较难以推广到大规模的问题上, 很重要的一个原因便是, 其离散的参数和结构导致无法利用梯度进行优化. 本文是对利用梯度来优化这些模型的一个尝试.

主要内容

本文考虑的是上图(a)中的离散模型, 其接受连续变量\(C_i\)和离散变量\(B_i\):

  1. 通过Binarization Layer 将连续变量\(C_i\)离散化并与\(B_i\)拼接得到输入\(\bm{u}^{(0)}\);
  2. 对于Logical Layer, 其以\(\bm{u}^{l-1}\)为输入, 输出\(\bm{u}^l\), 其包含且\(\bm{r}\)和或\(\bm{s}\)两个部分:
\[r_i^{(l)} = \bigwedge_{W_{ij}^{(l, 0)} = 1} u_j^{(l-1)}, \\
s_i^{(l)} = \bigvee_{W_{ij}^{(l, 1)} = 1} u_j^{(l-1)}. \\
\]

其中\(W^{(l, 0)}\)表示\(\bm{r}\)与\(\bm{u}\)的邻接矩阵, 而\(W^{(l, 1)}\)表示\(\bm{s}\)与\(\bm{u}\)的邻接矩阵. 可以发现, Logical Layer中的输入输出和权重都是二元的.

3. 最后通过一个线性层进行分类, 需要说明的是, 线性层的权重是连续的.

显然由于logical layer是离散的, 直接通过梯度更新是办不到的. 一个自然的想法是用一个连续的版本\(\hat{\mathcal{F}}(X; \theta)\)进行替换, 更新连续的参数\(\theta\)然后获得下列的离散的版本:

\[\mathcal{F}(X; q(\theta)), \quad q(x) = \mathbb{I}_{x > 0.5}.
\]

显然直接套用这个方法是低效的, 因为训练过程和离散没有任何关系, 我们没法保证离散后的模型依旧是有效的, 此外还有一个问题, 上述离散模型如何匹配到一个连续的版本.

下面是一个有趣的解决方案, 假设\(\hat{W}_{i,j} \in [0, 1]\), 则

\[Conj (\bm{u}, W_i) = \prod_{j=1}^n \bigg\{1 - W_{i,j}(1 - u_j) \bigg\}, \\
Disj (\bm{u}, W_i) = 1 - \prod_{j=1}^n \bigg\{1 - W_{i,j}u_j \bigg\}, \\
\]

便为且和或操作的连续版本.

试想:

\[\begin{array}{ll}
& r_i = 1 \\
\Leftrightarrow & \bigwedge_j [u_j^{(l-1)} \vee (1 - W_{ij})] = 1\\
\Leftrightarrow & \prod_j \bigg\{1 - W_{i,j}(1 - u_j) \bigg\} = 1.\\
\end{array}
\]

其它情况可以类似推导, 实在是有趣.

但是上述式子在实际中会有一些梯度消失的问题(因为连乘号, 且内部是[0, 1]之间的), 所示在实际使用中, 作者加了一个投影算子

\[Conj_+ = \mathbb{P}(Conj (\bm{u}, W_i)),
\]

其中(这设计都是为了避免梯度消失, 怎么想到的? 怎么会往这个方向去想的?)

\[\mathbb{P}(v) = \frac{-1}{-1 + \log (v)}.
\]

解决了连续版本的问题, 现在剩下的难啃的地方是如何更新\(\theta\)以保证\(q(\theta)\)也是有意义的.

作者采用如下的梯度更新公式:

\[\theta^{t+1} = \theta^t - \eta \frac{\partial \mathcal{L}(\bar{Y})}{\partial \bar{Y}} \cdot \frac{\partial \hat{Y}}{\partial \theta^t},
\]

其中\(\hat{Y} = \hat{\mathcal{F}}(X; \theta)\), \(\bar{Y} = \mathcal{F}(X; \bar{\theta})\).

作者用了一个嫁接的例子来说明该思想, 即损失关于预测的导数用离散的, 内部的导数用连续的.

我惊讶的是, 这些改动居然work? 太不可思议了.

Scalable Rule-Based Representation Learning for Interpretable Classification的更多相关文章

  1. 论文解读(SUBG-CON)《Sub-graph Contrast for Scalable Self-Supervised Graph Representation Learning》

    论文信息 论文标题:Sub-graph Contrast for Scalable Self-Supervised Graph Representation Learning论文作者:Yizhu Ji ...

  2. Hierarchical Attention Based Semi-supervised Network Representation Learning

    Hierarchical Attention Based Semi-supervised Network Representation Learning 1. 任务 给定:节点信息网络 目标:为每个节 ...

  3. [论文阅读笔记] metapath2vec: Scalable Representation Learning for Heterogeneous Networks

    [论文阅读笔记] metapath2vec: Scalable Representation Learning for Heterogeneous Networks 本文结构 解决问题 主要贡献 算法 ...

  4. (zhuan) Evolution Strategies as a Scalable Alternative to Reinforcement Learning

    Evolution Strategies as a Scalable Alternative to Reinforcement Learning this blog from: https://blo ...

  5. (zhuan) Notes on Representation Learning

    this blog from: https://opendatascience.com/blog/notes-on-representation-learning-1/   Notes on Repr ...

  6. Self-Supervised Representation Learning

    Self-Supervised Representation Learning 2019-11-11 21:12:14  This blog is copied from: https://lilia ...

  7. 【PSMA】Progressive Sample Mining and Representation Learning for One-Shot Re-ID

    目录 主要挑战 主要的贡献和创新点 提出的方法 总体框架与算法 Vanilla pseudo label sampling (PLS) PLS with adversarial learning Tr ...

  8. 论文解读GALA《Symmetric Graph Convolutional Autoencoder for Unsupervised Graph Representation Learning》

    论文信息 Title:<Symmetric Graph Convolutional Autoencoder for Unsupervised Graph Representation Learn ...

  9. (转)Predictive learning vs. representation learning 预测学习 与 表示学习

    Predictive learning vs. representation learning  预测学习 与 表示学习 When you take a machine learning class, ...

随机推荐

  1. day07 Nginx入门

    day07 Nginx入门 Nginx简介 Nginx是一个开源且高性能.可靠的http web服务.代理服务 开源:直接获取源代码 高性能:支持海量开发 可靠:服务稳定 特点: 1.高性能.高并发: ...

  2. Spark基础:(五)Spark编程进阶

    共享变量 (1)累加器:是用来对信息进行聚合的,同时也是Spark中提供的一种分布式的变量机制,其原理类似于mapreduce,即分布式的改变,然后聚合这些改变.累加器的一个常见用途是在调试时对作业执 ...

  3. 『学了就忘』Linux启动引导与修复 — 69、启动引导程序(grub)

    目录 1.启动引导程序(Boot Loader)简介 2.启动引导程序grub的作用 3.启动引导程序grub的位置 4./grub目录中其他的文件简单介绍 提示: 简单地说,Boot Loader就 ...

  4. 【leetcode】797. All Paths From Source to Target

    Given a directed acyclic graph (DAG) of n nodes labeled from 0 to n - 1, find all possible paths fro ...

  5. 调试器gdb

    1.启动和退出gdb gdb调试的对象是可执行文件,而不是程序源代码.如果要使一个可执行文件可以被gdb调试,那么在使用编译器gcc编译程序时加入-g选项.-g选项告诉gcc在编译程序时加入调试信息, ...

  6. Oracle中的DBMS_LOCK包的使用

    一.DBMS_LOCK相关知识介绍 锁模式: 名字 描述 数据类型 值 nl_mode Null INTEGER 1 ss_mode Sub Shared: used on an aggregate ...

  7. spring生成EntityManagerFactory的三种方式

    spring生成EntityManagerFactory的三种方式 1.LocalEntityManagerFactoryBean只是简单环境中使用.它使用JPA PersistenceProvide ...

  8. windows 查看端口被占用,解除占用

    查看 (列举端口为2688) netstat -ano | findstr "2688" 解除 原文地址

  9. Centos 7 安装redis,修改配置文件不生效、外网不能访问。

    前提: 在用Centos 7 安装 redis 时,遇上一下几个问题 ,记录下 . 1.修改配置文件,按官网步骤启动,不生效. 2.外网无法访问redis. 步骤: 1.打开centos 虚拟机 ,按 ...

  10. Jenkins监控

    目录 一.Monitoring插件 二.Prometheus监控 一.Monitoring插件 Monitoring插件(monitoring)使用JavaMelody,对Jenkins进行监控.插件 ...