转载于:Pytorch中的仿射变换(affine_grid)

参考:详细解读Spatial Transformer Networks (STN)

假设我们有这么一张图片:

 

下面我们将通过分别通过手动编码和pytorch方式对该图片进行平移、旋转、转置、缩放等操作,这些操作的数学原理在本文中不会详细讲解。

实现载入图片(注意,下面的代码都是在 jupyter 中进行):

  1. from torchvision import transforms
  2. from PIL import Image
  3. import matplotlib.pyplot as plt
  4.  
  5. %matplotlib inline
  6.  
  7. img_path = "图片文件路径"
  8. img_torch = transforms.ToTensor()(Image.open(img_path))
  9.  
  10. plt.imshow(img_torch.numpy().transpose(1,2,0))
  11. plt.show()

平移操作

普通方式

例如我们需要向右平移50px,向下平移100px。

  1. import numpy as np
  2. import torch
  3.  
  4. theta = np.array([
  5. [1,0,50],
  6. [0,1,100]
  7. ])
  8. # 变换1:可以实现缩放/旋转,这里为 [[1,0],[0,1]] 保存图片不变
  9. t1 = theta[:,[0,1]]
  10. # 变换2:可以实现平移
  11. t2 = theta[:,[2]]
  12.  
  13. _, h, w = img_torch.size()
  14. new_img_torch = torch.zeros_like(img_torch, dtype=torch.float)
  15. for x in range(w):
  16. for y in range(h):
  17. pos = np.array([[x], [y]])
  18. npos = t1@pos+t2
  19. nx, ny = npos[0][0], npos[1][0]
  20. if 0<=nx<w and 0<=ny<h:
  21. new_img_torch[:,ny,nx] = img_torch[:,y,x]
  22. plt.imshow(new_img_torch.numpy().transpose(1,2,0))
  23. plt.show()

图片变为:

图片平移-1

pytorch 方式

向右移动0.2,向下移动0.4:

  1. from torch.nn import functional as F
  2.  
  3. theta = torch.tensor([
  4. [1,0,-0.2],
  5. [0,1,-0.4]
  6. ], dtype=torch.float)
  7. grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size())
  8. output = F.grid_sample(img_torch.unsqueeze(0), grid)
  9. new_img_torch = output[0]
  10. plt.imshow(new_img_torch.numpy().transpose(1,2,0))
  11. plt.show()

得到的图片为:

 
图片平移-2

总结:

  • 要使用 pytorch 的平移操作,只需要两步:theta 的第三列为平移比例,向右为负,向下为负;

    • 创建 grid:grid = torch.nn.functional.affine_grid(theta, size),其实我们可以通过调节 size 设置所得到的图像的大小(相当于resize);
    • grid_sample 进行重采样:outputs = torch.nn.functional.grid_sample(inputs, grid, mode='bilinear')
  • theta 的第三列为平移比例,向右为负,向下为负;

我们通过设置 size 可以将图像resize:

  1. from torch.nn import functional as F
  2.  
  3. theta = torch.tensor([
  4. [1,0,-0.2],
  5. [0,1,-0.4]
  6. ], dtype=torch.float)
  7. # 修改size
  8. N, C, W, H = img_torch.unsqueeze(0).size()
  9. size = torch.Size((N, C, W//2, H//3))
  10. grid = F.affine_grid(theta.unsqueeze(0), size)
  11. output = F.grid_sample(img_torch.unsqueeze(0), grid)
  12. new_img_torch = output[0]
  13. plt.imshow(new_img_torch.numpy().transpose(1,2,0))
  14. plt.show()
修改size的效果

缩放操作

普通方式

放大1倍:

  1. import numpy as np
  2. import torch
  3.  
  4. theta = np.array([
  5. [2,0,0],
  6. [0,2,0]
  7. ])
  8. t1 = theta[:,[0,1]]
  9. t2 = theta[:,[2]]
  10.  
  11. _, h, w = img_torch.size()
  12. new_img_torch = torch.zeros_like(img_torch, dtype=torch.float)
  13. for x in range(w):
  14. for y in range(h):
  15. pos = np.array([[x], [y]])
  16. npos = t1@pos+t2
  17. nx, ny = npos[0][0], npos[1][0]
  18. if 0<=nx<w and 0<=ny<h:
  19. new_img_torch[:,ny,nx] = img_torch[:,y,x]
  20. plt.imshow(new_img_torch.numpy().transpose(1,2,0))
  21. plt.show()

结果为:

放大操作-1

由于没有使用插值算法,所以中间有很多部分是黑色的。

pytorch 方式

  1. from torch.nn import functional as F
  2.  
  3. theta = torch.tensor([
  4. [0.5, 0 , 0],
  5. [0 , 0.5, 0]
  6. ], dtype=torch.float)
  7. grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size())
  8. output = F.grid_sample(img_torch.unsqueeze(0), grid)
  9. new_img_torch = output[0]
  10. plt.imshow(new_img_torch.numpy().transpose(1,2,0))
  11. plt.show()

结果为:

放大操作-2

结论:可以看到,affine_grid 的放大操作是以图片中心为原点的。

旋转操作

普通操作

将图片旋转30度:

  1. import numpy as np
  2. import torch
  3. import math
  4.  
  5. angle = 30*math.pi/180
  6. theta = np.array([
  7. [math.cos(angle),math.sin(-angle),0],
  8. [math.sin(angle),math.cos(angle) ,0]
  9. ])
  10. t1 = theta[:,[0,1]]
  11. t2 = theta[:,[2]]
  12.  
  13. _, h, w = img_torch.size()
  14. new_img_torch = torch.zeros_like(img_torch, dtype=torch.float)
  15. for x in range(w):
  16. for y in range(h):
  17. pos = np.array([[x], [y]])
  18. npos = t1@pos+t2
  19. nx, ny = int(npos[0][0]), int(npos[1][0])
  20. if 0<=nx<w and 0<=ny<h:
  21. new_img_torch[:,ny,nx] = img_torch[:,y,x]
  22. plt.imshow(new_img_torch.numpy().transpose(1,2,0))
  23. plt.show()

结果为:

旋转操作-1
 

pytorch 操作

  1. from torch.nn import functional as F
  2. import math
  3.  
  4. angle = -30*math.pi/180
  5. theta = torch.tensor([
  6. [math.cos(angle),math.sin(-angle),0],
  7. [math.sin(angle),math.cos(angle) ,0]
  8. ], dtype=torch.float)
  9. grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size())
  10. output = F.grid_sample(img_torch.unsqueeze(0), grid)
  11. new_img_torch = output[0]
  12. plt.imshow(new_img_torch.numpy().transpose(1,2,0))
  13. plt.show()

结果为:

旋转操作-2

pytorch 以图片中心为原点进行旋转,并且在旋转过程中会发生图片缩放,如果选择角度变为 90°,图片为:

旋转 90° 结果
 

转置操作

普通操作

  1. import numpy as np
  2. import torch
  3.  
  4. theta = np.array([
  5. [0,1,0],
  6. [1,0,0]
  7. ])
  8. t1 = theta[:,[0,1]]
  9. t2 = theta[:,[2]]
  10.  
  11. _, h, w = img_torch.size()
  12. new_img_torch = torch.zeros_like(img_torch, dtype=torch.float)
  13. for x in range(w):
  14. for y in range(h):
  15. pos = np.array([[x], [y]])
  16. npos = t1@pos+t2
  17. nx, ny = npos[0][0], npos[1][0]
  18. if 0<=nx<w and 0<=ny<h:
  19. new_img_torch[:,ny,nx] = img_torch[:,y,x]
  20. plt.imshow(new_img_torch.numpy().transpose(1,2,0))
  21. plt.show()

结果为:

 
图片转置-1

pytorch 操作

我们可以通过size大小,保存图片不被压缩:

  1. from torch.nn import functional as F
  2.  
  3. theta = torch.tensor([
  4. [0, 1, 0],
  5. [1, 0, 0]
  6. ], dtype=torch.float)
  7. N, C, H, W = img_torch.unsqueeze(0).size()
  8. grid = F.affine_grid(theta.unsqueeze(0), torch.Size((N, C, W, H)))
  9. output = F.grid_sample(img_torch.unsqueeze(0), grid)
  10. new_img_torch = output[0]
  11. plt.imshow(new_img_torch.numpy().transpose(1,2,0))
  12. plt.show()

结果为:

图片转置-2

(转载)Pytorch中的仿射变换(affine_grid)的更多相关文章

  1. [转载]PyTorch中permute的用法

    [转载]PyTorch中permute的用法 来源:https://blog.csdn.net/york1996/article/details/81876886 permute(dims) 将ten ...

  2. [转载]Pytorch中nn.Linear module的理解

    [转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...

  3. 【转载】 Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau

    原文地址: https://blog.csdn.net/happyday_d/article/details/85267561 ------------------------------------ ...

  4. (原)CNN中的卷积、1x1卷积及在pytorch中的验证

    转载请注明处处: http://www.cnblogs.com/darkknightzh/p/9017854.html 参考网址: https://pytorch.org/docs/stable/nn ...

  5. 转pytorch中训练深度神经网络模型的关键知识点

    版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_42279044/articl ...

  6. pytorch中tensor数据和numpy数据转换中注意的一个问题

    转载自:(pytorch中tensor数据和numpy数据转换中注意的一个问题)[https://blog.csdn.net/nihate/article/details/82791277] 在pyt ...

  7. 详解Pytorch中的网络构造,模型save和load,.pth权重文件解析

    转载:https://zhuanlan.zhihu.com/p/53927068 https://blog.csdn.net/wangdongwei0/article/details/88956527 ...

  8. [转载]PyTorch上的contiguous

    [转载]PyTorch上的contiguous 来源:https://zhuanlan.zhihu.com/p/64551412 这篇文章写的非常好,我这里就不复制粘贴了,有兴趣的同学可以去看原文,我 ...

  9. [转载]Pytorch详解NLLLoss和CrossEntropyLoss

    [转载]Pytorch详解NLLLoss和CrossEntropyLoss 来源:https://blog.csdn.net/qq_22210253/article/details/85229988 ...

随机推荐

  1. 查看Linux的本机IP

    命令式 ifconfig -a 在限制inet addr中显示本机的ip地址

  2. Pat 1003 甲级

    #include <cstdlib> #include <cstring> #include <iostream> #include <cstdio> ...

  3. php解释器模式( interpreter pattern)

    ... <?php /* The interpreter pattern specifies how to evaluate language grammar or expressions. W ...

  4. 《快活帮》第九次团队作业:【Beta】Scrum meeting 1

    项目 内容 这个作业属于哪个课程 2016计算机科学与工程学院软件工程(西北师范大学) 这个作业的要求在哪里 实验十三 团队作业9:BETA冲刺与团队项目验收 团队名称 快活帮 作业学习目标 (1)掌 ...

  5. spark jdbc(mysql) 读取并发度优化

    转自:https://blog.csdn.net/lsshlsw/article/details/49789373 很多人在spark中使用默认提供的jdbc方法时,在数据库数据较大时经常发现任务 h ...

  6. new.target元属性 | 分别用es5、es6 判断一个函数是否使用new操作符

    函数内部有两个方法 [[call]] 和 [[construct]] (箭头函数没有这个方法),当使用new 操作符时, 函数内部调用 [[construct]], 创建一个新实例,this指向这个实 ...

  7. 面向IO编程--一切皆文件

    in Unix, everything is a file.This simplifies the manipulation of data and devices into a set of cor ...

  8. Common Substrings POJ - 3415 (后缀自动机)

    Common Substrings \[ Time Limit: 5000 ms\quad Memory Limit: 65536 kB \] 题意 给出两个字符串,要求两个字符串公共子串长度不小于 ...

  9. 使用jpillora/dnsmasq 提供可视化管理的dns server

    实际开发中dns 是一个比较重要的组件,一般大家可能会选择使用dnsmasq 但是缺少UI可视化,有些人可能会选择powerdns jpillora/dnsmasq 是一个对于dnsmasq 的包装, ...

  10. 【学习笔记】fwt&&fmt&&子集卷积

    前言:yyb神仙的博客 FWT 基本思路:将多项式变成点值表达,点值相乘之后再逆变换回来得到特定形式的卷积: 多项式的次数界都为\(2^n\)的形式,\(A_0\)定义为前一半多项式(下标二进制第一位 ...