我是如何使计算提速>150倍的
我是如何使计算提速>150倍的
书接上文《我是如何使计算时间提速25.6倍》.
上篇文章提到, F-measure使用累计直方图可以进一步加速计算, 但是E-measure却没有改出来. 在写完上篇文章的那个晚上, 重新整理思路后, 我似乎想到了如何去使用累计直方图来再次提速.
速度的制约
虽然使用"解耦"的思路可以高效优化每一个阈值下指标的计算过程, 但是整体的 for
循环确实仍然会占用较大的时间. 又考虑到各个阈值下的计算实际上并无太大关联, 如果可以实现同时计算, 那必然可以进一步提升速度. 这里我们又要把目光放回到在计算F-measure时大放光彩的累计直方图的策略上.
在前面的解耦之后, 实际上获得的关键变量是 fg_fg_numel
和 fg_bg_numel
.
fg_fg_numel = np.count_nonzero(binarized_pred & gt)
fg_bg_numel = np.count_nonzero(binarized_pred & ~gt)
从这两个变量本身入手, 如果使用累计直方图的话, 实际上可以同时获得 >=不同阈值
下的前景像素(值为1)的数量, 计算的本质和 np.count_nonzero
是一样的东西. 所以我们可以进行直观的替换:
"""
函数内部变量命名规则:
pred属性(前景fg、背景bg)_gt属性(前景fg、背景bg)_变量含义
如果仅考虑pred或者gt,则另一个对应的属性位置使用`_`替换
"""
fg_fg_hist, _ = np.histogram(pred[gt], bins=bins)
fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins)
fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0)
fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0)
这样我们就获得了不同阈值下的对应的一系列 fg_fg_numel
和 fg_bg_numel
了. 这里需要注意的是, 使用的划分区间 bins
的设置. 由于默认的 histogram
划分的区间会包含最后一个端点, 所以比较合理的划分是 bins = np.linspace(0, 256, 257)
, 这样最后一个区间是 [255, 256]
, 就可以包含到最大的值, 又不会和 254
重复计数.
为了便于计算, 这里将后面会用到的 pred
前景统计 fg___numel_w_thrs
和背景统计 bg____numel_w_thrs
直接写出来, 便于使用:
fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs
bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs
后面的步骤和之前的基本一致, numpy的广播机制使得不需要改动太多. 由于这部分代码实际上再多处位置会被使用, 所以提取成一个单独的方法.
def generate_parts_numel_combinations(self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel):
bg_fg_numel = self.gt_fg_numel - fg_fg_numel
bg_bg_numel = pred_bg_numel - bg_fg_numel
parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel]
mean_pred_value = pred_fg_numel / self.gt_size
mean_gt_value = self.gt_fg_numel / self.gt_size
demeaned_pred_fg_value = 1 - mean_pred_value
demeaned_pred_bg_value = 0 - mean_pred_value
demeaned_gt_fg_value = 1 - mean_gt_value
demeaned_gt_bg_value = 0 - mean_gt_value
combinations = [
(demeaned_pred_fg_value, demeaned_gt_fg_value),
(demeaned_pred_fg_value, demeaned_gt_bg_value),
(demeaned_pred_bg_value, demeaned_gt_fg_value),
(demeaned_pred_bg_value, demeaned_gt_bg_value)
]
return parts_numel, combinations
后面计算 enhanced_matrix_sum
的部分也就顺理成章比较自然的可以写出来:
parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations(
fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs,
pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs,
)
# 这里虽然可以使用列表来收集各个results_part,但是列表之后还需要再转为numpy数组来求和,倒不如直接一次性申请好空间后面直接装入即可
results_parts = np.empty(shape=(4, 256), dtype=np.float64)
for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)):
align_matrix_value = 2 * (combination[0] * combination[1]) / \
(combination[0] ** 2 + combination[1] ** 2 + _EPS)
enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
results_parts[i] = enhanced_matrix_value * part_numel
enhanced_matrix_sum = results_parts.sum(axis=0)
整体梳理
主要逻辑已经搞定, 接下来就是将这些代码与原始的代码融合起来, 也就是整合原始代码的 cal_em_with_threshold
和 cal_enhanced_matrix
两个方法.
def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float:
binarized_pred = pred >= threshold
if self.gt_fg_numel == 0:
binarized_pred_bg_numel = np.count_nonzero(~binarized_pred)
enhanced_matrix_sum = binarized_pred_bg_numel
elif self.gt_fg_numel == self.gt_size:
binarized_pred_fg_numel = np.count_nonzero(binarized_pred)
enhanced_matrix_sum = binarized_pred_fg_numel
else:
enhanced_matrix_sum = self.cal_enhanced_matrix(binarized_pred, gt)
em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
return em
结合前面代码中计算出的各个阈值下的前背景元素的统计值, 上面这里的代码实际上可以通过使用现有运算结果进行化简, 即 if
的前两个分支. 另外阈值划分也不需要显式处理, 因为已经在累计直方图中搞定了. 所以这里的代码对于动态阈值计算的情况下, 是可以被合并到 cal_enhanced_matrix
的计算过程中的. 直接得到最终的整合后的方法:
def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
"""
函数内部变量命名规则:
pred属性(前景fg、背景bg)_gt属性(前景fg、背景bg)_变量含义
如果仅考虑pred或者gt,则另一个对应的属性位置使用`_`替换
"""
pred = (pred * 255).astype(np.uint8)
bins = np.linspace(0, 256, 257)
fg_fg_hist, _ = np.histogram(pred[gt], bins=bins)
fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins)
fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0)
fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0)
fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs
bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs
if self.gt_fg_numel == 0:
enhanced_matrix_sum = bg___numel_w_thrs
elif self.gt_fg_numel == self.gt_size:
enhanced_matrix_sum = fg___numel_w_thrs
else:
parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations(
fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs,
pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs,
)
results_parts = np.empty(shape=(4, 256), dtype=np.float64)
for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)):
align_matrix_value = 2 * (combination[0] * combination[1]) / \
(combination[0] ** 2 + combination[1] ** 2 + _EPS)
enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
results_parts[i] = enhanced_matrix_value * part_numel
enhanced_matrix_sum = results_parts.sum(axis=0)
em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
return em
还是为了重用, cal_em_with_threshold
(该方法需要保留, 因为还有另一种E-measure的计算情况需要用到该方法)可以被重构:
def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float:
"""
函数内部变量命名规则:
pred属性(前景fg、背景bg)_gt属性(前景fg、背景bg)_变量含义
如果仅考虑pred或者gt,则另一个对应的属性位置使用`_`替换
"""
binarized_pred = pred >= threshold
fg_fg_numel = np.count_nonzero(binarized_pred & gt)
fg_bg_numel = np.count_nonzero(binarized_pred & ~gt)
fg___numel = fg_fg_numel + fg_bg_numel
bg___numel = self.gt_size - fg___numel
if self.gt_fg_numel == 0:
enhanced_matrix_sum = bg___numel
elif self.gt_fg_numel == self.gt_size:
enhanced_matrix_sum = fg___numel
else:
parts_numel, combinations = self.generate_parts_numel_combinations(
fg_fg_numel=fg_fg_numel, fg_bg_numel=fg_bg_numel,
pred_fg_numel=fg___numel, pred_bg_numel=bg___numel,
)
results_parts = []
for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)):
align_matrix_value = 2 * (combination[0] * combination[1]) / \
(combination[0] ** 2 + combination[1] ** 2 + _EPS)
enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
results_parts.append(enhanced_matrix_value * part_numel)
enhanced_matrix_sum = sum(results_parts)
em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
return em
效率对比
使用本地的845张灰度预测图和二值mask真值数据进行测试比较, 重新跑了一遍, 总体时间对比如下:
方法 | 总体耗时(s) | 速度提升(倍) |
---|---|---|
'base' | 539.2173762321472s | x1 |
'best' | 19.94518733024597s | x27.0 (539.22/19.95) |
'cumsumhistogram' | 3.2935903072357178s | x163.8 (539.22/3.29) |
还是那句话, 虽然具体时间可能还受硬件限制, 但是相对快慢还是比较明显的.
测试代码可见我的 github
: https://github.com/lartpang/CodeForArticle/tree/main/sod_metrics
我是如何使计算提速>150倍的的更多相关文章
- 【Python】我是如何使计算时间提速25.6倍的
我是如何使计算时间提速25.6倍的 我的原始文档:https://www.yuque.com/lart/blog/aemqfz 在显著性目标检测任务中有个重要的评价指标, E-measure, 需要使 ...
- 转载:四两拨千斤:借助Spark GraphX将QQ千亿关系链计算提速20倍
四两拨千斤:借助Spark GraphX将QQ千亿关系链计算提速20倍 时间 2016-07-22 16:57:00 炼数成金 相似文章 (5) 原文 http://www.dataguru.cn/ ...
- 如何让 Xcode 在读写上提速100倍?
如何让 Xcode 在读写上提速100倍? 上个月参加了一场西雅图当地的线下 iOS 开发者聚会.Jeff Szuhay 作为一个有20+年开发经验的资深程序员,跟我讲了一套提高 iOS 开发效率的方 ...
- SmartIDE v0.1.18 已经发布 - 助力阿里国产IDE OpenSumi 插件安装提速10倍、Dapr和Jupyter支持、CLI k8s支持
SmartIDE v0.1.18 (cli build 3538) 已经发布,在过去的Sprint 18中,我们集中精力推进对 k8s 远程工作区 的支持,同时继续扩展SmartIDE对不同技术栈的支 ...
- python之提速千倍爆破一句话
看了一下冰河大佬写的文章特别有感:https://bbs.ichunqiu.com/thread-16952-1-1.html 简单描述一下: 利用传统的单数据提交模式. 比如下面这个一句话木马: & ...
- 提速1000倍,预测延迟少于1ms,百度飞桨发布基于ERNIE的语义理解开发套件
提速1000倍,预测延迟少于1ms,百度飞桨发布基于ERNIE的语义理解开发套件 11月5日,在『WAVE Summit+』2019 深度学习开发者秋季峰会上,百度对外发布基于 ERNIE 的语义理解 ...
- 这款 IDE 插件再次升级,让「小程序云」的开发部署提速 8 倍
今年3月份,在阿里云北京峰会上,阿里巴巴正式发布了“阿里巴巴小程序繁星计划”,截至当前,已经有成千上万的开发者加入这个计划,使得小程序得到蓬勃发展,然而不可避免的是,这些服务加重了对云端的开发部署.运 ...
- Python 之父爆料:明年至少令 Python 提速 1 倍!
大概在半年前,我偶然看到一篇文章,有人提出了给 Python 提速 5 倍的计划,并在寻找经费赞助.当时并没有在意,此后也没有看到这方面的消息. 但是,就在 5 月 13 日"2021 年 ...
- 图像转置的SSE优化(支持8位、24位、32位),提速4-6倍。
一.前言 转置操作在很多算法上都有着广泛的应用,在数学上矩阵转置更有着特殊的意义.而在图像处理上,如果说图像数据本身的转置,除了显示外,本身并无特殊含义,但是在某些情况下,确能有效的提高算法效率,比如 ...
随机推荐
- python模块导入(包)
模块 关注公众号"轻松学编程"了解更多. 1.1. 模块的概述 在计算机程序的开发过程中,随着程序代码越写越多,在一个文件里的代码就会越来越长,越来越不容易维护. 为了编写可维 ...
- rclone 云盘同步工具的正确打开方式
Rclone 是一款的命令行工具,支持在不同对象存储.网盘间同步.上传.下载数据. 官网网址:https://rclone.org/ Github 项目:https://github.com/ncw/ ...
- python开发基础(二)运算符以及数据类型之float(浮点类型)
# encoding: utf-8 # module builtins # from (built-in) # by generator 1.147 """ Built- ...
- 人体动作捕捉格式之BVH
BVH简介 BVH是BioVision公司推出的一种人体动作捕捉文件格式.这种文件以节点为核心元素,记录连续数帧内人体骨架的运动. BVH=? 研究一个东西的时候我比较喜欢先研究它的名字.BVH可以认 ...
- 第三方库文件Joi对数据进行验证的方法以及解决Joi.validate is not a function的问题
Joi:javaScript对象的规则描述语言和验证器 1.npm install joi@14.3.1 2.建立joi.js文件 3.导入第三方包joi const Joi = require('j ...
- linux netfilter rule match target 数据结构
对于netfilter 可以参考 https://netfilter.org/documentation/HOWTO/netfilter-hacking-HOWTO-3.html netfilter ...
- 剑指offer刷题(算法类_1)
斐波那契数列 007-斐波拉契数列 题目描述 题解 代码 复杂度 008-跳台阶 题目描述 题解 代码 复杂度 009-变态跳台阶 题目描述 题解 代码 复杂度 010-矩形覆盖 题目描述 题解 代码 ...
- HTTP介绍(一)
超文本传输协议(HTTP)是一种用于分布式,协作式超媒体信息系统的应用程序层协议.HTTP是万维网(World Wide Web)数据通信的基础,超文本文档包括指向用户可以轻松访问的其他资源的超链接, ...
- 【转】CentOS7 64位安装mysql教程
从最新版本的linux系统开始,默认的是 Mariadb而不是mysql!这里依旧以mysql为例进行展示 1.先检查系统是否装有mysql rpm -qa | grep mysql 这里返回空值,说 ...
- [USACO14JAN]Ski Course Rating G
题目链接:https://www.luogu.com.cn/problem/P3101 Slove 这题我们可以尝试建立一个图. 以相邻的两个点建边,边的权值为两个点高度差的绝对值,然后把边按照边权值 ...