本文包含代码案例和讲解,建议收藏,也顺便点个赞吧。欢迎各路朋友爱好者加我的微信讨论问题:cyx645016617.

在很多关于医学图像分割的竞赛、论文和项目中,发现 Dice 系数(Dice coefficient) 损失函数出现的频率较多,这里整理一下。使用图像分割,绕不开Dice损失,这个就好比在目标检测中绕不开IoU一样

1 概述

Dice损失和Dice系数(Dice coefficient)是同一个东西,他们的关系是:

\[DiceLoss = 1-DiceCoefficient
\]

1.2 Dice 定义

  • Dice系数, 根据 Lee Raymond Dice命名,是一种集合相似度度量函数,通常用于计算两个样本的相似度(值范围为 [0, 1])。
\[DiceCoefficient = \frac{2|X \bigcap Y|}{|X| + |Y|}
\]

其中\(|X| \bigcap |Y|\)表示X和Y集合的交集,|X|和|Y|表示其元素个数,对于分割任务而言,|X|和|Y|表示分割的ground truth和predict_mask

此外,我们可以得到Dice Loss的公式:

\[DiceLoss = 1- \frac{2|X \bigcap Y|}{|X| + |Y|}
\]

2 手推案例

这个Dice网上有一个非常好二分类的Dice Loss的手推的案例,非常好理解,过程分成两个部分:

  1. 先计算\(|X|\bigcap|Y|\)
  2. 再计算\(|X|\)和\(|Y|\)

    计算loss我们必然已经有了这两个参数,模型给出的output,也就是预测的mask;数据集中的ground truth(GT),也就是真实的mask。

在很多关于医学图像分割的竞赛、论文和项目中,发现 Dice 系数(Dice coefficient) 损失函数出现的频率较多,这里整理一下。使用图像分割,绕不开Dice损失,这个就好比在目标检测中绕不开IoU一样

1 概述

Dice损失和Dice系数(Dice coefficient)是同一个东西,他们的关系是:

\[DiceLoss = 1-DiceCoefficient
\]

1.2 Dice 定义

  • Dice系数, 根据 Lee Raymond Dice命名,是一种集合相似度度量函数,通常用于计算两个样本的相似度(值范围为 [0, 1])。
\[DiceCoefficient = \frac{2|X \bigcap Y|}{|X| + |Y|}
\]

其中\(|X| \bigcap |Y|\)表示X和Y集合的交集,|X|和|Y|表示其元素个数,对于分割任务而言,|X|和|Y|表示分割的ground truth和predict_mask

此外,我们可以得到Dice Loss的公式:

\[DiceLoss = 1- \frac{2|X \bigcap Y|}{|X| + |Y|}
\]

2 手推案例

这个Dice网上有一个非常好二分类的Dice Loss的手推的案例,非常好理解,过程分成两个部分:

  1. 先计算\(|X|\bigcap|Y|\)
  2. 再计算\(|X|\)和\(|Y|\)

    计算loss我们必然已经有了这两个参数,模型给出的output,也就是预测的mask;数据集中的ground truth(GT),也就是真实的mask。

当然还没完,还要把结果加和:

对于二分类问题,GT分割图是只有 0, 1 两个值的,因此可以有效的将在 Pred 分割图中未在 GT 分割图中激活的所有像素清零. 对于激活的像素,主要是惩罚低置信度的预测,较高值会得到更好的 Dice 系数.

关于计算\(|X|\)和\(|Y|\),如下:

其中需要注意的是,一半情况下,这个是直接对所有元素求和,当然有对所有元素先平方再求和的做法。总之就这么多,非常的简单好用。不过上面的内容是针对分割二分类的情况,对于多分类的情况和二分类基本相同

3 二分类代码实现

在实现的时候,往往会加上一个smooth,防止分母为0的情况出现。所以公式变成:

\[DiceLoss = 1- \frac{2|X \bigcap Y|+smooth}{|X| + |Y|+smooth}
\]

一般smooth为1

3.1 PyTorch实现

先是dice coefficient的实现,pred和target的shape为【batch_size,channels,...】,2D和3D的都可以用这个。

  1. def dice_coeff(pred, target):
  2. smooth = 1.
  3. num = pred.size(0)
  4. m1 = pred.view(num, -1) # Flatten
  5. m2 = target.view(num, -1) # Flatten
  6. intersection = (m1 * m2).sum()
  7. return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)

当然dice loss就是1-dice ceofficient,所以可以写成:

  1. def dice_coeff(pred, target):
  2. smooth = 1.
  3. num = pred.size(0)
  4. m1 = pred.view(num, -1) # Flatten
  5. m2 = target.view(num, -1) # Flatten
  6. intersection = (m1 * m2).sum()
  7. return 1-(2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)

3.2 keras实现

  1. smooth = 1. # 用于防止分母为0.
  2. def dice_coef(y_true, y_pred):
  3. y_true_f = K.flatten(y_true) # 将 y_true 拉伸为一维.
  4. y_pred_f = K.flatten(y_pred)
  5. intersection = K.sum(y_true_f * y_pred_f)
  6. return (2. * intersection + smooth) / (K.sum(y_true_f * y_true_f) + K.sum(y_pred_f * y_pred_f) + smooth)
  7. def dice_coef_loss(y_true, y_pred):
  8. return 1. - dice_coef(y_true, y_pred)

3.3 tensorflow实现

  1. def dice_coe(output, target, loss_type='jaccard', axis=(1, 2, 3), smooth=1e-5):
  2. """
  3. Soft dice (Sørensen or Jaccard) coefficient for comparing the similarity of two batch of data,
  4. usually be used for binary image segmentation
  5. i.e. labels are binary.
  6. The coefficient between 0 to 1, 1 means totally match.
  7. Parameters
  8. -----------
  9. output : Tensor
  10. A distribution with shape: [batch_size, ....], (any dimensions).
  11. target : Tensor
  12. The target distribution, format the same with `output`.
  13. loss_type : str
  14. ``jaccard`` or ``sorensen``, default is ``jaccard``.
  15. axis : tuple of int
  16. All dimensions are reduced, default ``[1,2,3]``.
  17. smooth : float
  18. This small value will be added to the numerator and denominator.
  19. - If both output and target are empty, it makes sure dice is 1.
  20. - If either output or target are empty (all pixels are background), dice = ```smooth/(small_value + smooth)``, then if smooth is very small, dice close to 0 (even the image values lower than the threshold), so in this case, higher smooth can have a higher dice.
  21. Examples
  22. ---------
  23. >>> outputs = tl.act.pixel_wise_softmax(network.outputs)
  24. >>> dice_loss = 1 - tl.cost.dice_coe(outputs, y_)
  25. References
  26. -----------
  27. - `Wiki-Dice <https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient>`__
  28. """
  29. inse = tf.reduce_sum(output * target, axis=axis)
  30. if loss_type == 'jaccard':
  31. l = tf.reduce_sum(output * output, axis=axis)
  32. r = tf.reduce_sum(target * target, axis=axis)
  33. elif loss_type == 'sorensen':
  34. l = tf.reduce_sum(output, axis=axis)
  35. r = tf.reduce_sum(target, axis=axis)
  36. else:
  37. raise Exception("Unknow loss_type")
  38. dice = (2. * inse + smooth) / (l + r + smooth)
  39. dice = tf.reduce_mean(dice)
  40. return dice

4 多分类

假设是一个10分类的任务,那么我们应该会有一个这样的模型预测结果:[batch_size,10,width,height],然后我们的ground truth需要改成one hot的形式,也变成[batch_size,10,width,height]。剩下的和二分类的代码基本相同了,先ground truth和预测结果对应元素相乘,然后对相乘的结果求和。就是最后需要对每一个类别和每一个样本都求一次平均就行了。

5 深入探讨Dice,IoU



上图就是我们常见的IoU方法,假设分子的两个集合,一个集合是Ground Truth,另外一个集合是神经网络给出的预测值。不要被图中的正方形的形状限制了想想,对于分割任务来说,一般是像素级的不规则图案

如果预测正确,也就是分子中的蓝色交汇的部分,称之为True Positive,属于True Positive的像素的数量就是分子的值。分母的值是Ground Truth的所有像素的数量和预测结果中所有像素的数量的和再减去重叠的部分的像素数量。

直接学过recall,precision,混淆矩阵,f1score的朋友一定对FN,TP,TN,FP这些不陌生:

  • 黄色区域:预测为negative,但是GT中是positive的False Negative区域;
  • 红色区域:预测为positive,但是GT中是Negative的False positive区域;

对于IoU的预测好坏的直观理解就是:



简单的说就是,重叠的越多,IoU越接近1,预测效果越好

现在让我们更好的从IoU过渡到Dice,我们先把IoU的算式写出来:

\[IoU = \frac{TP}{TP+FP+FN}
\]

Dice的算式,结合我们之前讲的内容,可以推导出,\(|X|\bigcap|Y|\)就是TP,\(|X|\)假设是GT的话就是FN+TP,\(|Y|\)假设是预测的mask,就是TP+FP,所以:

\[Dice_coefficient = \frac{2\times TP}{TP+FN + TP + FP}
\]

所以我们可以得到Dice和IoU之间的关系了,这里的之后的Dice默认表示Dice Coefficient

\[IoU = \frac{Dice}{2-Dice}
\]

这个函数图像如下图,我们只关注0~1这个区间就好了,可以发现:

  • IoU和Dice同时为0,同时为1;这很好理解,就是全预测正确和全部预测错误
  • 假设在相同的预测情况下,可以发现Dice给出的评价会比IoU高一些,哈哈哈。所以Dice的数据会更加好看一些。

参考文章:

  1. https://www.aiuai.cn/aifarm1159.html
  2. https://blog.csdn.net/py184473894/article/details/90383618

图像分割必备知识点 | Dice损失 理论+代码的更多相关文章

  1. 图像分割必备知识点 | Unet详解 理论+ 代码

    文章转自:微信公众号[机器学习炼丹术].文章转载或者交流联系作者微信:cyx645016617 喜欢的话可以参与文中的讨论.在文章末尾点赞.在看点一下呗. 0 概述 语义分割(Semantic Seg ...

  2. 图像分割必备知识点 | Unet++超详解+注解

    文章来自周纵苇大佬的知乎,是Unet++模型的一作大佬,其在2019年底详细剖析了Unet++模型,讲解的非常好.所以在此做一个搬运+个人的理解. 文中加粗部分为个人做的注解.需要讨论交流的朋友可以加 ...

  3. Hybrid App 应用开发中 9 个必备知识点复习(WebView / 调试 等)

    前言 我们大前端团队内部 ?每周一练 的知识复习计划继续加油,本篇文章是 <Hybrid APP 混合应用专题> 主题的第二期和第三期的合集. 这一期共整理了 10 个问题,和相应的参考答 ...

  4. Web前端-CSS必备知识点

    Web前端-CSS必备知识点 css基本内容,类选择符,id选择符,伪类,伪元素,结构,继承,特殊性,层叠,元素分类,颜色,长度,url,文本,字体,边框,块级元素,浮动元素,内联元素,定位. 链接: ...

  5. 软件测试就业必备知识点&自学软件测试-Dotest-2019

    软件测试就业必备知识点&自学测试&教学大纲-Dotest-2019

  6. ASP.NET MVC开发:Web项目开发必备知识点

    最近加班加点完成一个Web项目,使用Asp.net MVC开发.很久以前接触的Asp.net开发还是Aspx形式,什么Razor引擎,什么MVC还是这次开发才明白,可以算是新手. 对新手而言,那进行A ...

  7. Microsoft Dynamics CRM2011 必备知识点

    一.CRM基本知识 1.CRM2001 有几个服务端点? 答:对外公开的服务,如Web服务,WCF,Restful API 2.一个ERP系统,要访问CRM的数据,CRM2011有哪些现有的服务入口提 ...

  8. 使用html5中video自定义播放器必备知识点总结以及JS全屏API介绍

    一.video的js知识点: controls(控制器).autoplay(自动播放).loop(循环)==video默认的: 自定义播放器中一些JS中提供的方法和属性的记录: 1.play()控制视 ...

  9. MVC中权限的知识点及具体实现代码

    一:知识点部分 权限是做网页经常要涉及到的一个知识点,在使用MVC做权限设计时需要先了解以下知识: MVC中Url的执行是按照Controller->Action->View页面,但是我们 ...

随机推荐

  1. SpringCloud 与 SpringBoot版本问题

    如果SpringBoot版本与SpringCloud版本不一致,SpringBoot应用启动会报错: 解决方案: 版本对应关系可以在 https://start.spring.io/info 上查看: ...

  2. linux配置java

    https://www.cnblogs.com/zeze/p/5902124.html

  3. 助力全球抗疫:3D突发公共卫生事件管理平台

    前言 秋冬降临,北半球气温转凉.欧洲多个国家单日新增病例持续创新高,美国更是成为全球疫情最严重的国家.国内山东青岛.新疆喀什等地也相继发现多例病情.全球第二波疫情已经开始,国内疫情牵动人心,全球抗疫仍 ...

  4. 数据结构(C++)——链栈

    结点结构 typedef char ElemType; typedef struct LkStackNode{ ElemType data; LkStackNode *next; }*Stack,SN ...

  5. C#+Arduino Uno 实现声控系统完全实施手册

    话不多说先上视频,一看就懂 另外可参考这里:https://www.cnblogs.com/dehai/p/4285749.html ,这个近6年前的帖子 程序结构 程序分成上位机(PC端)与下位机( ...

  6. python中可迭代对象、迭代器、生成器

    可迭代对象 关注公众号"轻松学编程"了解更多. 1.列表生成式 list = [result for x in range(m, n)] g1 = (i for i in rang ...

  7. Java复制数组的方法

    java数组拷贝主要有四种方法,分别是循环赋值,System.arraycopy(),Arrays.copyOf()(或者Arrays.copyOfRange)和clone()方法.下面分别介绍一下这 ...

  8. 【杂谈】JS相关的线程模型整理

    1.JS是单线程吗? 是的,到目前为止,JS语言没有多线程的语法,它的执行引擎只支持单线程,也就是一个JavaScript进程内只有一个线程. 2.事件循环什么? 事件循环就是执行线程不断的从队列中取 ...

  9. 重磅解读:K8s Cluster Autoscaler模块及对应华为云插件Deep Dive

    摘要:本文将解密K8s Cluster Autoscaler模块的架构和代码的Deep Dive,及K8s Cluster Autoscaler 华为云插件. 背景信息 基于业务团队(Cloud BU ...

  10. leetcode144add-two-numbers

    题目描述 给定两个代表非负数的链表,数字在链表中是反向存储的(链表头结点处的数字是个位数,第二个结点上的数字是十位数...),求这个两个数的和,结果也用链表表示. 输入:(2 -> 4 -> ...