我是如何使计算时间提速25.6倍的

我的原始文档:https://www.yuque.com/lart/blog/aemqfz

在显著性目标检测任务中有个重要的评价指标, E-measure, 需要使用在闭区间 [0, 255] 内连续变化的阈值对模型预测的灰度图二值化. 直接的书写方式就是使用 for 循环, 将对应的阈值送入指标得分计算函数中, 让其计算分割后的预测结果和真值mask之间的统计相似度.

在显著性目标检测中, 另一个指标, F-measure, 同样涉及到连续变化的阈值二值化处理, 但是该指标计算仅需要precision和recall, 这两项实际上仅需要正阳性(TP)和假阳性(FP)元素数量, 以及总的正(T)样本元素数量. T可以使用 np.count_nonzero(gt) 来计算, 而前两项则可以直接利用累计直方图的策略一次性得到所有的256个TP、FP数量对, 分别对应不同的阈值. 这样就可以非常方便且快速的计算出来这一系列的指标结果. 这实际上是对于F-measure计算的一种非常有效的加速策略.

但是不同的是, E-measure的计算方式(需要减去对应二值图的均值后进行计算)导致按照上面的这种针对变化阈值加速计算的策略并不容易变通, 至少我目前没有这样使用. 但是最后我找到了一种更加(相较于原始的 for 策略)高效的计算方式, 这里简单做一下思考和实验重现的记录.

选择使用更合适的函数

虽然运算主要基于 numpy 的各种函数, 但是针对同一个目的不同的函数实现方式也是有明显的速度差异的, 这里简单汇总下:

统计非零元素数量首选 np.count_nonzero(array)

我想到的针对二值图的几种不同的实现:

import time
import numpy as np # 快速统计numpy数组的非零值建议使用np.count_nonzero,一个简单的小实验
def cal_nonzero(size):
a = np.random.randn(size, size)
a = a > 0
start = time.time()
print(np.count_nonzero(a), time.time() - start)
start = time.time()
print(np.sum(a), time.time() - start)
start = time.time()
print(len(np.nonzero(a)[0]), time.time() - start)
start = time.time()
print(len(np.where(a)), time.time() - start) if __name__ == '__main__':
cal_nonzero(1000)
# 499950 6.723403930664062e-05
# 499950 0.0006949901580810547
# 499950 0.007088184356689453

可以看到, 最合适的是 np.count_nonzero(array) 了.

更快的交集计算方式

import time
import numpy as np # 快速统计numpy数组的非零值建议使用np.count_nonzero,一个简单的小实验
def cal_andnot(size):
a = np.random.randn(size, size)
b = np.random.randn(size, size)
a = a > 0
b = b < 0
start = time.time()
a_and_b_mul = a * b
_a_and__b_mul = (1 - a) * (1 - b)
print(time.time() - start)
start = time.time()
a_and_b_and = a & b
_a_and__b_and = ~a & ~b
print(time.time() - start) if __name__ == '__main__':
cal_andnot(1000)
# 0.0036919116973876953
# 0.0005502700805664062

可见, 对于bool数组, numpy的位运算是要更快更有效的. 而且bool数组可以直接用来索引矩阵即 array[bool_array] , 非常方便.

逻辑的改进

经过尽可能的挑选更加快速的计算函数之后, 目前速度受限的最大问题就是这个 for 循环中的256次矩阵运算了. 也就是这部分代码:

    ...
def step(self, pred: np.ndarray, gt: np.ndarray):
pred, gt = _prepare_data(pred=pred, gt=gt)
self.all_fg = np.all(gt)
self.all_bg = np.all(~gt)
self.gt_size = gt.shape[0] * gt.shape[1] if self.changeable_ems is not None:
changeable_ems = self.cal_changeable_em(pred, gt)
self.changeable_ems.append(changeable_ems)
adaptive_em = self.cal_adaptive_em(pred, gt)
self.adaptive_ems.append(adaptive_em) def cal_adaptive_em(self, pred: np.ndarray, gt: np.ndarray) -> float:
adaptive_threshold = _get_adaptive_threshold(pred, max_value=1)
adaptive_em = self.cal_em_with_threshold(pred, gt, threshold=adaptive_threshold)
return adaptive_em def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> list:
changeable_ems = [self.cal_em_with_threshold(pred, gt, threshold=th) for th in np.linspace(0, 1, 256)]
return changeable_ems def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float:
binarized_pred = pred >= threshold
if self.all_bg:
enhanced_matrix = 1 - binarized_pred
elif self.all_fg:
enhanced_matrix = binarized_pred
else:
enhanced_matrix = self.cal_enhanced_matrix(binarized_pred, gt)
em = enhanced_matrix.sum() / (gt.shape[0] * gt.shape[1] - 1 + _EPS)
return em def cal_enhanced_matrix(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
demeaned_pred = pred - pred.mean()
demeaned_gt = gt - gt.mean()
align_matrix = 2 * (demeaned_gt * demeaned_pred) / (demeaned_gt ** 2 + demeaned_pred ** 2 + _EPS)
enhanced_matrix = (align_matrix + 1) ** 2 / 4
return enhanced_matrix
...

可以看到, 这里对于每一个阈值都要计算一遍同样的流程, 如果每次的计算都比较耗时的话, 那么总体时间也就很难减下来. 所以需要探究如何降低这里的 cal_enhanced_matrix 的耗时.

前面的尝试都是在代码函数选择层面的改进, 但是对于这里, 这样的思路已经很难产生明显的效果了. 那么我们就应该转变思路了, 应该从计算流程本身上思考. 可以按照下面这一系列思考来引出最终的一种比较好的策略.

  • 这里计算为什么会那么慢?

    • 因为涉及到了大量的矩阵元素级的运算, 例如元素级减法、加法、乘法、平方、除法.
  • 大量的元素级运算是否可以优化?
    • 必须可以:<
  • 如何优化元素级运算?
    • 寻找规律性、重复性的计算, 将其合并、消减, 可以联想numpy的稀疏矩阵的思想.
  • 规律性、重复性的计算在哪里?
    • 去均值实际上是对每个元素减去了相同的一个值, 如果被减数可以优化, 那么这一步就可以被优化
    • 元素乘法和平方涉及到两部分, demeaned_gtdemeaned_pred, 如果这两个可以被优化, 那么这些运算就都可以被优化
    • 这些元素运算的连锁关系导致了只要我们优化了最初的predgt, 那么整个流程就都可以被优化
  • 如何优化predgt的表示?
    • 这里需要从二者本身的属性上入手
  • 二者最大的特点是什么?
    • 都是二值数组, 只有0和1
  • 那如何优化?
    • 实际上就借鉴了稀疏矩阵的思想, 既然存在大量的重复性, 那么我们就将数值与位置解耦, 优化表示方式
  • 如何解耦?
    • gt为例, 可以表示为0和1两种数据, 其中0对应背景, 1对应前景, 0的数量表示背景面积, 1的数量表示前景面积
  • 那如何使用该思想重构前面的计算呢?

到最后一个问题, 实际上核心策略已经出现, 就是"解耦", 将数值与位置解耦. 这里需要具体分析下, 我们直接将 predgt 拆分成数值和数量, 是可以比较好的处理 demeaned_* 项的表示的, 也就是:

# demeaned_pred = pred - pred.mean()
# demeaned_gt = gt - gt.mean()
pred_fg_numel = np.count_nonzero(binarized_pred)
pred_bg_numel = self.gt_size - pred_fg_numel
gt_fg_numel = np.count_nonzero(gt)
gt_bg_numel = self.gt_size - gt_fg_numel mean_pred_value = pred_fg_numel / self.gt_size
mean_gt_value = gt_fg_numel / self.gt_size demeaned_pred_fg_value = 1 - mean_pred_value
demeaned_pred_bg_value = 0 - mean_pred_value
demeaned_gt_fg_value = 1 - mean_gt_value
demeaned_gt_bg_value = 0 - mean_gt_value

接下来需要进一步优化后面的乘法和加法了, 因为这里同时涉及到了同一位置的 predgt 的值, 这就需要注意了, 因为二者前景与背景对应关系并不明确, 这就得分情况考虑了. 总体而言, 包含四种情况, 就是:

  1. pred: fg; gt: fg
  2. pred: fg; gt: bg
  3. pred: bg; gt: fg
  4. pred: bg; gt: bg

而这些区域实际上是对前面初步解耦区域的进一步细化, 所以我们重新整理思路, 可以将整个流程构造如下:

fg_fg_numel = np.count_nonzero(binarized_pred & gt)
fg_bg_numel = np.count_nonzero(binarized_pred & ~gt) # bg_fg_numel = np.count_nonzero(~binarized_pred & gt)
bg_fg_numel = self.gt_fg_numel - fg_fg_numel
# bg_bg_numel = np.count_nonzero(~binarized_pred & ~gt)
bg_bg_numel = self.gt_size - (fg_fg_numel + fg_bg_numel + bg_fg_numel) parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel] mean_pred_value = (fg_fg_numel + fg_bg_numel) / self.gt_size
mean_gt_value = self.gt_fg_numel / self.gt_size demeaned_pred_fg_value = 1 - mean_pred_value
demeaned_pred_bg_value = 0 - mean_pred_value
demeaned_gt_fg_value = 1 - mean_gt_value
demeaned_gt_bg_value = 0 - mean_gt_value combinations = [(demeaned_pred_fg_value, demeaned_gt_fg_value), (demeaned_pred_fg_value, demeaned_gt_bg_value),
(demeaned_pred_bg_value, demeaned_gt_fg_value), (demeaned_pred_bg_value, demeaned_gt_bg_value)]

这里忽略掉了一些不必要的计算, 能直接使用现有量就使用现有的量.

针对前面的这些解耦, 后面就可以比较简单的书写了:

results_parts = []
for part_numel, combination in zip(parts_numel, combinations):
# align_matrix = 2 * (demeaned_gt * demeaned_pred) / (demeaned_gt ** 2 + demeaned_pred ** 2 + _EPS)
align_matrix_value = 2 * (combination[0] * combination[1]) / \
(combination[0] ** 2 + combination[1] ** 2 + _EPS)
# enhanced_matrix = (align_matrix + 1) ** 2 / 4
enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
results_parts.append(enhanced_matrix_value * part_numel) # enhanced_matrix = enhanced_matrix.sum()
enhanced_matrix = sum(results_parts)

由于不同区域元素结果一致, 而区域的面积也已知, 所以最终 cal_em_with_threshold 中的 enhanced_matrix.sum() 其实更适合放到 cal_enhanced_matrix 中, 可以一便计算出来.

为了尽可能重用现有变量, 我们其实反过来可以优化 cal_em_with_threshold :

binarized_pred = pred >= threshold
if self.all_bg:
enhanced_matrix = 1 - binarized_pred
elif self.all_fg:
enhanced_matrix = binarized_pred
else:
enhanced_matrix = self.cal_enhanced_matrix(binarized_pred, gt)
em = enhanced_matrix.sum() / (gt.shape[0] * gt.shape[1] - 1 + _EPS)

这里的 self.all_bgself.all_fg 实际上可以使用 self.gt_fg_numelself.gt_size 表示, 也就是只需计算一次 np.count_nonzero(array) 就可以了. 另外在 cal_em_with_thresholdif 的前两个分支中, 需要将 sum 整合到各个分支内部(else分支已经被整合到了 cal_enhanced_matrix 方法中), (1-binarized_pred).sum()binarized_pred.sum() 实际上就是表示背景像素数量和前景像素数量. 所以可以借助于更快的 np.count_nonzero(array) , 从而改成如下形式:

binarized_pred = pred >= threshold

if self.gt_fg_numel == 0:
binarized_pred_bg_numel = np.count_nonzero(~binarized_pred)
enhanced_matrix_sum = binarized_pred_bg_numel
elif self.gt_fg_numel == self.gt_size:
binarized_pred_fg_numel = np.count_nonzero(binarized_pred)
enhanced_matrix_sum = binarized_pred_fg_numel
else:
enhanced_matrix_sum = self.cal_enhanced_matrix(binarized_pred, gt)
em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)

效率对比

使用本地的845张灰度预测图和二值mask真值数据进行测试比较, 总体时间对比如下:

  • 'base': 503.5014679431915s
  • 'best': 19.27734637260437s

虽然具体时间可能还受硬件限制, 但是相对快慢还是比较明显的. 变为原来的19/504~=4%, 快了504/19~=26.5倍.

测试代码可见我的 github : https://github.com/lartpang/CodeForArticle/tree/main/sod_metrics

【Python】我是如何使计算时间提速25.6倍的的更多相关文章

  1. 我是如何使计算提速>150倍的

    我是如何使计算提速>150倍的 我的原始文档:https://www.yuque.com/lart/blog/lwgt38 书接上文<我是如何使计算时间提速25.6倍>. 上篇文章提 ...

  2. Python基础与科学计算常用方法

    Python基础与科学计算常用方法 本文使用的是Jupyter Notebook,Python3.你可以将代码直接复制到Jupyter Notebook中运行,以便更好的学习. 导入所需要的头文件 i ...

  3. python常用标准库(时间模块 time和datetime)

    常用的标准库 time时间模块 import time time -- 获取本地时间戳 时间戳又被称之为是Unix时间戳,原本是在Unix系统中的计时工具. 它的含义是从1970年1月1日(UTC/G ...

  4. flex安装时停在计算时间界面的解决办法

    现象:安装FLEX BUILDER4.6时停在计算时间界面,过了一会后弹出安装失败的对话框. 环境:WIN7 解决: 1.下载AdobeCreativeCloudCleanerTool, 地址:htt ...

  5. Python实现天数倒计时计算

    tips:在datetime模块里有一个计算时间差的 timedelta.让两个datetime对象相减就得到timedelta ###--Python实现天数倒计时计算 #tips:在datetim ...

  6. java为啥计算时间从1970年1月1日开始

    http://www.myexception.cn/program/1494616.html ————————————————————————————————————————————————————— ...

  7. python使用datetime模块计算各种时间间隔的方法

    python使用datetime模块计算各种时间间隔的方法 本文实例讲述了python使用datetime模块计算各种时间间隔的方法.分享给大家供大家参考.具体分析如下: python中通过datet ...

  8. python 文本相似度计算

    参考:python文本相似度计算 原始语料格式:一个文件,一篇文章. #!/usr/bin/env python # -*- coding: UTF-8 -*- import jieba from g ...

  9. Python实现进度条和时间预估的示例代码

    一.前言 很多人学习python,不知道从何学起.很多人学习python,掌握了基本语法过后,不知道在哪里寻找案例上手.很多已经做案例的人,却不知道如何去学习更加高深的知识.那么针对这三类人,我给大家 ...

随机推荐

  1. CF618F Double Knapsack

    题意简化 给定两个大小为 n 的集合A,B,要求在每个集合中选出一个子集,使得两个选出来的子集元素和相等 元素范围在 1~n ,n<=1e5 题目连接 题解 考虑前缀和 令A集合的前缀和为SA, ...

  2. ASP.NET Core Authentication系列(四)基于Cookie实现多应用间单点登录(SSO)

    前言 本系列前三篇文章分别从ASP.NET Core认证的三个重要概念,到如何实现最简单的登录.注销和认证,再到如何配置Cookie 选项,来介绍如何使用ASP.NET Core认证.感兴趣的可以了解 ...

  3. 如何在Windows Server 2012及更高版本中将域控制器降级

    如何在Windows Server 2012及更高版本中将域控制器降级 如果不降级就重装系统,会出问题,所以在将域控系统重装系统之前一定要先降级. 使用服务器管理器将 Windows Server 2 ...

  4. 寻找性能更优秀的动态 Getter 和 Setter 方案

    反射获取 PropertyInfo 可以对对象的属性值进行读取或者写入,但是这样性能不好.所以,我们需要更快的方案. 方案说明 就是用表达式编译一个 Action<TObj,TValue> ...

  5. 一次打包引发的思考,原来maven还能这么玩?

    持续原创输出,点击上方蓝字关注我 目录 前言 依赖关系 你会怎么做? 必知的几个参数 总结 前言 昨天有一个读者找我的交流工作心得,偶然间提到一个有趣的问题,如下: 「大致的意思」:公司最近在整多模块 ...

  6. Java程序员成长之路

    北哥在前文总结了程序员的核心能力,但在专业能力维度,只是做了大概的阐述,并没有详细展开.从今天开始,我会把我作为程序员成长过程中,学习的知识总结成系列文章陆续发出来,供大家学习参考. 本文是第一篇,关 ...

  7. 9 HTTP和HTTPS

    9 HTTP和HTTPS 状态码 定义 1xx 报告 接收到请求,继续进程 2xx 成功 步骤成功接收,被理解,并被接受 3xx 重定向 为了完成请求,必须采取进一步措施 4xx 客户端出错 请求包括 ...

  8. spring 中aop 切面实战

    切面相关注解: @Aspect : 声明该类为一个注解类 @Pointcut : 定义一个切点 @Before : 在切点之前执行 @After : 在切点之后执行 不管目标方法是否执行成功 @Aft ...

  9. PF_PACKET&&tcpdump

    linux下抓包原理 linux下的抓包是通过注册一种虚拟的底层网络协议来完成对网络设备消息的处理权.当网卡接收到一个网络报文之后,它会遍历系统中所有已经注册的网络协议,当抓包模块把自己伪装成一个网络 ...

  10. 329. Longest Increasing Path in a Matrix(核心在于缓存遍历过程中的中间结果)

    Given an integer matrix, find the length of the longest increasing path. From each cell, you can eit ...