pytorch固定BN层参数
背景:基于PyTorch
的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch
相同的测试数据经过主分支输出的结果不同。
原因:未固定主分支BN
层中的running_mean
和running_var
。
解决方法:将需要固定的BN
层状态设置为eval
。
问题示例:
环境:torch
:1.7.0
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.bn1 = nn.BatchNorm2d(6)
self.conv2 = nn.Conv2d(6, 16, 3)
self.bn2 = nn.BatchNorm2d(16)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 5)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
def print_parameter_grad_info(net):
print('-------parameters requires grad info--------')
for name, p in net.named_parameters():
print(f'{name}:\t{p.requires_grad}')
def print_net_state_dict(net):
for key, v in net.state_dict().items():
print(f'{key}')
if __name__ == "__main__":
net = Net()
print_parameter_grad_info(net)
net.requires_grad_(False)
print_parameter_grad_info(net)
torch.random.manual_seed(5)
test_data = torch.rand(1, 1, 32, 32)
train_data = torch.rand(5, 1, 32, 32)
# print(test_data)
# print(train_data[0, ...])
for epoch in range(2):
# training phase, 假设每个epoch只迭代一次
net.train()
pre = net(train_data)
# 计算损失和参数更新等
# ....
# test phase
net.eval()
x = net(test_data)
print(f'epoch:{epoch}', x)
运行结果:
-------parameters requires grad info--------
conv1.weight: True
conv1.bias: True
bn1.weight: True
bn1.bias: True
conv2.weight: True
conv2.bias: True
bn2.weight: True
bn2.bias: True
fc1.weight: True
fc1.bias: True
fc2.weight: True
fc2.bias: True
fc3.weight: True
fc3.bias: True
-------parameters requires grad info--------
conv1.weight: False
conv1.bias: False
bn1.weight: False
bn1.bias: False
conv2.weight: False
conv2.bias: False
bn2.weight: False
bn2.bias: False
fc1.weight: False
fc1.bias: False
fc2.weight: False
fc2.bias: False
fc3.weight: False
fc3.bias: False
epoch:0 tensor([[-0.0755, 0.1138, 0.0966, 0.0564, -0.0224]])
epoch:1 tensor([[-0.0763, 0.1113, 0.0970, 0.0574, -0.0235]])
可以看到:
net.requires_grad_(False)
已经将网络中的各参数设置成了不需要梯度更新的状态,但是同样的测试数据test_data
在不同epoch
中前向之后出现了不同的结果。
调用print_net_state_dict
可以看到BN
层中的参数running_mean
和running_var
并没在可优化参数net.parameters
中
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
但在training pahse
的前向过程中,这两个参数被更新了。导致整个网络在freeze
的情况下,同样的测试数据出现了不同的结果
Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default
momentum
of 0.1. source
因此在training phase时对BN层显式设置eval
状态:
if __name__ == "__main__":
net = Net()
net.requires_grad_(False)
torch.random.manual_seed(5)
test_data = torch.rand(1, 1, 32, 32)
train_data = torch.rand(5, 1, 32, 32)
# print(test_data)
# print(train_data[0, ...])
for epoch in range(2):
# training phase, 假设每个epoch只迭代一次
net.train()
net.bn1.eval()
net.bn2.eval()
pre = net(train_data)
# 计算损失和参数更新等
# ....
# test phase
net.eval()
x = net(test_data)
print(f'epoch:{epoch}', x)
可以看到结果正常了:
epoch:0 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
交流基地:630390733
pytorch固定BN层参数的更多相关文章
- 【转载】 【caffe转向pytorch】caffe的BN层+scale层=pytorch的BN层
原文地址: https://blog.csdn.net/u011668104/article/details/81532592 ------------------------------------ ...
- 【转载】 Caffe BN+Scale层和Pytorch BN层的对比
原文地址: https://blog.csdn.net/elysion122/article/details/79628587 ------------------------------------ ...
- (原)torch中微调某层参数
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6221664.html 参考网址: https://github.com/torch/nn/issues ...
- Tensorflow训练和预测中的BN层的坑
以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...
- 【转载】 Pytorch(1) pytorch中的BN层的注意事项
原文地址: https://blog.csdn.net/weixin_40100431/article/details/84349470 ------------------------------- ...
- Batch Normalization的算法本质是在网络每一层的输入前增加一层BN层(也即归一化层),对数据进行归一化处理,然后再进入网络下一层,但是BN并不是简单的对数据进行求归一化,而是引入了两个参数λ和β去进行数据重构
Batch Normalization Batch Normalization是深度学习领域在2015年非常热门的一个算法,许多网络应用该方法进行训练,并且取得了非常好的效果. 众所周知,深度学习是应 ...
- PyTorch模型读写、参数初始化、Finetune
使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...
- BN层
论文名字:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 论 ...
- 【卷积神经网络】对BN层的解释
前言 Batch Normalization是由google提出的一种训练优化方法.参考论文:Batch Normalization Accelerating Deep Network Trainin ...
随机推荐
- web自动化 下拉框、切换到新窗口
一.下拉框 相信大家在手动测试web页面时,遇到过下拉框吧,那进行web自动化测试时,如何操作下拉框,且看下文 1.selenium中提供了方法,先导入Select方法 from selenium.w ...
- C语言模拟实现先来先服务(FCFS)和短作业优先(SJF)调度算法
说明 该并非实现真正的处理机调度,只是通过算法模拟这两种调度算法的过程. 运行过程如下: 输入进程个数 输入各个进程的到达事件 输入各个进程的要求服务事件 选择一种调度算法 程序给出调度结果:各进程的 ...
- sqli-labs-master 闯关前知识点学习
1).前期准备.知识点 开始之前,为了方便查看sql注入语句,我在sqli-labs-master网页源码php部分加了两行代码,第一行意思是输出数据库语句,第二行是换行符 一.Mysql 登录 1. ...
- PADS生成贴片文件
PADS生成贴片文件 VIEW-BOTTOM VIEW能够使Bottom层正常显示. 1. pastmask_top->Output Devices->Device Setup- 2. 进 ...
- Python中可迭代对象是什么?
Python中可迭代对象(Iterable)并不是指某种具体的数据类型,它是指存储了元素的一个容器对象,且容器中的元素可以通过__iter__( )方法或__getitem__( )方法访问. __i ...
- Error.name 六种值对应的信息
1 EvalErroe:eval() 的使用与定义不一致 2 RangrError: 数值越界 3 ReferenceError:非法或不能识别的引用数值 4 SyntaxError:发生语法解析错 ...
- 1、pytorch写的第一个Linear模型(原始版,不调用nn.Modules模块)
参考: https://github.com/Iallen520/lhy_DL_Hw/blob/master/PyTorch_Introduction.ipynb 模拟一个回归模型,y = X * w ...
- 记一次MacPro风扇一直转的问题排查
1.查看CPU占用最高的进程 借助活动监视器,查看CPU占用最高的进程,可以观察到是Chrome浏览器 2.打开Chrome的任务管理器 2.1.查看CPU占用最高的chrome进程 3.分析和结束进 ...
- [SQL Server]多次为 '派生表' 指定了列 'id'
问题: 原因: 因为派生表oo中出现了两个同样的'ID'属性,所以会报[多次为 'o' 指定了列 'ID']的错误. 只需要把第二个星号替换成所需要的列名并把重复字段重命名就好了 解决方案:
- js中的(function(){})()立即执行
( function(){-} )() 和 ( function (){-} () ) 是两种javascript立即执行函数的常见写法,要理解立即执行函数,需要先理解一些函数的基本概念. 函数声明. ...