pytorch faster_rcnn
代码地址:https://github.com/jwyang/faster-rcnn.pytorch
1.fasterRCNN.train():这个不是让网络进行训练,而是让module in training mode,有些module在traing model和testing model下不同,比如bn
即self.training这个成员变量为true(这个成员变量属于nn.Module,fasterRCNN继承了这个成员变量),以下是train成员函数的源码
2.bn的train和test不同,train的时候应该是要学习参数的,test的时候关闭,pytorch的用法如下:
pytorch的batchnorm使用时需要小心,training和track_running_stats可以组合出三种behavior,很容易掉坑里(我刚发现我对track_running_stats的理解错了)。
- training=True, track_running_stats=True, 这是常用的training时期待的行为,running_mean 和running_var会跟踪不同batch数据的mean和variance。
- training=True, track_running_stats=False, 这时候batchnorm不跟踪跨batch数据的statistics了,而是用每个batch的mean和variance做normalization。
- training=False, track_running_stats=True, 这是我们期待的test时候的行为,即使用training阶段估计的running_mean 和running_var.
- training=False, track_running_stats=False,同2(!!!).
https://www.zhihu.com/question/282672547/answer/529154567李韶华的回答
if self.class_agnostic:
self.RCNN_bbox_pred = nn.Linear(4096, 4)
else:
self.RCNN_bbox_pred = nn.Linear(4096, 4 * self.n_classes)
4.真正开始训练的代码不是fasterRCNN.train(),而是下面这段代码:
rois, cls_prob, bbox_pred, \
rpn_loss_cls, rpn_loss_box, \
RCNN_loss_cls, RCNN_loss_bbox, \
rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)
fasterRCNN是一个实例,应该是没办法进行调用的,但实际上这段代码执行的是forward函数。为什么?其实就是python的括号重载。fasterRCNN这个实例继承于nn.Module类,这个类定义了forward成员函数,nn.Module类使用了__call__进行了重载,让实例能够调用,并且调用的函数是forward函数,具体代码见下面的源码:
python中__call__函数的作用是使实例能够像函数一样被调用https://blog.csdn.net/Yaokai_AssultMaster/article/details/70256621,也称之为括号重载,即‘()’
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
hook(self, input)
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
raise RuntimeError(
"forward hooks should never return any values, but '{}'"
"didn't return None".format(hook))
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
nn.Module定义了一个forward的成员函数,这个函数在基类中没有实现,而是在各个子类自己实现的,每个子类都必须实现forward函数:
def forward(self, *input):
r"""Defines the computation performed at every call.
Should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
"""
raise NotImplementedError
子类调用forward函数不能直接用calss.forward(),而是用实例的函数调用,具体的原因好像是hook,这个在上面__call__函数中也看到调用forward使用了跟hook有关的input
pytorch faster_rcnn的更多相关文章
- Faster_RCNN 4.训练模型
总结自论文:Faster_RCNN,与Pytorch代码: 本文主要介绍代码最后部分:trainer.py .train.py , 首先分析一些主要理论操作,然后在代码分析里详细介绍其具体实现.首先 ...
- Faster_RCNN 3.模型准备(下)
总结自论文:Faster_RCNN,与Pytorch代码: 本文主要介绍代码第二部分:model/ , 首先分析一些主要理论操作,然后在代码分析里详细介绍其具体实现. 首先在参考文章的基础上进一步详细 ...
- Faster_RCNN 2.模型准备(上)
总结自论文:Faster_RCNN,与Pytorch代码: 本文主要介绍代码第二部分:model/utils , 首先分析一些主要理论操作,然后在代码分析里详细介绍其具体实现. 一. 主要操作 1. ...
- Faster_RCNN 1.准备工作
总结自论文:Faster_RCNN,与Pytorch代码: 代码结构: simple-faster-rcnn-pytorch.py data __init__.py dataset.py util. ...
- 目标检测之Faster-RCNN的pytorch代码详解(模型训练篇)
本文所用代码gayhub的地址:https://github.com/chenyuntc/simple-faster-rcnn-pytorch (非本人所写,博文只是解释代码) 好长时间没有发博客了 ...
- PyTorch专栏(八):微调基于torchvision 0.3的目标检测模型
专栏目录: 第一章:PyTorch之简介与下载 PyTorch简介 PyTorch环境搭建 第二章:PyTorch之60分钟入门 PyTorch入门 PyTorch自动微分 PyTorch神经网络 P ...
- faster_rcnn c++版本的 caffe 封装,动态库(2)
摘要: 转载请注明出处,楼燚(yì)航的blog,http://www.cnblogs.com/louyihang-loves-baiyan/ github上的代码链接,求给星星:) https:// ...
- Ubutnu16.04安装pytorch
1.下载Anaconda3 首先需要去Anaconda官网下载最新版本Anaconda3(https://www.continuum.io/downloads),我下载是是带有python3.6的An ...
- 解决运行pytorch程序多线程问题
当我使用pycharm运行 (https://github.com/Joyce94/cnn-text-classification-pytorch ) pytorch程序的时候,在Linux服务器 ...
随机推荐
- mybatis-generator-maven逆向工程
在idea 中使用 mybatis的 mybatis-generator-maven-plugin 可以根据数据库 生成 dao层,pojo类,Mapper文件. 一: 在 pom.xml ...
- linux系统编程:自己动手写一个pwd命令
pwd命令:打印当前的工作目录 我们都知道每个目录下面都有两个特殊的目录( . 和 .. ), .: 当前目录, ..: 上层目录, 每个目录都有一个i节点与之相关联 ghostwu@ubuntu: ...
- 01-Javascript简介(了解)
[转]01-Javascript简介(了解) Web前端有三层: HTML:从语义的角度,描述页面结构 CSS:从审美的角度,描述样式(美化页面) JavaScript:从交互的角度,描述行为(提升用 ...
- Python并发编程(守护进程,进程锁,进程队列)
进程的其他方法 P = Process(target=f,) P.Pid 查看进程号 查看进程的名字p.name P.is_alive() 返回一个true或者False P.terminate( ...
- Java 基础知识总结1
作者QQ:1095737364 QQ群:123300273 欢迎加入! 1.数据类型: 数据类型:1>.基本数据类型:1).数值型: 1}.整型类型(byte 8位 (by ...
- js-权威指南学习笔记20
第二十章 客户端存储 1.客户端存储有一下几种形式:Web存储.cookie.IE userData.离线Web应用.Web数据库.文件系统API. 2.Web存储标准所描述的API包含localSt ...
- 分布式配置中心 携程(apollo)
1.传统配置文件与分布式配置文件区别 传统配置文件:如果修改了配置文件,需要重新打包发布,重新发布服务,而且每个环境的变更配置文件,比较繁琐. 分布式配置文件:将配置文件注册到配置中心上去,可以使用分 ...
- SD从零开始51-54 信用控制范围, 信用范围数据维护, 自动信用控制, 信用控制-阻止后续功能
[原创] SD从零开始51 信用控制范围 分散的组织结构Decentralized Organization 信用控制范围是一个为客户指定和控制信用限额的组织单元: 依赖于你公司的需求,应收款可以使用 ...
- SQLServer 学习笔记之超详细基础SQL语句 Part 11
Sqlserver 学习笔记 by:授客 QQ:1033553122 -----------------------接Part 10------------------- DECLARE @myavg ...
- AOP编程 - 淘宝京东网络处理
现象描述 当我们打开京东 app 进入首页,如果当前是没有网络的状态,里面的按钮点击是没有反应的.只有当我们打开网络的情况下,点击按钮才能跳转页面,按照我们一般人写代码的逻辑应该是这个样子: /** ...