tensorflow 旋转矩阵的函数实现方法

关键字: rot90, tensorflow

1. 背景

在做数据增强的操作过程中, 很多情况需要对图像旋转和平移等操作, 针对一些特殊的卷积(garbo conv)操作,还需要对卷积核进行旋转操作.

在tensorflow中似乎没有实现对4D tensor的旋转操作.

严格的说: tensorflow对tensor的翻转操作并未实现, 仅有针对3D tensor的tf.image.rot()

而在大多数的情况下使用的是4D形式的tensor, [B,W,H,C] 或者是3D的图像组成的batchs.

通过查看这篇文章的代码可以知道[1] 可以使用numpy的rot90()函数旋转, 但是rot90对象是ndarray, 针对tensorflow.tensor对象而言显然是无法使用的, 会抛出类似: 无法找到m.dim属性的异常.

也就是说无法使用numpy.rot90() 函数.

又知, tensorflow中提供有对矩阵的翻转, 转置,切片操作的函数,但是没有提供旋转90°, 180°,270°的操作.

因此可以参照numpy.rot90(m, k=1, axes=(0,1)) 的程序片段去自己动手实现.

rot90中的第一个参数m是操作对象, k是旋转的次数,k=1 代表逆时针旋转90度, k=2 代表逆时针旋转180度,以此类推

axes是代表旋转的操作在哪两个维度构成的平面上.

rot90的源代码如下:

  1. def rot90(m, k=1, axes=(0,1)):
  2. '''
  3. ......
  4. '''
  5. # 省略检测参数的操作
  6. k %= 4
  7. if k == 0:
  8. return m[:]
  9. if k == 2:
  10. return flip(flip(m, axes[0]), axes[1])
  11. axes_list = arange(0, m.ndim)
  12. (axes_list[axes[0]], axes_list[axes[1]]) = (axes_list[axes[1]],
  13. axes_list[axes[0]])
  14. if k == 1:
  15. return transpose(flip(m,axes[1]), axes_list)
  16. else:
  17. # k == 3
  18. return flip(transpose(m, axes_list), axes[1])

PS: 通过阅读上述的代码,也可以发现在tensorflow中直接使用rot90所抛出的异常是在这里出现的

  1. if axes[0] == axes[1] or absolute(axes[0] - axes[1]) == m.ndim

原因是: 程序把tensor对象当成np.ndarray操作了, 而tensor对象没有m.dim属性

2. 实现rot90操作

2.1 梳理程序流程

通过查看源代码可以梳理出程序流程图:

2.2 tensorflow 实现旋转操作

根据上述的流程图, 可以实现对tensorflow的rot90操作;

  1. def rot90(tensor,k=1,axes=[1,2],name=None):
  2. '''
  3. autor:lizh
  4. tensor: a tensor 4 or more dimensions
  5. k: integer, Number of times the array is rotated by 90 degrees.
  6. axes: (2,) array_like
  7. The array is rotated in the plane defined by the axes.
  8. Axes must be different.
  9. -----
  10. Returns
  11. -------
  12. tensor : tf.tensor
  13. A rotated view of `tensor`.
  14. See Also: https://www.tensorflow.org/api_docs/python/tf/image/rot90
  15. '''
  16. axes = tuple(axes)
  17. if len(axes) != 2:
  18. raise ValueError("len(axes) must be 2.")
  19. tenor_shape = (tensor.get_shape().as_list())
  20. dim = len(tenor_shape)
  21. if axes[0] == axes[1] or np.absolute(axes[0] - axes[1]) == dim:
  22. raise ValueError("Axes must be different.")
  23. if (axes[0] >= dim or axes[0] < -dim
  24. or axes[1] >= dim or axes[1] < -dim):
  25. raise ValueError("Axes={} out of range for tensor of ndim={}."
  26. .format(axes, dim))
  27. k%=4
  28. if k==0:
  29. return tensor
  30. if k==2:
  31. img180 = tf.reverse(tf.reverse(tensor, axis=[axes[0]]),axis=[axes[1]],name=name)
  32. return img180
  33. axes_list = np.arange(0, dim)
  34. (axes_list[axes[0]], axes_list[axes[1]]) = (axes_list[axes[1]],axes_list[axes[0]]) # 替换
  35. print(axes_list)
  36. if k==1:
  37. img90=tf.transpose(tf.reverse(tensor,axis=[axes[1]]), perm=axes_list, name=name)
  38. return img90
  39. if k==3:
  40. img270=tf.reverse( tf.transpose(tensor, perm=axes_list),axis=[axes[1]],name=name)
  41. return img270

2.3 代码测试

  1. # 加载库
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import tensorflow as tf
  5. # 手写体数据集 加载
  6. from tensorflow.examples.tutorials.mnist import input_data
  7. mnist = input_data.read_data_sets("/home/lizhen/data/MNIST/", one_hot=True)
  8. sess=tf.Session()
  9. #选取数据 4D
  10. images = mnist.train.images
  11. img_raw = images[0,:] # [0,784]
  12. img=tf.reshape(img_raw,[-1,28,28,1]) # img 现在是tensor
  13. # 绘图
  14. def fig_2D_tensor(tensor):# 绘图
  15. #plt.matshow(tensor, cmap=plt.get_cmap('gray'))
  16. plt.matshow(tensor) # 彩色图像
  17. # plt.colorbar() # 颜色条
  18. plt.show()
  19. # 显 显示 待旋转的图片
  20. fig_2D_tensor(sess.run(img)[0,:,:,0]) # 提取ndarray

简单的测试一下代码:

  1. img11_rot=rot90(img,2) # 旋转两次90
  2. fig_2D_tensor(sess.run(img11_rot)[0,:,:,0]) # 打印图像
  3. img12_rot=rot90(img,1,[1,1]) # 抛出异常, 测试 Axes must be different.
  4. img13_rot=rot90(img,1,[0,6]) # 抛出异常, 测试 Axes must be different.
  5. img14_rot=rot90(img,axes=[1,5])# 抛出异常,测试out of range.
  6. img14_rot=rot90(img,axes=[-1,2]) # -1的下标是倒数第二个,测试out of range.

测试结果:

3总结

okey了,现在可以用了.

.....

额,,,,,最近才发现tensorflow的最新版本,大约就在前几天发布的新版本(14天前, 1.10.1 )上已经添加了对2D,3D图像的操作,支持[B,W,H,C]格式的tensor做出旋转[2]

星期五, 07. 九月 2018 02:49下午

参考文献


  1. Understanding 2D Dilated Convolution Operation with Examples in Numpy and Tensorflow with Interactive Code ↩︎

  2. tensorflow/python/ops/image_ops#rot90 ↩︎

tensorflow: a Implementation of rotation ops (旋转的函数实现方法)的更多相关文章

  1. 【转】Unity3D Transform中有关旋转的属性和方法测试

    Transform有关旋转个属性和方法测试 一,属性 1,var eulerAngles : Vector3 public float yRotation = 5.0F; void Update()  ...

  2. 【微软100题】定义字符串的左旋转操作:把字符串前面的若干个字符移动到字符串的尾部。 如把字符串abcdef左旋转2位得到字符串cdefab。请实现字符串左旋转的函数。

    package test; /** * 定义字符串的左旋转操作:把字符串前面的若干个字符移动到字符串的尾部. 如把字符串abcdef左旋转2位得到字符串cdefab. 请实现字符串左旋转的函数. * ...

  3. TensorFlow之DNN(三):神经网络的正则化方法(Dropout、L2正则化、早停和数据增强)

    这一篇博客整理用TensorFlow实现神经网络正则化的内容. 深层神经网络往往具有数十万乃至数百万的参数,可以进行非常复杂的特征变换,具有强大的学习能力,因此容易在训练集上过拟合.缓解神经网络的过拟 ...

  4. Android 解决setRequestedOrientation之后手机屏幕的旋转不触发onConfigurationChanged方法

    最近在做播放器的时候遇到一个问题,在屏幕方向改变之后需要切换播放器全屏/非全屏的时候,在重写了onConfigurationChanged方法并在manifest.xml配置文件中添加 android ...

  5. tensorflow构建CNN模型时的常用接口函数

    (1)tf.nn.max_pool()函数 解释: tf.nn.max_pool(value, ksize, strides, padding, data_format='NHWC', name=No ...

  6. Matrix控制平移、旋转和缩放的方法

    1.setTranslate(float ds,float dy):控制Matrix进行平移.2.setSkew(float kx,float ky,float px,float py):控制Matr ...

  7. JMeter接口HTTP请求implementation不选java会报错解决方法

    1.若不对c参数和d参数进行URL编码则需要选择implementation为java: 2.若想不设implementation值,则需进行c参数d参数URLEncoding import java ...

  8. TensorFlow 常用函数与方法

    摘要:本文主要对tf的一些常用概念与方法进行描述. tf函数 TensorFlow 将图形定义转换成分布式执行的操作, 以充分利用可用的计算资源(如 CPU 或 GPU.一般你不需要显式指定使用 CP ...

  9. QT 实现图片旋转的两种方法

    第一种方案 使用 QPixmap 的 transformed 函数来实现旋转,这个函数默认是以图片中心为旋转点,不能设置旋转的中心点,使用如下: QMatrix matrix; matrix.rota ...

随机推荐

  1. python学习,day2:字典

    字典的增删改查 # coding=utf-8 # Author: RyAn Bi info = { 'stu1101':'Tenglan Wu', 'stu1102':'longze Luola', ...

  2. BZOJ 2457 双端队列

           Sherry 现在碰到了一个棘手的问题,有N个整数需要排序.        Sherry 手头能用的工具就是若干个双端队列.        她需要依次处理这 N 个数,对于每个数, Sh ...

  3. P4027 [NOI2007]货币兑换

    传送门 首先有一个显然的贪心,每次操作都要做到底,为了最优不会出现只卖一部分或者只买一部分的操作 所以设 $f[i]$ 表示前 $i$ 天得到的最大价值,那么对于每一个 $i$,枚举所有 $j< ...

  4. zero-copy总结

    基本概念 零拷贝,通常在java NIO编程中会使用,比如netty网络工具包. 其真实意思是: 网卡或者其他外设进行io操作时不经过CPU, 而是直接和主memory交互,不经过CPU寄存器,这样可 ...

  5. 一个迷你的 Node.js 基于 Express 的 MVR 模式的 API工程 的分析

    1. 工程说明 该工程是基于 Express 库,编写的一个 API 查询返回的一个微型应用. API Resource 就是把 API 的内容当做网络资源去处理.工程中的路由访问也是返回 API 内 ...

  6. CDH集群安装配置(四)- mysql 的安装

    安装mysql,并且创建相关的表(只需要在chd1上面安装而且需要root权限)1.1 查看Centos自带mysql是否已经安装 yum list installed | grep mysql 卸载 ...

  7. Django 配置访问静态文件

    1.settings.py 首先在 settings 文件中,引用 os 模块: import os   定义根目录: BASE_DIR = os.path.dirname(os.path.dirna ...

  8. PL/SQL 游标

    本随笔不是原创,只是学习笔记,用于加深记忆,原创地址PL/SQL --> 游标 一.游标的相关概念和特性 1.定义: 映射到结果集中的某一行的特定位置,类似与C语言中的指针.即通过游标方式定位到 ...

  9. MySQL比较运算符的子查询

    使用比较运算符的子查询 =.>.<.>=.<=.<>.!=.<=> 语法结构 operand comparison_operator subquery ...

  10. SSH框架学习步骤

    Hibernate 对象状态 关系映射 SQL语句 缓存抓取 struts action的分发配置 参数传递  文件上传 spring IOC AOP