pytorch常用函数总结(持续更新)

torch.max(input,dim)

求取指定维度上的最大值,,返回输入张量给定维度上每行的最大值,并同时返回每个最大值的位置索引。比如:

demo.shape
Out[7]: torch.Size([10, 3, 10, 10])
torch.max(demo,1)[0].shape
Out[8]: torch.Size([10, 10, 10])

torch.max(demo,1)[0]这其中的[0]取得就是返回的最大值,torch.max(demo,1)[1]就是返回的最大值对应的位置索引。例子如下:

a
Out[8]:
tensor([[1., 2., 3.],
[4., 5., 6.]])
a.max(1)
Out[9]:
torch.return_types.max(
values=tensor([3., 6.]),
indices=tensor([2, 2]))

class torch.nn.ParameterList(parameters=None)

submodules保存在一个list中。

ParameterList可以像一般的Python list一样被索引。而且ParameterList中包含的parameters已经被正确的注册,对所有的module method可见。

参数说明:

  • modules (list, optional) – a list of nn.Parameter

例子:

class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)]) def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, p in enumerate(self.params):
x = self.params[i // 2].mm(x) + p.mm(x)
return x

torch.cat()函数

cat是concatnate的意思:拼接,联系在一起。

先说cat( )的普通用法

如果我们有两个tensor是A和B,想把他们拼接在一起,需要如下操作:

C = torch.cat( (A,B),0 )  #按维数0拼接(竖着拼)

C = torch.cat( (A,B),1 )  #按维数1拼接(横着拼)

相当于将tensor按照指定维度进行拼接,比如A的shape为128*64*32*32,B的shape为 128*32*64*64,那么按照 torch.cat( (A,B),1)拼接的之后的形状为 128*96*64*64

注意:

两个tensor要想进行拼接,必须保证除了指定拼接的维度以外其他的维度形状必须相同,比如上面的例子,拼接A和B时,A的形状为128*64*32*32,B的形状为128*32*64*64,只有第二个维度的维数数值不同,其他的维度的维数都是相同的,所以拼接时可按维度1进行拼接(注意,维度的下标是从0开始的,比如 A 的形状对应的维度下标为:1280∗641∗322∗323128_0*64_1*32_2*32_31280​∗641​∗322​∗323​)

contiguous()函数的使用

contiguous一般与transpose,permute,view搭配使用:使用transpose或permute进行维度变换后,调用contiguous,然后方可使用view对维度进行变形(如:tensor_var.contiguous().view() ),示例如下:

x = torch.Tensor(2,3)
y = x.permute(1,0) # permute:二维tensor的维度变换,此处功能相当于转置transpose
y.view(-1) # 报错,view使用前需调用contiguous()函数
y = x.permute(1,0).contiguous()
y.view(-1) # OK

具体原因有两种说法:

1 transpose、permute等维度变换操作后,tensor在内存中不再是连续存储的,而view操作要求tensor的内存连续存储,所以需要contiguous来返回一个contiguous copy;

2 维度变换后的变量是之前变量的浅拷贝,指向同一区域,即view操作会连带原来的变量一同变形,这是不合法的,所以也会报错;---- 这个解释有部分道理,也即contiguous返回了tensor的深拷贝contiguous copy数据;

原文链接:https://zhuanlan.zhihu.com/p/64376950

tensor.repeat()函数

该函数传入的参数个数不少于tensor的维数,其中每个参数代表的是对该维度重复多少次,也就相当于复制的倍数,结合例子更好理解,如下:

>>> import torch
>>>
>>> a = torch.randn(33, 55)
>>> a.size()
torch.Size([33, 55])
>>>
>>> a.repeat(1, 1).size()
torch.Size([33, 55])
>>>
>>> a.repeat(2,1).size()
torch.Size([66, 55])
>>>
>>> a.repeat(1,2).size()
torch.Size([33, 110])
>>>
>>> a.repeat(1,1,1).size()
torch.Size([1, 33, 55])
>>>
>>> a.repeat(2,1,1).size()
torch.Size([2, 33, 55])
>>>
>>> a.repeat(1,2,1).size()
torch.Size([1, 66, 55])
>>>
>>> a.repeat(1,1,2).size()
torch.Size([1, 33, 110])
>>>
>>> a.repeat(1,1,1,1).size()
torch.Size([1, 1, 33, 55])
>>>
>>> # repeat()的参数的个数,不能少于被操作的张量的维度的个数,
>>> # 下面是一些错误示例
>>> a.repeat(2).size() # 1D < 2D, error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b = torch.randn(5,6,7)
>>> b.size() # 3D
torch.Size([5, 6, 7])
>>>
>>> b.repeat(2).size() # 1D < 3D, error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1).size() # 2D < 3D, error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1,1).size() # 3D = 3D, okay
torch.Size([10, 6, 7])
>>>

参考博客:https://blog.csdn.net/qq_29695701/article/details/89763168

pytorch常用函数总结(持续更新)的更多相关文章

  1. php常用函数(持续更新)

    每一种编程语言在用的过程中都会发现有时候要一种特定需求的功能函数,结果没有内置这样的函数,这个时候就需要自己根据已有函数编写尽可能简单的函数,下面是我在做php相关工作时积累下的函数,会持续更新,您要 ...

  2. JavaScript中常用函数(入门级)(持续更新)

    本文中枫竹梦介绍一些JavaScript中入门级的常用函数,对于已经过了入门的童鞋可选择略过,都是一些非常实用的函数.如果发现什么问题,欢迎讨论. 问题列表 Q1: 设计一个函数repeatIt(st ...

  3. php 常用函数集合(持续更新中...)

    php 常用函数集合 在php的开发中,巧妙的运用php自带的一些函数,会起到事半功倍的效果,在此,主要记录一些常用的函数 1.time(),microtime()函数 time():获取当前时间戳 ...

  4. Js 常用函数【持续更新】

    Js Math对象方法介绍:http://www.w3school.com.cn/jsref/jsref_obj_math.asp 1. 算数函数(Math) 1)Js小数取整 常用于:分页算法 js ...

  5. Oracle数据库常用函数使用--持续更新中

    NVL函数.NVL( string1, replace_with).如果string1为NULL,则NVL函数返回replace_with的值,否则返回原来的值. INSTR函数.用于查找指定字符串是 ...

  6. C语言中的常用函数_持续更新

    isspace函数: 背景:之前遇到scanf()输入时会把换行符留在输入队列的情况,如果下次要用到getchar(),但是会导致其先返回这个我们不需要的换行符:从而导致不希望出现的行为: 说明:检查 ...

  7. SqlServer一些常用函数(持续更新。。。)

    1. 字符串拼接: + 拼接 SELECT 'AA' + 'BB' A //AABB在2012版本sqlserver之后,可以使用cancat进行字符串拼接了. 2. 判断是否为空,并取另外的值 :I ...

  8. git常用命令(持续更新中)

    git常用命令(持续更新中) 本地仓库操作git int                                 初始化本地仓库git add .                       ...

  9. 【github&&git】4、git常用命令(持续更新中)

    git常用命令(持续更新中) 本地仓库操作git int                                 初始化本地仓库git add .                       ...

随机推荐

  1. C#LeetCode刷题-树状数组

    树状数组篇 # 题名 刷题 通过率 难度 218 天际线问题   32.7% 困难 307 区域和检索 - 数组可修改   42.3% 中等 315 计算右侧小于当前元素的个数   31.9% 困难 ...

  2. LeetCode406 queue-reconstruction-by-height详解

    题目详情 假设有打乱顺序的一群人站成一个队列. 每个人由一个整数对(h, k)表示,其中h是这个人的身高,k是排在这个人前面且身高大于或等于h的人数. 编写一个算法来重建这个队列. 注意: 总人数少于 ...

  3. 详解Python Graphql

    前言 很高兴现在接手的项目让我接触到了Python Graphql,百度上对其介绍相对较少也不够全面,几乎没有完整的中文文档,所以这边也借此机会学习一下Graphql. 什么是Graphql呢? Gr ...

  4. Oracle和Mysql分页的区别

    一.Mysql使用limit分页 select * from stu limit m, n; //m = (startPage-1)*pageSize,n = pageSize PS: (1)第一个参 ...

  5. 数据结构-二叉树(6)哈夫曼树(Huffman树)/最优二叉树

    树的路径长度是从树根到每一个结点的路径长度(经过的边数)之和. n个结点的一般二叉树,为完全二叉树时取最小路径长度PL=0+1+1+2+2+2+2+… 带权路径长度=根结点到任意结点的路径长度*该结点 ...

  6. Eligibility Traces and Plasticity on Behavioral Time Scales: Experimental Support of neoHebbian Three-Factor Learning Rules

    郑重声明:原文参见标题,如有侵权,请联系作者,将会撤销发布! Abstract 大多数基本行为,如移动手臂抓住物体或走进隔壁房间探索博物馆,都是在几秒钟的时间尺度上进化的:相反,神经元动作电位则是在几 ...

  7. A Review on Generative Adversarial Networks: Algorithms, Theory, and Applications

    1 Introduction GANs由两个模型组成:生成器和鉴别器.生成器试图捕获真实示例的分布,以便生成新的数据样本.鉴别器通常是一个二值分类器,尽可能准确地将生成样本与真实样本区分开来.GANs ...

  8. Java算法——分治法

         一.基本概念 在计算机科学中,分治法是一种很重要的算法.字面上的解释是“分而治之”,就是把一个复杂的问题分成两个或更多的相同或相似的子问题,再把子问题分成更小的子问题……直到最后子问题可以简 ...

  9. 接口测试 Mock 实战 | 结合 jq 完成批量化的手工 Mock

    本文霍格沃兹测试学院学员学习实践笔记. 一.应用背景 因为本章的内容是使用jq工具配合完成,因此在开始部分会先花一定的篇幅介绍jq机器使用,如果读者已经熟悉jq,可以直接跳过这部分. 先来看应用场景, ...

  10. 谷歌蜂鸟算法对网站seo优化有何影响

    http://www.wocaoseo.com/thread-89-1-1.html       谷歌在过去三个月里,非常低调的推出了蜂鸟算法,据谷歌技术员表示,此种方法一出,将影响90%网站的排名, ...