Mroueh Y, Sercu T, Goel V, et al. McGan: Mean and Covariance Feature Matching GAN[J]. arXiv: Learning, 2017.

@article{mroueh2017mcgan:,

title={McGan: Mean and Covariance Feature Matching GAN},

author={Mroueh, Youssef and Sercu, Tom and Goel, Vaibhava},

journal={arXiv: Learning},

year={2017}}

利用均值和协方差构建IPM, 获得相应的mean GAN 和 covariance gan.

主要内容

IPM:

\[d_{\mathscr{F}} (\mathbb{P}, \mathbb{Q}) = \sup_{f \in \mathscr{F}} |\mathbb{E}_{x \sim \mathbb{P}} f(x) - \mathbb{E}_{x \sim \mathbb{Q}} f(x)|.
\]

当\(\mathscr{F}\)是对称空间, 即\(f \in \mathscr{F} \rightarrow - f \in \mathscr{F}\),可得

\[d_{\mathscr{F}} (\mathbb{P}, \mathbb{Q}) = \sup_{f \in \mathscr{F}} \big \{\mathbb{E}_{x \sim \mathbb{P}} f(x) - \mathbb{E}_{x \sim \mathbb{Q}} f(x) \big\}.
\]

Mean Matching IPM

\[\mathscr{F}_{v,w,p}:= \{f(x)=\langle v, \Phi_w(x) \rangle | v\in \mathbb{R}^m, \|v\|_p \le 1, \Phi_w:\mathcal{X} \rightarrow \mathbb{R}^m, w \in \Omega\},
\]

其中\(\|\cdot \|_p\)表示\(\ell_p\)范数, \(\Phi_w\)往往用网络来表示, 我们可通过截断\(w\)来使得\(\mathscr{F}_{v,w,p}\)为有界线性函数空间(有界从而使得后面推导中\(\sup\)成为\(\max\)).



其中

\[\mu_w(\mathbb{P})= \mathbb{E}_{x \sim \mathbb{P}} [\Phi_w(x)] \in \mathbb{R}^m.
\]

最后一个等式的成立是因为:

\[\|x\|_* = \max \{\langle v, x \rangle | \|v\| \le 1\},
\]

又\(\| \cdot \|_p\)的对偶范数是\(\|\cdot\|_q, \frac{1}{p}+\frac{1}{q}=1\).

prime

整个GAN的训练过程即为

\[\tag{3}
\min_{g_\theta} \max_{w \in \Omega} \max_{v, \|v\|_p \le 1} \mathscr{L}_{\mu} (v,w,\theta),
\]

其中

\[\mathscr{L}_{\mu} (v,w,\theta) = \langle v, \mathbb{E}_{x \in \mathbb{P}_r} \Phi_w(x) - \mathbb{E}_{z \sim p(z)} \Phi_w(g_{\theta} (z)) \rangle.
\]

估计形式为

dual

也有对应的dual形态

\[\tag{4}
\min_{g_\theta} \max_{w \in \Omega} \|\mu_w(\mathbb{P}_r) - \mu_w (\mathbb{P}_{\theta})\|_q.
\]

Covariance Feature Matching IPM

\[\mathscr{F}_{U, V,w} := \{f(x)= \sum_{j=1}^k \langle u_j, \Phi_w(x) \rangle \langle v_j, \Phi_w(x)\rangle, \langle u_i, u_j \rangle = \langle v_i, v_j \rangle =0, i \not = j, else \:1 \},
\]

等价于

\[\mathscr{F}_{U, V,w} := \{f(x)= \langle U^T \Phi_w(x), V^T\Phi_w(x) \rangle, U^TU=I_k, V^TV=I_k, w \in \Omega \}.
\]

并有

其中\([A]_k\)表示\(A\)的\(k\)阶近似, 如果\(A = \sum_i \sigma_iu_iv_i^T\), \(\sigma_1\ge \sigma_2,\ldots\), 则\([A]_k=\sum_{i=1}^k \sigma_i u_iv_i^T\). \(\mathcal{O}_{m,k} := \{M \in \mathbb{R}^{m \times k} | M^TM = I_k \}\), \(\|A\|_*=\sum_i \sigma_i\)表示算子范数.

prime

\[\tag{6}
\min_{g_\theta} \max_{w \in \Omega} \max_{U,V \in \mathcal{P}_{m, k}} \mathscr{L}_{\sigma} (U, V,w,\theta),
\]

其中

\[\mathscr{L}_{\sigma} (U,V,w,\theta) = \mathbb{E}_{x \sim \mathbb{P}_r} \langle U^T \Phi_w(x), V^T\Phi_w(x) \rangle- \mathbb{E}_{z \sim p_z} \langle U^T \Phi_w(g_{\theta}(z)), V^T\Phi_w(g_{\theta}(z)) \rangle.
\]

采用下式估计

dual

\[\tag{7}
\min_{g_{\theta}} \max_{w \in \Omega} \| [\Sigma_w(\mathbb{P}_r) - \Sigma_w(\mathbb{P}_{\theta})]_k\|_*.
\]

注: 既然\(\Sigma_w(\mathbb{P}_r) - \Sigma_w(\mathbb{P}_{\theta})\)是对称的, 为什么\(U \not =V\)? 因为虽然其对称, 但是并不(半)正定, 所以\(v_i=-u_i\)也是有可能的.

算法



代码

未经测试.

  1. import torch
  2. import torch.nn as nn
  3. from torch.nn.functional import relu
  4. from collections.abc import Callable
  5. def preset(**kwargs):
  6. def decorator(func):
  7. def wrapper(*args, **nkwargs):
  8. nkwargs.update(kwargs)
  9. return func(*args, **nkwargs)
  10. wrapper.__doc__ = func.__doc__
  11. wrapper.__name__ = func.__name__
  12. return wrapper
  13. return decorator
  14. class Meanmatch(nn.Module):
  15. def __init__(self, p, dim, dual=False, prj='l2'):
  16. super(Meanmatch, self).__init__()
  17. self.norm = p
  18. self.dual = dual
  19. if dual:
  20. self.dualnorm = self.norm
  21. else:
  22. self.init_weights(dim)
  23. self.projection = self.proj(prj)
  24. @property
  25. def dualnorm(self):
  26. return self.__dualnorm
  27. @dualnorm.setter
  28. def dualnorm(self, norm):
  29. if norm == 'inf':
  30. norm = float('inf')
  31. elif not isinstance(norm, float):
  32. raise ValueError("Invalid norm")
  33. p = 1 / (1 - 1 / norm)
  34. self.__dualnorm = preset(p=p, dim=1)(torch.norm)
  35. def init_weights(self, dim):
  36. self.weights = nn.Parameter(torch.rand((1, dim)),
  37. requires_grad=True)
  38. @staticmethod
  39. def _proj1(x):
  40. u = x.max()
  41. if u <= 1.:
  42. return x
  43. l = 0.
  44. c = (u + l) / 2
  45. while (u - l) > 1e-4:
  46. r = relu(x - c).sum()
  47. if r > 1.:
  48. l = c
  49. else:
  50. u = c
  51. c = (u + l) / 2
  52. return relu(x - c)
  53. @staticmethod
  54. def _proj2(x):
  55. return x / torch.norm(x)
  56. @staticmethod
  57. def _proj3(x):
  58. return x / torch.max(x)
  59. def proj(self, prj):
  60. if prj == "l1":
  61. return self._proj1
  62. elif prj == "l2":
  63. return self._proj2
  64. elif prj == "linf":
  65. return self._proj3
  66. else:
  67. assert isinstance(prj, Callable), "Invalid prj"
  68. return prj
  69. def forward(self, real, fake):
  70. temp = (real - fake).mean(dim=1)
  71. if self.dual:
  72. return self.dualnorm(temp)
  73. elif not self.training and self.dual:
  74. raise TypeError("just for training...")
  75. else:
  76. self.weights.data = self.projection(self.weights.data) #some diff here!!!!!!!!!!
  77. return self.weights @ temp
  78. class Covmatch(nn.Module):
  79. def __init__(self, dim, k):
  80. super(Covmatch, self).__init__()
  81. self.init_weights(dim, k)
  82. def init_weights(self, dim, k):
  83. temp1 = torch.rand((dim, k))
  84. temp2 = torch.rand((dim, k))
  85. self.U = nn.Parameter(temp1, requires_grad=True)
  86. self.V = nn.Parameter(temp2, requires_grad=True)
  87. def qr(self, w):
  88. q, r = torch.qr(w)
  89. sign = r.diag().sign()
  90. return q * sign
  91. def update_weights(self):
  92. self.U.data = self.qr(self.U.data)
  93. self.V.data = self.qr(self.V.data)
  94. def forward(self, real, fake):
  95. self.update_weights()
  96. temp1 = real @ self.U
  97. temp2 = real @ self.V
  98. temp3 = fake @ self.U
  99. temp4 = fake @ self.V
  100. part1 = torch.trace(temp1 @ temp2.t()).mean()
  101. part2 = torch.trace(temp3 @ temp4.t()).mean()
  102. return part1 - part2

McGan: Mean and Covariance Feature Matching GAN的更多相关文章

  1. Computer Vision_33_SIFT:Robust scale-invariant feature matching for remote sensing image registration——2009

    此部分是计算机视觉部分,主要侧重在底层特征提取,视频分析,跟踪,目标检测和识别方面等方面.对于自己不太熟悉的领域比如摄像机标定和立体视觉,仅仅列出上google上引用次数比较多的文献.有一些刚刚出版的 ...

  2. Computer Vision_33_SIFT:Remote Sensing Image Registration With Modified SIFT and Enhanced Feature Matching——2017

    此部分是计算机视觉部分,主要侧重在底层特征提取,视频分析,跟踪,目标检测和识别方面等方面.对于自己不太熟悉的领域比如摄像机标定和立体视觉,仅仅列出上google上引用次数比较多的文献.有一些刚刚出版的 ...

  3. [OpenCV] Feature Matching

    得到了杂乱无章的特征点后,要筛选出好的特征点,也就是good matches. BruteForceMatcher FlannBasedMatcher 两者的区别:http://yangshen998 ...

  4. [转]GAN论文集

    really-awesome-gan A list of papers and other resources on General Adversarial (Neural) Networks. Th ...

  5. [论文理解] Good Semi-supervised Learning That Requires a Bad GAN

    Good Semi-supervised Learning That Requires a Bad GAN 恢复博客更新,最近没那么忙了,记录一下学习. Intro 本文是一篇稍微偏理论的半监督学习的 ...

  6. Generative Adversarial Nets[Improved GAN]

    0.背景 Tim Salimans等人认为之前的GANs虽然可以生成很好的样本,然而训练GAN本质是找到一个基于连续的,高维参数空间上的非凸游戏上的纳什平衡.然而不幸的是,寻找纳什平衡是一个十分困难的 ...

  7. (转) GAN论文整理

    本文转自:http://www.jianshu.com/p/2acb804dd811 GAN论文整理 作者 FinlayLiu 已关注 2016.11.09 13:21 字数 1551 阅读 1263 ...

  8. 常见GAN的应用

    深入浅出 GAN·原理篇文字版(完整)|干货 from:http://baijiahao.baidu.com/s?id=1568663805038898&wfr=spider&for= ...

  9. AI佳作解读系列(六) - 生成对抗网络(GAN)综述精华

    注:本文来自机器之心的PaperWeekly系列:万字综述之生成对抗网络(GAN),如有侵权,请联系删除,谢谢! 前阵子学习 GAN 的过程发现现在的 GAN 综述文章大都是 2016 年 Ian G ...

随机推荐

  1. A Child's History of England.15

    And indeed it did. For, the great army landing from the great fleet, near Exeter, went forward, layi ...

  2. Oracle—全局变量

    Oracle全局变量 一.数据库程序包全局变量       在程序实现过程中,经常用遇到一些全局变量或常数.在程序开发过程中,往往会将该变量或常数存储于临时表或前台程序的全局变量中,由此带来运行效率降 ...

  3. 使用Rapidxml重建xml树

    需求 : 重建一棵xml树, 在重建过程中对原来的标签进行一定的修改. 具体修改部分就不给出了, 这里只提供重建部分的代码 code : /****************************** ...

  4. Default Constructors

    A constructor without any arguments or with default value for every argument, is said to be default ...

  5. 通过Shell统计PV和UV

    PV.UV是网站分析中最基础.最常见的指标.PV即PageView,网站浏览量,指页面的浏览次数,用以衡量网站用户访问的网页数量.用户没打开一个页面便记录1次PV,多次打开同一页面则浏览量累计:UV即 ...

  6. hash 模式与 history 模式小记

    hash 模式 这里的 hash 就是指 url 后的 # 号以及后面的字符.比如说 "www.baidu.com/#hashhash" ,其中 "#hashhash&q ...

  7. 【Spring Framework】Spring注解设置Bean的初始化、销毁方法的方式

    bean的生命周期:创建---初始化---销毁. Spring中声明的Bean的初始化和销毁方法有3种方式: @Bean的注解的initMethod.DestroyMethod属性 bean实现Ini ...

  8. 【Services】【Web】【Nginx】静态下载页面的安装与配置

    1. 拓扑 F5有自动探活机制,如果一台机器宕机,请求会转发到另外一台,省去了IPVS漂移的麻烦 F5使用轮询算法,向两台服务器转发请求,实现了负载均衡 2. 版本: 2.1 服务器版本:RHEL7. ...

  9. SpringBoot环境下java实现文件的下载

    思路:文件下载,就是给服务器上的文件创建输入流,客户端创建输出流,将文件读出,读入到客户端的输出流中,(流与流的转换) package com.cst.icode.controller; import ...

  10. java 对 final 关键字 深度理解

    基础理解 : 1.修饰类 当用final去修饰一个类的时候,表示这个类不能被继承.处于安全,在JDK中,被设计为final类的有String.System等,这些类不能被继承 .注意:被修饰的类的成员 ...