NiftyNet 项目了解
1. NiftyNet项目概述
NiftyNet项目对tensorflow进行了比较好的封装,实现了一整套的DeepLearning流程。将数据加载、模型加载,网络结构定义等进行了很好的分离,抽象封装成了各自独立的模块。虽然抽象的概念比较多,使得整个项目更为复杂,但是整体结构清晰,支持的模块多。可扩展性还没有进行试验,暂时不是很清楚。 该项目能够实现:
- 图像分割
- 图像分类
- gan
- Autoencoder
- 回归
项目支持医学图像的读取,提供的读取器有:
- nibabel 支持.nii医学文件格式
- simpleitk 支持.dcm和.mhd格式的医疗图像
- opencv 支持.jpg等常见图像读取,读取后通道顺序为BGR
- skimage 支持.jpg等常见图像读取
- pillow 支持.jpg等常见图像读取
在使用中遇到了一些问题,其训练的速度非常慢。最开始单个iter的平均训练时间估计在40秒以上,有的iter时间会有200秒。现在主要在查找性能瓶颈。
一、 项目结构
niftynet.engine.application_driver(ApplicationDriver)定义并驱动着整个Application的生命周期,将配置数据进行解析后,实例化Application并启动流程。
i.
Application
Application 作为核心概念,承担整个train或inference的主要功能。所有Application继承于niftynet.application.base_application(简称为BaseApplication)。BaseApplication使用单例模式。
在Application类中,构建了Tensorflow的图结构和创建Session用于驱动计算。
BaseApplication单例模式的具体实现有一点小问题。
Application所完成的工作具体可以划分成以下4个环节
- 输入数据相关 数据加载,数据增强,数据取样等,抽象在这两个接口中在SegmentationApplication中,sampler支持:uniform, weighted, resized, balanced4种方式
initialise_dataset_loader()
initialise_sampler()
- 网络结构相关 网络结构的定义,参数的管理,自定义操作等,抽象在此接口中
initialise_network()
- 模型共享相关 完成由网络的输入到网络的输出,计算loss、gradient,创建optimizer等,抽象在此接口中
connect_data_and_network()
- 输出解码相关 inference将网络输出解码操作,抽象在此接口中
interpret_output()
ii.
Config
配置文件需要必须包含的模块:
- [SYSTEM]
- [NETWORK]
- 如果action为train,那么config中需要包含[TRAINING]模块
- 如果action为inference,那么config中需要包含[INFERENCE]模块
- 额外的,根据特定的application,会需要包含指定名称的模块。如:
–
[GAN]
–
[SEGMENTATION]
–
[REGRESSION]
–
[AUTOENCODER]
- 除了以上的配置外,其他的数据会处理为input data source specifications【数据声明模块】
l 数据声明模块
Name |
解释 |
例子 |
默认值 |
csv_file |
包含输入图像文件的列表 |
csvfile=filelist.csv |
'' |
pathtosearch |
如果没有配置csv_file,则从此路径下去搜索输入图像 |
pathtosearch=~/ct_data |
NiftyNet home folder |
filename_contains |
搜索输入图像时用于匹配的关键词 |
filename_contains=foo, bar |
'' |
filenamenotcontains |
搜索输入图像时用于排除的关键词 |
filenamenotcontains=ti, |
'' |
filename_removefromid |
正则表达式,用于从输入图像的文件名中,解析出id |
filename_removefromid=foo |
'' |
interp_order |
插值法 |
interp_order=1 |
3 |
pixdim |
如果指定了,输入的3D图像会重新采样到指定大小再送入网络 |
pixdim=1.2, 1.2, 1.2 |
'' |
axcodes |
如果指定了,输入的3D图像会重新设定到指定的axcodes顺序再送入网络 参考文章 |
axcodes=L, P, S |
'' |
spatialwindowsize |
3个整数,指定输入window的大小[能被8整除] |
spatialwindowsize=64, |
'' |
loader |
指定图像读取loader类型 |
loder=simpleitk |
None |
[interp_order] 当设定采样方法为resize时,需要这个参数对图片上采样或下采样 1表示双线性插值
0表示最近邻插值
3表示三次样条插值
l [SYSTEM]
Name |
解释 |
例子 |
默认值 |
cude_devices |
指定GPU |
cuda_devices=0,1 |
'' |
num_threads |
预处理线程的数量 |
num_threads=8 |
2 |
num_gpus |
训练时使用GPU数量 |
num_gpus=2 |
1 |
model_dir |
保存或读取模型权重和Log的位置 |
model_dir=~/niftynet/xxx |
config文件所在目录 |
datasetsplitfile |
用于将数据划分成training/validation/inferenct字集 |
datasetsplitfile=~/nifnet/xxx |
./datasetsplitfile.csv |
event_handler |
注册事件处理 |
eventhandler=modelrestorer |
modelsaver, |
l [NETWORK]
Names |
解释 |
例子 |
默认值 |
name |
所使用的网络结构 |
name=niftynet.network.toynet.ToyNet |
‘’ |
activation_function |
设置网络中使用的激活函数 |
activation_function=prelu |
Relu |
batch_size |
批大小 |
batch_size=10 |
2 |
smaller_final_batch_mode |
当总数据量不能被batch_size整除时,最后一个batch_size的方式 |
smaller_final_batch_mode=drop smaller_final_batch_mode=pad smaller_final_batch_mode=dynamic |
pad |
decay |
正则化参数 |
decay=1e-5 |
0.0 |
reg_type |
正则化类型 |
reg_type=L1 |
L2 |
volume_padding_size |
volume_padding_size=4, |
0, 0, 0 |
|
volume_padding_mode |
volume_padding_mode=symmetric |
minimum |
|
window_sampling |
采样的类型 |
window_sampling=uniform 固定尺寸,相同的概率分布 window_sampling=weighted 固定尺寸,根据intensity作为概率分布 window_sampling=balanced 固定尺寸,每个label拥有相同采样概率 window_sampling=resize 缩放图像到window尺寸 |
uniform |
queue_length |
采样时使用的buffer大小 |
queue_length=10 |
5 |
keep_prob |
如果网络中使用了dropout |
keep_prob=0.2 |
1.0 |
l [TRAINING]
Name |
解释 |
例子 |
默认值 |
optimizer |
优化器类型 |
optimizer=momentum |
adam |
sample_per_volume |
每个输入图像采样的次数 |
sample_per_volume=5 |
1 |
lr |
学习率 |
lr=0.0001 |
0.1 |
loss_type |
loss计算方式 |
loss_type=CrossEntropy |
Dice |
starting_iter |
启动的iter |
starting_iter=0 |
0 |
save_every_n |
保存的间隔 |
save_every_n=50 |
500 |
tensorboard_every_n |
tensorboard记录的间隔 |
tensorboard_every_n=50 |
20 |
max_iter |
最大iter数 |
max_iter=3000 |
10000 |
max_checkpoints |
保存的最多checkpoint数 |
max_checkpoints=5 |
100 |
训练时验证
validation_every_n |
训练时进行验证的间隔 |
validation_every_n=10 |
-1 |
validation_max_iter |
验证时iter的数量 |
validation_max_iter=5 |
1 |
exclude_fraction_for_validation |
验证集的比重 |
exclude_fraction_for_validation=0.2 |
0.0 |
exclude_fraction_for_inference |
测试集的比重 |
exclude_fraction_for_inference=0.1 |
0.0 |
数据增强
rotation_angle |
旋转 |
rotation_angle=-10.0, |
‘’ |
scaling_percentage |
缩放 |
scaling_percentage=-20.0, |
‘’ |
random_flipping_axes |
翻转 |
random_flipping_axes=1,2 |
-1 |
l [INFERENCE]
Name |
解释 |
例子 |
默认值 |
spatial_window_size |
网络输入尺寸大小 |
spatial_window_size=64,64,64 |
‘’ |
border |
输入尺寸的边框 |
border=5,5,5 |
0,0,0 |
inference_iter |
使用指定iter保存的权重文件 |
inference_iter=1000 |
-1 |
save_seg_dir |
保存输出路径 |
save_seg_dir=output/test |
output |
output_postfix |
输出保存的后缀 |
output_postfix=_output |
_niftynet_out |
output_interp_order |
插值法 |
output_interp_order=0 |
0 |
dataset_to_infer |
使用的数据集,可选:’all’, ‘training’, |
dataset_to_infer=all |
‘’ |
iii.
Reader & Dataset
n niftynet.io.image_reader模块
ImageReader的主要作用是,遍历一组目录,搜索并返回一个图像的列表,以及使用iterative的方式将数据加载到内存中。
ImageReader会创建一个tf.data.Dataset的对象,这样使得模块可以很方便地接入到基于tensorflow的程序中。
ImageReader的特点:
l 设计用于支持医疗图像数据的格式
l 支持多模态输入数据
l 支持tf.data.Dataset
n niftynet.contrib.dataset_sampler
sampler将 image
reader作为输入,从每张图像中采取出结果输出。
在很多的医学图像处理的情况中,由于GPU显存的限制以及训练效率等的考虑,网络结构会对图像的部分进行处理而非整张图像。
iv.
Network
项目中包含了一些已经实现的网络:
- GAN:
–
simulator_gan
–
siple_gan
- Segmentation:
–
highres3dnet, highres3dnetsmall, highres3dnetlarge
–
toynet
–
unet
–
vnet
–
dense_vnet
–
deepmedic
–
scalenet
–
holisticnet
–
unet_2d
- classification:
–
resnet
–
se_resnet
- autoencoder:
–
vae
v.
Loss
已提供支持的loss计算方式
- Segmentation
- CrossEntropy
- CrossEntropy_Dense
- Dice
- Dice_NS
- Dice_Dense
- Dice_Dense_NS
- Tversky
- GDSC
- WGDL
- SensSpec
- Gan
- CrossEntropy
- Regression
- L1Loss
- L2Loss
- RMSE
- MAE
- Huber
- Classification
- CrossEntropy
- AutoEncoder
- VariationalLowerBound
支持的优化器类型
- adam
- gradientdescent
- momentum
- nesterov
- adagrad
- rmsprop
vi.
Event机制
NiftyNet项目的设计,使用了Signal和event handler模式,具体实现使用了blinker库。这样可以方便地将模型保存,tensorboard记录等操作进行配置。
目前可供注册的signal有:
- GRAPH_CREATED
- SESS_STARTED
- SESS_FINISHED
- ITER_STARTED
- ITER_FINISHED
信号处理函数注册到对应的信号后,由引擎负责调用。
vii.
Layer
网络层的相关设计都封装在Layer类中,可继承layer类,实现定制化结构
NiftyNet 项目了解的更多相关文章
- NiftyNet项目介绍
NiftyNet项目介绍 简述 NiftyNet是一款开源的卷积神经网络平台,旨在通过实现医学图像分析的深度学习方法和模块,支持快速原型和再现性,由WEISS (Wellcome EPSRC Ce ...
- NiftyNet 数据预处理
NiftyNet项目介绍 使用NiftyNet时,我们需要先将图像数据和标签进行一次简单的处理,得到对应的.csv文件. 对应文件格式为: img.csv image path img_name im ...
- 开源医学图像处理平台NiftyNet介绍
18年下半年10月份左右,老师分配有关NiftyNet平台的相关学习的任务,时隔5个月,决定整理一下以前的笔记,写成相应的博客! 目录 1.NiftyNet平台简介 2.NiftyNet平台架构设计 ...
- Fis3前端工程化之项目实战
Fis3项目 项目目录结构: E:. │ .gitignore │ fis-conf.js │ index.html │ package.json │ README.md │ ├─material │ ...
- 【原】Android热更新开源项目Tinker源码解析系列之三:so热更新
本系列将从以下三个方面对Tinker进行源码解析: Android热更新开源项目Tinker源码解析系列之一:Dex热更新 Android热更新开源项目Tinker源码解析系列之二:资源文件热更新 A ...
- 最近帮客户实施的基于SQL Server AlwaysOn跨机房切换项目
最近帮客户实施的基于SQL Server AlwaysOn跨机房切换项目 最近一个来自重庆的客户找到走起君,客户的业务是做移动互联网支付,是微信支付收单渠道合作伙伴,数据库里存储的是支付流水和交易流水 ...
- Hangfire项目实践分享
Hangfire项目实践分享 目录 Hangfire项目实践分享 目录 什么是Hangfire Hangfire基础 基于队列的任务处理(Fire-and-forget jobs) 延迟任务执行(De ...
- Travis CI用来持续集成你的项目
这里持续集成基于GitHub搭建的博客为项目 工具: zqz@ubuntu:~$ node --version v4.2.6 zqz@ubuntu:~$ git --version git versi ...
- 【原】Android热更新开源项目Tinker源码解析系列之一:Dex热更新
[原]Android热更新开源项目Tinker源码解析系列之一:Dex热更新 Tinker是微信的第一个开源项目,主要用于安卓应用bug的热修复和功能的迭代. Tinker github地址:http ...
随机推荐
- CDHtmlDialog 基本使用
跳转 Navigate("res://tt.exe/#138"); 138是html的资源号 输入框的Get,set HRESULT CTTDlg::OnButtonCancel( ...
- [Android Memory] Linux下malloc函数和OOM Killer
http://www.linuxidc.com/Linux/2010-09/28364.htm Linux下malloc函数主要用来在用户空间从heap申请内存,申请成功返回指向所分配内存的指针,申请 ...
- 怎么设置IDEA,去除单词拼写检查,或者添加自定义的单词
如图所示,添加自定义的单词,这样IDEA检查的时候,就不会报错了.估计默认是根据英文单词来释义的.
- Ajax的简单总结
1. Ajax的优势和不足 1.1 Ajax的优势 1. 不需要插件支持 Ajax不需要任何浏览器插件,就可以被绝大多数主流浏览器所支持,用户只需要允许JavaScript在浏览器上执行即可. 2. ...
- 【转】php中的会话机制(2)
原文:https://segmentfault.com/a/1190000000468220 发现,在调用session_start()的时候, session_start() 里面应该是有调用类似 ...
- (转) java中try/catch性能和原理
stackoverflow上有一个讨论,参与的人还挺多: https://stackoverflow.com/questions/141560/should-try-catch-Go-inside-o ...
- JQuery的get、post和ajax方法的使用
在JQuery中可以使用get,post和ajax方法给服务器端传递数据 get方法的使用(customForGet.js文件): function verify(){ //1.获取文本框的数据 // ...
- 二叉树(9)----打印二叉树中第K层的第M个节点,非递归算法
1.二叉树定义: typedef struct BTreeNodeElement_t_ { void *data; } BTreeNodeElement_t; typedef struct BTree ...
- 微信公众平台开发小记(ASP.NET)
微信的好东西,提供了很大的平台去发挥,公司最近推出微信公众账号,也接触了一些东西, 最终决定用asp.net来开发服务端程序. 微信公众平台的API很简单,利用XML来规范格式,并且所有的数据都在CD ...
- 浅析php中抽象类和接口的概念以及区别[转]
//抽象类的定义: abstract class ku{ //定义一个抽象类 abstract function kx(); ...... } function aa extends ku{ //实现 ...