[Box] Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint
@article{cyr2019robust,
title={Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.},
author={Cyr, Eric C and Gulian, Mamikon and Patel, Ravi G and Perego, Mauro and Trask, Nathaniel},
journal={arXiv: Learning},
year={2019}}
概
这篇文章介绍了一种梯度下降的改进, 以及Box参数初始化方法.
主要内容
\arg \min_{\xi^L \xi^H} \sum_{k=1}^K \epsilon_k \|\mathcal{L}_k[u] - \sum_i \xi_i^L \mathcal{L}_k [\Phi_i(x, \xi^H)]\|^2_{\ell_2(\mathcal{X}_k)}.
\]
LSGD
固定\(\xi^H, \mathcal{X}_k\), 并令\(\epsilon_k=1\), 则问题(6)退化为一个最小二乘问题
\]
其中\(b_i = \mathcal{L}[u](x_i)\), \(A_{ij}=\mathcal{L} [\Phi_j (x_i, \xi^H)]\), \(x_i \in \mathcal{X}, i=1,\ldots, N, j=1, \ldots, w\).
所以算法如下
Box 初始化
该算法期望使得feature-rich,但是我不知道这个rich从何而来.
假设第\(l\)层的输入为\(x \in \mathbb{R}^{d_1}\), 输出为\(y \in \mathbb{R}^{d_2}\), 则该层的权重矩阵\(W \in \mathbb{R}^{d_2 \times d_1}\). 我们逐行地定义\(W\):
- 采样\(p\), \(p\sim U[0 ,1]^{d_1}\);
- 采样\(n\), \(n \sim \mathcal{N}(0,I_{d_1})\), 并令\(n=n/\|n\|\);
- 求参数\(k\)使得
\]
- \(W\)第\(i\)行\(w_i=kn^T\), \(b_i=-kp \cdot n\).
其中\(\sigma\)表示激活函数, 文中指的是ReLU.
求解参数\(k\):
- \(p_{max} = \max (0, \mathrm{sign}(n))\);
- \(k=\frac{1}{(p_{max}-p) \cdot n}\)
此\(k\)即为所需\(k\), 只需证明\(p_{max}\)是最大化
\]
的解. 最大化上式, 可以分解为
\]
故\(x_i = \max(0, \mathrm{sign}(n_i))\).
这个初始化有什么好处呢, 可以发现, 输入\(x \in[0,1]^{d_1}\)满足, 则输出\(y \in [0, 1]^{d_2}\), 保证二者的"值域"范围一致, 以此类推整个网络节点值范围近似.
如果, 作者构建了一个2-2-2-2-2-2-2-2的网络, 可以发现, Xavier 和 Kaiming的初始化方法经过一定层数后, 就会塌缩在某个点, 而Box初始化方法能够缓解这一现象.
下面是文中列出的算法(与这里的符号有一点点不同, 另外\(b\)作者应该是遗漏了负号).
Box for Resnet
因为Resnet特殊的结构,
\]
假设\(x \in [0,m]^{d_1}\), 则:
- 采样\(p\), \(p\sim U[0 ,m]^{d_1}\);
- 采样\(n\), \(n \sim \mathcal{N}(0,I_{d_1})\), 并令\(n=n/\|n\|\);
- 求参数\(k\)使得
\]
- \(W\)第\(i\)行\(w_i=kn^T\), \(b_i=-kp \cdot n\).
\]
若第一层输入\(x_i \in [0,1]\), 去\(\delta=1/L\), 其中\(L\)为总的层数, 则
\]
代码
'''
initialization.py
'''
import torch
import torch.nn as nn
import warnings
def generate(size, m, delta):
p = torch.rand(size) * m
n = torch.randn(size)
temp = 1 / torch.norm(n, p=2, dim=1, keepdim=True)
n = temp * n
pmax = nn.functional.relu(torch.sign(n)) * m
temp = (pmax - p) * n
k = (m * delta) / temp.sum(dim=1, keepdim=True)
w = k * n
b = -(w * p).sum(dim=1)
return w, b
def box_init(module, m=1, delta=1):
if isinstance(module, nn.Linear):
w, b = generate(module.weight.shape, m, delta)
try:
module.weight.data = w
module.bias.data = b
except AttributeError as e:
s = "Error: \n" + str(e) + "\n stops the initialization" \
" for this module: {}".format(module)
warnings.warn(s)
elif isinstance(module, nn.Conv2d):
outc, inc, h, w = module.weight.size()
w, b = generate((outc, inc * h * w), m, delta)
try:
module.weight.data = w.reshape(module.weight.size())
module.bias.data = b
except AttributeError as e:
s = "Error: \n" + str(e) + "\n stops the initialization" \
" for this module: {}".format(module)
warnings.warn(s)
else:
pass
"""config.py"""
nums = 10
layers = 6
method = "kaiming" #box/xavier/kaiming
net = "Net" #Net/ResNet
"""
测试
"""
import torch
import torch.nn as nn
import config
from initialization import box_init
class Net(nn.Module):
def __init__(self, l):
super(Net, self).__init__()
self.linears = []
for i in range(l):
name = "linear" + str(i)
self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
nn.ReLU()))
self.linears.append(self.__getattr__(name))
if config.method == 'box':
self.box_init()
elif config.method == "xavier":
self.xavier_init()
else:
self.kaiming_init()
def box_init(self):
for module in self.modules():
box_init(module)
def xavier_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal_(module.weight)
def kaiming_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight)
def forward(self, x):
out = []
temp = x
for linear in self.linears:
temp = linear(temp)
out.append(temp)
return out
class ResNet(nn.Module):
def __init__(self, l):
super(ResNet, self).__init__()
self.linears = []
for i in range(l):
name = "linear" + str(i)
self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
nn.ReLU()))
self.linears.append(self.__getattr__(name))
if config.method == 'box':
self.box_init(l)
elif config.method == "xavier":
self.xavier_init()
else:
self.kaiming_init()
def box_init(self, layers):
delta = 1 / layers
m = 1. + delta
l = 0
for module in self.modules():
if isinstance(module, (nn.Linear)):
if l == 0:
box_init(module, 1, 1)
else:
box_init(module, m ** l, delta)
l += 1
def xavier_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal_(module.weight)
def kaiming_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight)
def forward(self, x):
out = []
temp = x
for linear in self.linears:
temp = linear(temp) + temp
out.append(temp)
return out
if config.net == "Net":
net = Net(config.layers)
else:
net = ResNet(config.layers)
x = torch.linspace(0, 1, config.nums)
y = torch.linspace(0, 1, config.nums)
grid_x, grid_y = torch.meshgrid(x, y)
x = grid_x.flatten()
y = grid_y.flatten()
data = torch.stack((x, y), dim=1)
outs = net(data)
import matplotlib.pyplot as plt
def axplot(x, y, ax):
x = x.detach().numpy()
y = y.detach().numpy()
ax.scatter(x, y)
def plot(x, y, outs):
fig, axs = plt.subplots(1, config.layers+1, sharey=True, figsize=(12, 2))
axs[0].scatter(x, y)
axs[0].set(title="layer0")
for i in range(config.layers):
ax = axs[i+1]
out = outs[i]
x = out[:, 0]
y = out[:, 1]
axplot(x, y, ax)
ax.set(title="layer"+str(i+1))
plt.tight_layout()
plt.savefig("C:/Users/pkavs/Desktop/fig.png")
#plt.show()
plot(x, y, outs)
[Box] Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint的更多相关文章
- Training Deep Neural Networks
http://handong1587.github.io/deep_learning/2015/10/09/training-dnn.html //转载于 Training Deep Neural ...
- Coursera Deep Learning 2 Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization - week1, Assignment(Initialization)
声明:所有内容来自coursera,作为个人学习笔记记录在这里. Initialization Welcome to the first assignment of "Improving D ...
- Training (deep) Neural Networks Part: 1
Training (deep) Neural Networks Part: 1 Nowadays training deep learning models have become extremely ...
- Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks
目录 概 主要内容 深度 宽度 代码 Huang H., Wang Y., Erfani S., Gu Q., Bailey J. and Ma X. Exploring architectural ...
- [C4] Andrew Ng - Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization
About this Course This course will teach you the "magic" of getting deep learning to work ...
- (转)Understanding, generalisation, and transfer learning in deep neural networks
Understanding, generalisation, and transfer learning in deep neural networks FEBRUARY 27, 2017 Thi ...
- On Explainability of Deep Neural Networks
On Explainability of Deep Neural Networks « Learning F# Functional Data Structures and Algorithms is ...
- Classifying plankton with deep neural networks
Classifying plankton with deep neural networks The National Data Science Bowl, a data science compet ...
- Must Know Tips/Tricks in Deep Neural Networks
Must Know Tips/Tricks in Deep Neural Networks (by Xiu-Shen Wei) Deep Neural Networks, especially C ...
随机推荐
- [JAVA]动态代理与AOP的千丝万缕
动态代理与AOP的联系 别的不说,直接上图 首先是AOP切面编程 什么是切面?(自己心里想想就ok)所以所谓的切面编程,你也就懂得大体了,只是这个被切的是个程序而已 那么AOP与动态代理有什么关系呢? ...
- Django REST framework完全入门
Django REST framework 一个强大灵活的Django工具包,提供了便捷的 REST API 开发框架 我们用传统的django也可以实现REST风格的api,但是顶不住Django ...
- Libev——ev_timer 相对时间定时器
Libev中的超时监视器ev_timer,是简单的相对时间定时器,它会在给定的时间点触发超时事件,还可以在固定的时间间隔之后再次触发超时事件. 1.超时监视器ev_timer结构 typedef st ...
- LVS配置记录
目录: 一.NAT模式配置 二.DR模式配置 三.TUN模式配置 LVS原理及架构不再赘述. 一.NAT模式 部署环境 注意: 1) DIP.RIP必须为同网段: 2) RS网关必须指向DS: 3) ...
- Redis cluster 集群命令合集
目录 一.常用命令 二.操作命令 三.redis-trib.rb脚本 一.常用命令 打印集群的信息 CLUSTER INFO 列出集群当前已知的所有节点(node),以及这些节点的相关信息. CLUS ...
- 密码学之Hash散列
一.简介 hash(散列.杂凑)函数,是将任意长度的数据映射到有限长度的域上. 直观解释起来,就是对一串数据m进行杂糅,输出另一段固定长度的数据h,作为这段数据的特征(指纹).也就是说,无论数据块m有 ...
- CTF 自动拼图
忘记在哪个群里面看见有师傅说过这样一句加,百度搜索"CTF拼图脚本,有惊喜". 在做JUSTCTF的题时候,看到一道拼图题.就想着试一试. 先百度搜了,看到了fjh1997师傅的一 ...
- 再识ret2syscall
当初学rop学到的ret2syscall,对int 0x80中断了解还不是很深,这次又复习了一遍.虽然很简单,但是还是学到了新东西.那么我们就从ret2syscall开始吧. IDA一打开的时候,就看 ...
- [BUUCTF]REVERSE——findit
findit 步骤: apk文件,直接用apkide打开,找到findit文件 查看了所有的函数,没找到有关flag的线索,但是找到了一串奇怪的16进制 将第一串转换一下,感觉第二串应该是flag 拿 ...
- Indirect函数(Excel函数集团)
此处文章均为本妖原创,供下载.学习.探讨! 文章下载源是Office365国内版1Driver,如有链接问题请联系我. 请勿用于商业!谢谢 下载地址:https://officecommunity-m ...