Pytorch AdaptivePooing操作转Pooling操作

多数的前向推理框架不支持AdaptivePooing操作,此时需要将AdaptivePooing操作转换为普通的Pooling操作。AdaptivePooling与Max/AvgPooling相互转换提供了一种转换方法,但我在Pytorch1.6中的测试结果是错误的。通过查看Pytorch源码(pytorch-master\aten\src\ATen\native\AdaptiveAveragePooling.cpp)我找出了正确的转换方式。

  inline int start_index(int a, int b, int c) {
return (int)std::floor((float)(a * c) / b);
} inline int end_index(int a, int b, int c) {
return (int)std::ceil((float)((a + 1) * c) / b);
} template <typename scalar_t>
static void adaptive_avg_pool2d_single_out_frame(
scalar_t *input_p,
scalar_t *output_p,
int64_t sizeD,
int64_t isizeH,
int64_t isizeW,
int64_t osizeH,
int64_t osizeW,
int64_t istrideD,
int64_t istrideH,
int64_t istrideW)
{
at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) {
for (auto d = start; d < end; d++)
{
/* loop over output */
int64_t oh, ow;
for(oh = 0; oh < osizeH; oh++)
{
int istartH = start_index(oh, osizeH, isizeH);
int iendH = end_index(oh, osizeH, isizeH);
int kH = iendH - istartH; for(ow = 0; ow < osizeW; ow++)
{
int istartW = start_index(ow, osizeW, isizeW);
int iendW = end_index(ow, osizeW, isizeW);
int kW = iendW - istartW; /* local pointers */
scalar_t *ip = input_p + d*istrideD + istartH*istrideH + istartW*istrideW;
scalar_t *op = output_p + d*osizeH*osizeW + oh*osizeW + ow; /* compute local average: */
scalar_t sum = 0;
int ih, iw;
for(ih = 0; ih < kH; ih++)
{
for(iw = 0; iw < kW; iw++)
{
scalar_t val = *(ip + ih*istrideH + iw*istrideW);
sum += val;
}
} /* set output to local average */
*op = sum / kW / kH;
}
}
}
});
}

上述代码段中isizeH,isizeW分别表示输入张量的宽高osizeH,osizeW则表示输出宽高。关注第二个for循环for(oh = 0; oh < osizeH; oh++){.....}中的内容。假设输入的宽高均为223isizeH = isizeW = 223,输出的宽高均为7osizeH = osizeW = 224,然后简单分析一下oh=0,1,2时的情况:

  • oh=0, istartH = 0, iendH = ceil(223/7)=32, kH = 32
  • oh=1, istartH = floor(223/7) = 31, iendH = ceil(223*2/7)=64, kH = 33
  • oh=2, istartH = floor(223*2/7) = 63, iendH = ceil(223*3/7)=96, kH = 33

这里的kH就是kernel_size的大小. oh=0时的kernel_size比其他情况要小,所以需要在输入上添加padding,让oh=0时的kernel_size与其他情况相同。添加的padding大小为1,等价于让istartH从-1开始,即kH = 32-(-1) = 33. 下一个需要获取的参数是stride,stride = istartH[oh=i]-istartH[oh=i-1], 在上述例子中即为32。按照上述的例子分析输入宽高为224的情况可以发现padding=0,所以padding也是一个需要转换的参数。下面给出3个参数的转换公式:

  • stride = ceil(input_size / output_size)
  • kernel_size = ceil(2 * input_size / output_size) - floor(input_size / output_size)
  • padding = ceil(input_size / output_size) - floor(input_size / output_size)

在上述的代码中最后部分,可以看见均值使用*op = sum / kW / kH计算得到的。这表明在边缘部分计算均值没有考虑padding,所以对应的AvgPool中的count_include_pad应该设为False。下面贴出我的测试代码:

def test(size):
import numpy as np
import torch x = torch.randn(1,1,size,size) input_size = np.array(x.shape[2:])
output_size = np.array([7,7]) # stride = ceil(input_size / output_size)
# kernel_size = ceil(2 * input_size / output_size) - floor(input_size / output_size)
# padding = ceil(input_size / output_size) - floor(input_size / output_size) stride = numpy.ceil(input_size / output_size).astype(int)
kernel_size = (numpy.ceil(2 * input_size / output_size) - numpy.floor(input_size / output_size)).astype(int)
padding = (numpy.ceil(input_size / output_size) - numpy.floor(input_size / output_size)).astype(int)
print(stride)
print(kernel_size)
print(padding)
avg1 = nn.AdaptiveAvgPool2d(list(output_size))
avg2 = nn.AvgPool2d(kernel_size=kernel_size.tolist(), stride=stride.tolist(), padding=padding.tolist(), ceil_mode=False, count_include_pad=False)
max1 = nn.AdaptiveMaxPool2d(list(output_size))
max2 = nn.MaxPool2d(kernel_size=kernel_size.tolist(), stride=stride.tolist(), padding=padding.tolist(), ceil_mode=False ) avg1_out = avg1(x)
avg2_out = avg2(x)
max1_out = max1(x)
max2_out = max2(x)
print(avg1_out-avg2_out)
print(max1_out-max2_out)
print(torch.__version__)
  • inH = inW=224时的输出

  • inH = inW=223时的输出

Pytorch AdaptivePooing操作转Pooling操作的更多相关文章

  1. python基础操作以及hdfs操作

    目录 前言 基础操作 hdfs操作 总结 一.前言        作为一个全栈工程师,必须要熟练掌握各种语言...HelloWorld.最近就被"逼着"走向了python开发之路, ...

  2. [WCF编程]10.操作:单向操作

    一.单向操作概述 WCF提供了单向操作,一旦客户端调用,WCF会生成一个请求,但没有相关的应答信息返回给客户端.所以,单向操作是不能有返回值,服务抛出的任何异常都不会传递给客户端. 理想情况下,一旦客 ...

  3. Linq查询操作之聚合操作(count,max,min,sum,average,aggregate,longcount)

    在Linq中有一些这样的操作,根据集合计算某一单一值,比如集合的最大值,最小值,平均值等等.Linq中包含7种操作,这7种操作被称作聚合操作. 1.Count操作,计算序列中元素的个数,或者计算满足一 ...

  4. Linq查询操作之排序操作

    在Linq中排序操作可以按照一个或多个关键字对序列进行排序.其中第一个排序关键字为主要关键字,第二个排序关键字为次要关键字.Linq排序操作共包含以下5个基本的操作. 1.OrderBy操作,根据排序 ...

  5. Linq查询操作之投影操作

    投影操作,乍一看不知道在说啥.那么什么是投影操作呢?其实就是Select操作,名字起的怪怪的.和Linq查询表达式中的select操作是一样的.它能够选择数据源中的元素,并指定元素的表现形式.投影操作 ...

  6. Laravel框架数据库CURD操作、连贯操作

    这篇文章主要介绍了Laravel框架数据库CURD操作.连贯操作.链式操作总结,本文包含大量数据库操作常用方法,需要的朋友可以参考下 一.Selects 检索表中的所有行 $users = DB::t ...

  7. Laravel框架数据库CURD操作、连贯操作总结

    这篇文章主要介绍了Laravel框架数据库CURD操作.连贯操作.链式操作总结,本文包含大量数据库操作常用方法,需要的朋友可以参考下 一.Selects 检索表中的所有行 复制代码代码如下: $use ...

  8. IOS文件操作的两种方式:NSFileManager操作和流操作

    1.常见的NSFileManager文件方法 -(NSData *)contentsAtPath:path //从一个文件读取数据 -(BOOL)createFileAtPath: path cont ...

  9. ThinkPHP - 前置操作+后置操作

    前置操作和后置操作   系统会检测当前操作(不仅仅是index操作,其他操作一样可以使用)是否具有前置和后置操作,如果存在就会按照顺序执行,前置和后置操作的方法名是在要执行的方法前面加 _before ...

随机推荐

  1. 【简记】SpringBoot禁用Swagger

    楔子 Swagger 是 Java Web 开发中常用的接口文档生成类库,在开发和前后端联调时使用它来模拟接口调用能提高开发效率.但是,在生产环境可能并不需要它,一个原因是启用它会延长程序启动时间(动 ...

  2. Java多线程专题6: Queue和List

    合集目录 Java多线程专题6: Queue和List CopyOnWriteArrayList 如何通过写时拷贝实现并发安全的 List? CopyOnWrite(COW), 是计算机程序设计领域中 ...

  3. 微服务架构 | 10.1 使用 Sleuth 追踪服务调用链

    目录 前言 1. Sleuth 基础知识 1.1 Sleuth 原理 2. 在服务中使用 Sleuth 追踪 2.1 引入 pom.xml 依赖文件 2.2 查看日志信息 最后 前言 参考资料: &l ...

  4. 「JOISC 2014 Day1」 历史研究

    「JOISC 2014 Day1」 历史研究 Solution 子任务2 暴力,用\(cnt\)记录每种权值出现次数. 子任务3 这不是一个尺取吗... 然后用multiset维护当前的区间,动态加, ...

  5. Spring学习二:Spring Bean 定义

    Bean 定义 被称作 bean 的对象是构成应用程序的支柱也是由 Spring IoC 容器管理的.bean 是一个被实例化,组装,并通过 Spring IoC 容器所管理的对象.这些 bean 是 ...

  6. AndroidStudio项目提交到github最详细步骤【转】

    感谢大佬:https://www.cnblogs.com/imqsl/p/6763133.html 在使用studio开发的项目过程中有时候我们想将项目发布到github上,以前都是用一种比较麻烦的方 ...

  7. Visual Studio 下error C2471: 无法更新程序数据库

    转载请注明来源:https://www.cnblogs.com/hookjc/ 解决方案:修改项目属性 右击项目 --> "属性" 1. "C/C++" ...

  8. C++ XML解析之TinyXML

    转载请注明来源:https://www.cnblogs.com/hookjc/ 使用TinyXML进行C++ XML解析,感觉使用起来比较简单,很容易上手,本文给出一个使用TinyXML进行XML解析 ...

  9. [转]有关ListIterator接口的add与remove方法探究

    原文地址: http://www.java123.net/v/492971.html 应用案例: http://820199753.iteye.com/blog/2230032 ListIterato ...

  10. Spack 内置函数

    1.Map函数:通过函数传递源的每个元素,并形成新的分布式数据集. %spark #并行化集合生成RDD var data = sc.parallelize(List(10,20,30)) %输出结果 ...