先看函数参数:

torch.flatten(input, start_dim=0, end_dim=-1)

input: 一个 tensor,即要被“推平”的 tensor。

start_dim: “推平”的起始维度。

end_dim: “推平”的结束维度。

首先如果按照 start_dim 和 end_dim 的默认值,那么这个函数会把 input 推平成一个 shape 为 $[n]$ 的tensor,其中 $n$ 即 input 中元素个数。

如果我们要自己设定起始维度和结束维度呢?

我们要先来看一下 tensor 中的 shape 是怎么样的:

t = torch.tensor([[[1, 2, 2, 1],
[3, 4, 4, 3],
[1, 2, 3, 4]],
[[5, 6, 6, 5],
[7, 8, 8, 7],
[5, 6, 7, 8]]])
print(t, t.shape)

运行结果:

tensor([[[1, 2, 2, 1],
[3, 4, 4, 3],
[1, 2, 3, 4]], [[5, 6, 6, 5],
[7, 8, 8, 7],
[5, 6, 7, 8]]])
torch.Size([2, 3, 4])

我们可以看到,最外层的方括号内含两个元素,因此 shape 的第一个值是 $2$;类似地,第二层方括号里面含三个元素,shape 的第二个值就是 $3$;最内层方括号里含四个元素,shape 的第二个值就是 $4$。

示例代码:

x = torch.flatten(t, start_dim=1)
print(x, x.shape) y = torch.flatten(t, start_dim=0, end_dim=1)
print(y, y.shape)

运行结果:

tensor([[1, 2, 2, 1, 3, 4, 4, 3, 1, 2, 3, 4],
[5, 6, 6, 5, 7, 8, 8, 7, 5, 6, 7, 8]]) torch.Size([2, 12])
tensor([[1, 2, 2, 1],
[3, 4, 4, 3],
[1, 2, 3, 4],
[5, 6, 6, 5],
[7, 8, 8, 7],
[5, 6, 7, 8]]) torch.Size([6, 4])

可以看到,当 start_dim = $1$ 而 end_dim = $-1$ 时,它把第 $1$ 个维度到最后一个维度全部推平合并了。而当 start_dim = $0$ 而 end_dim = $1$ 时,它把第 $0$ 个维度到第 $1$ 个维度全部推平合并了。

(这里注意的一点是,维度是从第 $0$ 维开始的)

而且,pytorch中的 torch.nn.Flatten 类和 torch.Tensor.flatten 方法其实都是基于上面的 torch.flatten 函数实现的。

关于torch.flatten的笔记的更多相关文章

  1. 深度学习框架 Torch 7 问题笔记

    深度学习框架 Torch 7 问题笔记 1. 尝试第一个 CNN 的 torch版本, 代码如下: -- We now have 5 steps left to do in training our ...

  2. numpy中flatten学习笔记

    ndarray.flatten() 用法 用于返回一个折叠成一维的数组.该函数只能适用于numpy对象,即array或者mat,普通的list列表是不行的. 例子 # coding=utf-8 fro ...

  3. 训练一个图像分类器demo in PyTorch【学习笔记】

    [学习源]Tutorials > Deep Learning with PyTorch: A 60 Minute Blitz > Training a Classifier   本文相当于 ...

  4. 关于torchvision.models中VGG的笔记

    VGG 主要有两种结构,分别是 VGG16 和 VGG19,两者并没有本质上的区别,只是网络深度不一样. 对于给定的感受野,采用堆积的小卷积核是优于采用大的卷积核的,因为多层非线性层可以增加网络深度来 ...

  5. [深度学习] Pytorch学习(一)—— torch tensor

    [深度学习] Pytorch学习(一)-- torch tensor 学习笔记 . 记录 分享 . 学习的代码环境:python3.6 torch1.3 vscode+jupyter扩展 #%% im ...

  6. [PyTorch 学习笔记] 3.1 模型创建步骤与 nn.Module

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/module_containers.py 这篇文章来看下 ...

  7. tensorflow/pytorch/mxnet的pip安装,非源代码编译,基于cuda10/cudnn7.4.1/ubuntu18.04.md

    os安装 目前对tensorflow和cuda支持最好的是ubuntu的18.04 ,16.04这种lts,推荐使用18.04版本.非lts的版本一般不推荐. Windows倒是也能用来装深度GPU环 ...

  8. 轻量级CNN模型之squeezenet

    SqueezeNet 论文地址:https://arxiv.org/abs/1602.07360 和别的轻量级模型一样,模型的设计目标就是在保证精度的情况下尽量减少模型参数.核心是论文提出的一种叫&q ...

  9. Pytorch学习之源码理解:pytorch/examples/mnists

    Pytorch学习之源码理解:pytorch/examples/mnists from __future__ import print_function import argparse import ...

随机推荐

  1. mysql 两表索引优化

    建表语句 CREATE TABLE IF NOT EXISTS `class`( `id` INT(10) UNSIGNED NOT NULL AUTO_INCREMENT, `card` INT(1 ...

  2. VM ubuntu18.04.01虚拟机没办法联网

    sudo service network-manager stop sudo rm /var/lib/NetworkManager/NetworkManager.state sudo service ...

  3. SciPy fftpack(傅里叶变换)

    章节 SciPy 介绍 SciPy 安装 SciPy 基础功能 SciPy 特殊函数 SciPy k均值聚类 SciPy 常量 SciPy fftpack(傅里叶变换) SciPy 积分 SciPy ...

  4. Kubernetes 各版本镜像列表

    以下镜像列表由 kubeadm v1.11.1 导出,若使用预下载镜像离线部署的方式部署,请使用 kubeadm v1.11.1 版本 导出各版本镜像列表: kubeadm config images ...

  5. Spring源码深度解析-《源码构建》

    1.gradle构建eclipse项目时,gradle-5.0版本构建失败,gradle-3.3构建成功!Why 2.导入spring-framework-3.2.x/spring-beans之前先导 ...

  6. 基于LAMP实现后台活动发布和前端扫码签到系统

    目的 无论是公司.学校和社会团体,都会举办各式各样的活动,比如运动会.部门会议.项目会议.野炊.团建等.作为团队管理者来讲,当然希望能够把这类活动转移到线上形成完整的系统,类似于电子流的形式.本文以学 ...

  7. 清北学堂模拟赛2 T2 ball

    题目大意: 多组数据,每组给定n,m,表示将n个小球放进m个箱子,每个小球均有两个箱子(可能相同)可放,求所有小球均放好的方案mod998244353的总数. 思路: 算是我和题解思路肥肠相近的一道题 ...

  8. SNOI2019 选做

    施工中... d1t1 字符串 题面 考虑两个字符串 \(s_i,s_j(i<j)\) ,实质是 \(s[i+1,\dots j]\) 和 \(s[i,\dots ,j-1]\) 的字符串字典序 ...

  9. mysq 事务管理入门

    设置隔离级别:

  10. delphi base64编码

    需要uses IdCoderMIME: function TForm1.Base64E(Path: string): string;var filepath: string; filestream: ...