pytorch torch.nn.functional实现插值和上采样
interpolate
torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None)
根据给定的size或scale_factor参数来对输入进行下/上采样
使用的插值算法取决于参数mode的设置
支持目前的temporal(1D, 如向量数据), spatial(2D, 如jpg、png等图像数据)和volumetric(3D, 如点云数据)类型的采样数据作为输入,输入数据的格式为minibatch x channels x [optional depth] x [optional height] x width,具体为:
- 对于一个temporal输入,期待着3D张量的输入,即minibatch x channels x width
- 对于一个空间spatial输入,期待着4D张量的输入,即minibatch x channels x height x width
- 对于体积volumetric输入,则期待着5D张量的输入,即minibatch x channels x depth x height x width
可用于重置大小的mode有:最近邻、线性(3D-only),、双线性, 双三次(bicubic,4D-only)和三线性(trilinear,5D-only)插值算法和area算法
参数:
input (Tensor) – 输入张量
size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]) –输出大小.
scale_factor (float or Tuple[float]) – 指定输出为输入的多少倍数。如果输入为tuple,其也要制定为tuple类型
mode (str) –
可使用的上采样算法,有
'nearest'
,'linear'
,'bilinear'
,'bicubic'
,'trilinear'和'area'
.默认使用
'nearest'
align_corners (bool, optional) –
几何上,我们认为输入和输出的像素是正方形,而不是点。如果设置为True,则输入和输出张量由其角像素的中心点对齐,从而保留角像素处的值。如果设置为False,则输入和输出张量由它们的角像素的角点对齐,插值使用边界外值的边值填充;
当scale_factor保持不变时
,使该操作独立于输入大小。仅当使用的算法为'linear'
,'bilinear', 'bilinear'
or'trilinear'时可以使用。
默认设置为
False
注意:
使用mode='bicubic'时,可能会导致overshoot问题,即它可以为图像生成负值或大于255的值。如果你想在显示图像时减少overshoot问题,可以显式地调用result.clamp(min=0,max=255)。
When using the CUDA backend, this operation may induce nondeterministic behaviour in be backward that is not easily switched off. Please see the notes on Reproducibility for background.
警告:
当align_corners = True时,线性插值模式(线性、双线性、双三线性和三线性)不按比例对齐输出和输入像素,因此输出值可以依赖于输入的大小。这是0.3.1版本之前这些模式的默认行为。从那时起,默认行为是align_corners = False,如下图:
上面的图是source pixel为4*4上采样为target pixel为8*8的两种情况,这就是对齐和不对齐的差别,会对齐左上角元素,即设置为align_corners = True时输入的左上角元素是一定等于输出的左上角元素。但是有时align_corners = False时左上角元素也会相等,官网上给的例子就不太能说明两者的不同(也没有试出不同的例子,大家理解这个概念就行了)
举例:
import torch
from torch import nn
import torch.nn.functional as F
input = torch.arange(, , dtype=torch.float32).view(, , , )
input
返回:
tensor([[[[., .],
[., .]]]])
x = F.interpolate(input, scale_factor=, mode='nearest')
x
返回:
tensor([[[[., ., ., .],
[., ., ., .],
[., ., ., .],
[., ., ., .]]]])
x = F.interpolate(input, scale_factor=, mode='bilinear', align_corners=True)
x
返回:
tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
[1.6667, 2.0000, 2.3333, 2.6667],
[2.3333, 2.6667, 3.0000, 3.3333],
[3.0000, 3.3333, 3.6667, 4.0000]]]])
也提供了一些Upsample的方法:
upsample
torch.nn.functional.upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None)
torch.nn.functional.upsample_nearest(input, size=None, scale_factor=None)
torch.nn.functional.upsample_bilinear(input, size=None, scale_factor=None)
因为这些现在都建议使用上面的interpolate方法实现,所以就不解释了
更加复杂的例子可见:pytorch 不使用转置卷积来实现上采样
pytorch torch.nn.functional实现插值和上采样的更多相关文章
- PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx
PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx 在写 PyTorch 代码时,我们会发现一些功能重复的操作,比如卷积.激活.池化等操作.这些操作分别可 ...
- [pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList
1. torch.nn与torch.nn.functional之间的区别和联系 https://blog.csdn.net/GZHermit/article/details/78730856 nn和n ...
- Pytorch本人疑问(1) torch.nn和torch.nn.functional之间的区别
在写代码时发现我们在定义Model时,有两种定义方法: torch.nn.Conv2d()和torch.nn.functional.conv2d() 那么这两种方法到底有什么区别呢,我们通过下述代码看 ...
- torch.nn.functional中softmax的作用及其参数说明
参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/functional/#_1 class torch.nn.Soft ...
- 从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系
从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系 relu多种实现之间的关系 relu 函数在 pytorch 中总共有 3 次出现: torc ...
- pytorch torch.nn 实现上采样——nn.Upsample
Vision layers 1)Upsample CLASS torch.nn.Upsample(size=None, scale_factor=None, mode='nearest', align ...
- Pytorch——torch.nn.Sequential()详解
参考:官方文档 源码 官方文档 nn.Sequential A sequential container. Modules will be added to it in the order th ...
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
- 『PyTorch』第十二弹_nn.Module和nn.functional
大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...
随机推荐
- mysql 5.6 rpm安装启动、配置参数、字符集修改等
linux 7 安装mysql server 注意:此mysql版本是el6 MySQL-server-5.6.35-1.el6.x86_64 一.安装部署: 1.yum:首先要配置yum源,yum安 ...
- 小程序框架之视图层 View~基础组件
框架为开发者提供了一系列基础组件,开发者可以通过组合这些基础组件进行快速开发.详细介绍请参考组件文档. 什么是组件: 组件是视图层的基本组成单元. 组件自带一些功能与微信风格一致的样式. 一个组件通常 ...
- 【Java】《Java程序设计基础教程》第二章学习
一.标识符 Java 中标识符的使用有如下规定:(1)标识符由字母.数字.美元符号”$”和下划线”_”组成,除此之外的任何其他符号是不能作为标识符使用的.(2)标识符中的第一个字符不能为数字. (3 ...
- sleep() 和 wait() 区别是什么?
sleep() 和 wait() 区别是什么? 1.每个对象都有一个锁来控制同步访问,Synchronized关键字可以和对象的锁交互,来实现同步方法或同步块.sleep()方法正在执行的线程主动让出 ...
- Vue创建组件的三种方式
1.使用 Vue.extend 来创建全局的Vue组件 <div id="app"> <!-- 如果要使用组件,直接,把组件的名称,以 HTML 标签的形式,引入 ...
- Storm 安装时 部分supervisor启动成功,并不在web ui上显示
今天帮公司搭建集群时,发现启动了三个Supervisor 发现只有一个显示在Web UI 上. 于是我就简单地检查了下另外两台没有启动的 storm supervisor的日志, 发现没有报出什么异常 ...
- map json 字符串 对象之间的相互转化
1.对象与字符串之间的互转 将对象转换成为字符串 String str = JSON.toJSONString(infoDo); 字符串转换成为对象 InfoDo infoDo = JSON.pars ...
- git 在 A 项目中引用 B 项目
git 在 A 项目中引用 B 项目 场景: 需要在项目calcDLL(http://XXX/XXXA.git) 中 引用 项目libindex(http://XXX/XXXB.git). 解决方 ...
- 最长不下降子序列 nlogn && 输出序列
最长不下降子序列实现: 利用序列的单调性. 对于任意一个单调序列,如 1 2 3 4 5(是单增的),若这时向序列尾部增添一个数 x,我们只会在意 x 和 5 的大小,若 x>5,增添成功,反之 ...
- FTP与HTTP上传文件的对比
许多站点,比如facebook或一些博客等都允许用户上传或下载文件,比如论坛或博客系统的图片. 在这种情况下,通常有两种选择上传文件到服务器,那就是FTP协议和HTTP协议. 以下列出了一些两者的不同 ...