Cyr E C, Gulian M, Patel R G, et al. Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.[J]. arXiv: Learning, 2019.

@article{cyr2019robust,

title={Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.},

author={Cyr, Eric C and Gulian, Mamikon and Patel, Ravi G and Perego, Mauro and Trask, Nathaniel},

journal={arXiv: Learning},

year={2019}}

这篇文章介绍了一种梯度下降的改进, 以及Box参数初始化方法.

主要内容

\[\tag{6}
\arg \min_{\xi^L \xi^H} \sum_{k=1}^K \epsilon_k \|\mathcal{L}_k[u] - \sum_i \xi_i^L \mathcal{L}_k [\Phi_i(x, \xi^H)]\|^2_{\ell_2(\mathcal{X}_k)}.
\]

LSGD

固定\(\xi^H, \mathcal{X}_k\), 并令\(\epsilon_k=1\), 则问题(6)退化为一个最小二乘问题

\[\arg \min_{\xi^L} \|A\xi^L -b\|^2_{\ell_2 (\mathcal{X})},
\]

其中\(b_i = \mathcal{L}[u](x_i)\), \(A_{ij}=\mathcal{L} [\Phi_j (x_i, \xi^H)]\), \(x_i \in \mathcal{X}, i=1,\ldots, N, j=1, \ldots, w\).

所以算法如下

Box 初始化

该算法期望使得feature-rich,但是我不知道这个rich从何而来.

假设第\(l\)层的输入为\(x \in \mathbb{R}^{d_1}\), 输出为\(y \in \mathbb{R}^{d_2}\), 则该层的权重矩阵\(W \in \mathbb{R}^{d_2 \times d_1}\). 我们逐行地定义\(W\):

  1. 采样\(p\), \(p\sim U[0 ,1]^{d_1}\);
  2. 采样\(n\), \(n \sim \mathcal{N}(0,I_{d_1})\), 并令\(n=n/\|n\|\);
  3. 求参数\(k\)使得
\[\max_{x \in [0, 1]^{d_1}} \sigma(k(x-p) \cdot n)=1.
\]
  1. \(W\)第\(i\)行\(w_i=kn^T\), \(b_i=-kp \cdot n\).

其中\(\sigma\)表示激活函数, 文中指的是ReLU.

求解参数\(k\):

  1. \(p_{max} = \max (0, \mathrm{sign}(n))\);
  2. \(k=\frac{1}{(p_{max}-p) \cdot n}\)

此\(k\)即为所需\(k\), 只需证明\(p_{max}\)是最大化

\[(x - p)\cdot n, \quad x \in [0,1]^{d_1}
\]

的解. 最大化上式, 可以分解为

\[\max_{x_i \in [0, 1]} x_in_i,
\]

故\(x_i = \max(0, \mathrm{sign}(n_i))\).

这个初始化有什么好处呢, 可以发现, 输入\(x \in[0,1]^{d_1}\)满足, 则输出\(y \in [0, 1]^{d_2}\), 保证二者的"值域"范围一致, 以此类推整个网络节点值范围近似.



如果, 作者构建了一个2-2-2-2-2-2-2-2的网络, 可以发现, Xavier 和 Kaiming的初始化方法经过一定层数后, 就会塌缩在某个点, 而Box初始化方法能够缓解这一现象.

下面是文中列出的算法(与这里的符号有一点点不同, 另外\(b\)作者应该是遗漏了负号).

Box for Resnet

因为Resnet特殊的结构,

\[y=(W+I)x+b.
\]

假设\(x \in [0,m]^{d_1}\), 则:

  1. 采样\(p\), \(p\sim U[0 ,m]^{d_1}\);
  2. 采样\(n\), \(n \sim \mathcal{N}(0,I_{d_1})\), 并令\(n=n/\|n\|\);
  3. 求参数\(k\)使得
\[\max_{x \in [0, m]^{d_1}} \sigma(k(x-p) \cdot n)=\delta m.
\]
  1. \(W\)第\(i\)行\(w_i=kn^T\), \(b_i=-kp \cdot n\).
\[k=\frac{\delta m}{(mp_{max}-p) \cdot n}.
\]

若第一层输入\(x_i \in [0,1]\), 去\(\delta=1/L\), 其中\(L\)为总的层数, 则

\[[0,1] \rightarrow [0,1+\frac{1}{L}] \rightarrow [0,(1+\frac{1}{L})^2] \rightarrow \cdots
\]

代码



'''
initialization.py
'''
import torch
import torch.nn as nn
import warnings def generate(size, m, delta):
p = torch.rand(size) * m
n = torch.randn(size)
temp = 1 / torch.norm(n, p=2, dim=1, keepdim=True)
n = temp * n
pmax = nn.functional.relu(torch.sign(n)) * m
temp = (pmax - p) * n
k = (m * delta) / temp.sum(dim=1, keepdim=True)
w = k * n
b = -(w * p).sum(dim=1)
return w, b def box_init(module, m=1, delta=1):
if isinstance(module, nn.Linear):
w, b = generate(module.weight.shape, m, delta)
try:
module.weight.data = w
module.bias.data = b
except AttributeError as e:
s = "Error: \n" + str(e) + "\n stops the initialization" \
" for this module: {}".format(module)
warnings.warn(s) elif isinstance(module, nn.Conv2d):
outc, inc, h, w = module.weight.size()
w, b = generate((outc, inc * h * w), m, delta)
try:
module.weight.data = w.reshape(module.weight.size())
module.bias.data = b
except AttributeError as e:
s = "Error: \n" + str(e) + "\n stops the initialization" \
" for this module: {}".format(module)
warnings.warn(s) else:
pass


"""config.py"""

nums = 10
layers = 6
method = "kaiming" #box/xavier/kaiming
net = "Net" #Net/ResNet


"""
测试
""" import torch
import torch.nn as nn
import config
from initialization import box_init class Net(nn.Module): def __init__(self, l):
super(Net, self).__init__() self.linears = []
for i in range(l):
name = "linear" + str(i)
self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
nn.ReLU()))
self.linears.append(self.__getattr__(name))
if config.method == 'box':
self.box_init()
elif config.method == "xavier":
self.xavier_init()
else:
self.kaiming_init() def box_init(self):
for module in self.modules():
box_init(module) def xavier_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal_(module.weight) def kaiming_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight) def forward(self, x):
out = []
temp = x
for linear in self.linears:
temp = linear(temp)
out.append(temp)
return out class ResNet(nn.Module): def __init__(self, l):
super(ResNet, self).__init__() self.linears = []
for i in range(l):
name = "linear" + str(i)
self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
nn.ReLU()))
self.linears.append(self.__getattr__(name))
if config.method == 'box':
self.box_init(l)
elif config.method == "xavier":
self.xavier_init()
else:
self.kaiming_init() def box_init(self, layers):
delta = 1 / layers
m = 1. + delta
l = 0
for module in self.modules():
if isinstance(module, (nn.Linear)):
if l == 0:
box_init(module, 1, 1)
else:
box_init(module, m ** l, delta)
l += 1 def xavier_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal_(module.weight) def kaiming_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight) def forward(self, x):
out = []
temp = x
for linear in self.linears:
temp = linear(temp) + temp
out.append(temp)
return out if config.net == "Net":
net = Net(config.layers)
else:
net = ResNet(config.layers) x = torch.linspace(0, 1, config.nums)
y = torch.linspace(0, 1, config.nums) grid_x, grid_y = torch.meshgrid(x, y) x = grid_x.flatten()
y = grid_y.flatten()
data = torch.stack((x, y), dim=1)
outs = net(data) import matplotlib.pyplot as plt def axplot(x, y, ax):
x = x.detach().numpy()
y = y.detach().numpy()
ax.scatter(x, y) def plot(x, y, outs):
fig, axs = plt.subplots(1, config.layers+1, sharey=True, figsize=(12, 2))
axs[0].scatter(x, y)
axs[0].set(title="layer0")
for i in range(config.layers):
ax = axs[i+1]
out = outs[i]
x = out[:, 0]
y = out[:, 1]
axplot(x, y, ax)
ax.set(title="layer"+str(i+1))
plt.tight_layout()
plt.savefig("C:/Users/pkavs/Desktop/fig.png")
#plt.show()
plot(x, y, outs)

[Box] Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint的更多相关文章

  1. Training Deep Neural Networks

    http://handong1587.github.io/deep_learning/2015/10/09/training-dnn.html  //转载于 Training Deep Neural ...

  2. Coursera Deep Learning 2 Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization - week1, Assignment(Initialization)

    声明:所有内容来自coursera,作为个人学习笔记记录在这里. Initialization Welcome to the first assignment of "Improving D ...

  3. Training (deep) Neural Networks Part: 1

    Training (deep) Neural Networks Part: 1 Nowadays training deep learning models have become extremely ...

  4. Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks

    目录 概 主要内容 深度 宽度 代码 Huang H., Wang Y., Erfani S., Gu Q., Bailey J. and Ma X. Exploring architectural ...

  5. [C4] Andrew Ng - Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization

    About this Course This course will teach you the "magic" of getting deep learning to work ...

  6. (转)Understanding, generalisation, and transfer learning in deep neural networks

    Understanding, generalisation, and transfer learning in deep neural networks FEBRUARY 27, 2017   Thi ...

  7. On Explainability of Deep Neural Networks

    On Explainability of Deep Neural Networks « Learning F# Functional Data Structures and Algorithms is ...

  8. Classifying plankton with deep neural networks

    Classifying plankton with deep neural networks The National Data Science Bowl, a data science compet ...

  9. Must Know Tips/Tricks in Deep Neural Networks

    Must Know Tips/Tricks in Deep Neural Networks (by Xiu-Shen Wei)   Deep Neural Networks, especially C ...

随机推荐

  1. 12-gauge/bore shotgun

    12-gauge/bore shotgun不是弹夹(magazine)容量为12发的霰(xian)弹枪.[LDOCE]gauge - a measurement of the width or thi ...

  2. Hive(八)【行转列、列转行】

    目录 一.行转列 相关函数 concat concat_ws collect_set collect_list 需求 需求分析 数据准备 写SQL 二.列转行 相关函数 split explode l ...

  3. 大数据学习day31------spark11-------1. Redis的安装和启动,2 redis客户端 3.Redis的数据类型 4. kafka(安装和常用命令)5.kafka java客户端

    1. Redis Redis是目前一个非常优秀的key-value存储系统(内存的NoSQL数据库).和Memcached类似,它支持存储的value类型相对更多,包括string(字符串).list ...

  4. Template Metaprogramming in C++

    说实话,学习C++以来,第一次听说"Metaprogramming"这个名词. Predict the output of following C++ program. 1 #in ...

  5. Mysql的表级锁

    我们首先需要知道的一个大前提是:mysql的锁是由具体的存储引擎实现的.所以像Mysql的默认引擎MyISAM和第三方插件引擎 InnoDB的锁实现机制是有区别的.可根据不同的场景选用不同的锁定机制. ...

  6. Oracle存储过程游标for循环怎么写

    一.不带参数的游标for循环 首先编写存储过程的整体结构,如下: create or replace procedure test_proc is v_date date; --变量定义 begin ...

  7. 【Linux】【Commands】文本查看类

    分屏查看命令:more和less more命令: more FILE 特点:翻屏至文件尾部后自动退出: less命令: less FILE head命令: 查看文件的前n行: head [option ...

  8. window安装ab压力测试

    ab是Apache HTTP server benchmarking tool的缩写,可以用以测试HTTP请求的服务器性能,也是业界比较流行和简单易用的一种压力测试工具包 ## 下载 下载地址:(ht ...

  9. 部署应用程序到Tomcat的webapps目录

    一.方法如下 1.通过MyEclipse上方工具栏Manage Deployments,依次选择项目和服务器: 2.通过右击项目Export,生成war包到webapps中: 3.复制项目WebRoo ...

  10. 【C/C++】例题3-5 生成元/算法竞赛入门经典/数组与字符串

    [题目] x+x的各位数之和为y,x为y的生成元. 求10万以内的n的最小生成元,无解输出0. [解答] 这是我根据自己的想法最初写的代码: #include<cstdio> #inclu ...