tf.py_func的一些使用笔记——TensorFlow1.x
tensorflow.py_func是TensorFlow1.x版本下的函数,在TensorFlow.2.x已经不建议使用了,但是依然可以通过tf.compat.v1.py_func的方式来进行调用。
可以说TensorFlow1.x下的py_func函数在TensorFlow2.x下除了通过tf.compat.v1.py_func的方式来进行调用
就再也没有等价的使用方法了,具体可以看TensorFlow2.x的API文档:
https://tensorflow.google.cn/api_docs/python/tf/compat/v1/py_func
----------------------------------------------------------
这里需要着重说明一点,很多人认为TensorFlow1.x中的tf.py_func等价于TensorFlow2.x中的tf.py_function,其实不然。在TensorFlow2.x中除了对tf.py_func进行v1版本保留和兼容的tf.compat.v1.py_func
,其实是没有完全同tf.py_func等价的函数。如果说在 TensorFlow2.x中 比较相近的函数应该是tf.numpy_function而不是tf.py_function。
在TensorFlow2.x中对tf.numpy_function的解释:
https://tensorflow.google.cn/api_docs/python/tf/numpy_function
在TensorFlow2.x中对tf.py_function的解释:
https://tensorflow.google.cn/api_docs/python/tf/py_function
----------------------------------------------------------
TensorFlow2.x中 tf.numpy_function 和 TensorFlow1.x中 tf.py_func
(tf.compat.v1.py_func
)中唯一的区别是:
tf.py_func
中是可以设置函数是否考虑状态的,而tf.numpy_function中是必须要考虑状态的(没有定义不考虑状态的设置)。
This name was deprecated and removed in TF2, but tf.numpy_function
is a near-exact replacement, just drop the stateful
argument (all tf.numpy_function
calls are considered stateful).
=============================================
对TensorFlow1.x中 tf.py_func
进行下一步解释:
tf.py_func其实是将python函数包装成TensorFlow的一个操作operation,tf.py_func的输入可以是numpy,可以是tensor,也可以是Variable,其输入只能是tensor。
tf.py_func定义的操作是属于TensorFlow的计算图的,在定义tf.py_func时是不会具体执行的,只有在具体的tf.Session中还可以执行,但是tf.py_func并不同于其他的TensorFlow的operation,因为tf.py_func定义的操作是运行在python空间下的而不是运行在TensorFlow空间下的。
tf.py_func定义后,在session中运行时的基本原理就是将输入的变量(不论是tensor还是numpy.array)转换为python空间下的numpy.array变量,在经过numpy运算后在将获得的numpy.array结果转换为tensor,给到TensorFlow的计算图。
其实,tf.py_func的功能完全可以手动实现类似的,就是手动的把tensor变量转为numpy.array,然后运算好后把结果手动转为tensor,tf.py_func最大的好处就是把这一过程给自动化了,不过随之也使这个运算过程变得难以理解了。从tf.py_func的原理我们就可以知道,虽然tf.py_func可以作为TensorFlow计算图的一部分挂在计算图上,但是由于其本质是将TensorFlow空间变量转为python空间变量后经过运算再转为TensorFlow空间变量,中间经过了命名空间和运算空间的转换,因此tf.py_func是不可以进行梯度反传的,或许我们更可以把这个操作看做是一种简易的为TensorFlow提供支持的python库。
================================================
2022年10月13日更新
如果tf.py_func包装的python函数的参数是string类型,那么传到包装的函数内时会被自动转为bytes类型,也就是string变bytes,这一点需要注意,否则真的是不知道什么地方报错的。
例子:
import tensorflow as tf
import numpy as np sess = tf.Session() def fun(a, b):
print("+"*30)
print("function fun excute!!!")
print(a, b)
return np.array(len(a+b), dtype=np.float32) x = "abc"
y = "bde"
ans = tf.py_func(fun, (x, y), (tf.float32, ), name="ab_op")
print("="*30, "result:")
print(ans) print(sess.run(ans))
运行结果:
可以看到,由tensorflow空间传参到python空间会自动的将string类型转为bytes类型。
在python3.x版本中,可以使用bytes.decode()的方法将传入的bytes类型转会string类型,具体:
修改后的代码:
import tensorflow as tf
import numpy as np sess = tf.Session() def fun(a, b):
print("+"*30)
print("function fun excute!!!")
print(a, b)
a = a.decode()
b = b.decode()
print(a, b)
return np.array(len(a+b), dtype=np.float32) x = "abc"
y = "bde"
ans = tf.py_func(fun, (x, y), (tf.float32, ), name="ab_op")
print("="*30, "result:")
print(ans) print(sess.run(ans))
重点部分:
================================================
一些例子:
以下代码均为TensorFlow1.x版本:
import tensorflow as tf
import numpy as np sess = tf.Session() def fun(a, b):
print("+"*30)
print("function fun excute!!!")
return a+1, b+1 x = np.array([1.0,2.0,3.0], dtype=np.float32)
y = np.array([4.0,5.0,6.0], dtype=np.float32) ans=tf.py_func(fun, (x, y), (tf.float32, tf.float32), name="ab_python") print("="*30, "result:")
print(ans)
print(sess.run(ans))
运行结果:
可以看到,tf.py_func的执行其实是为TensorFlow定义了一个operation,而tf.py_func所包装的python函数只有在TensorFlow执行计算图的时候才会被真正执行。
tf.py_func为包装的python函数所传入的参数可以是numpy.array类型,也可以是tensor类型,也可以是Variable类型,但是不管在tf.py_func中传入的参数是什么类型,最后传入到所包装的python函数中都会被转为numpy.array类型,而包装后的函数在session开始执行后所返回给计算图的数据类型也会被转换为tensor类型。
--------------------------------------------------------
包装的参数为tensor:
import tensorflow as tf
import numpy as np sess = tf.Session() def fun(a, b):
print("+"*30)
print("function fun excute!!!")
return a+1, b+1 x = tf.constant([1.0,2.0,3.0], dtype=np.float32)
y = tf.constant([4.0,5.0,6.0], dtype=np.float32)
ans =tf.py_func(fun, (x, y), (tf.float32, tf.float32), name="ab_op") print("="*30, "result:")
print(ans)
print("session is running!!!")
print(sess.run(ans))
运行结果:
--------------------------------------------------------
包装的参数为Variable:
import tensorflow as tf
import numpy as np sess = tf.Session() def fun(a, b):
print("+"*30)
print("function fun excute!!!")
return a+1, b+1 x = tf.Variable([1.0,2.0,3.0], dtype=np.float32)
y = tf.Variable([4.0,5.0,6.0], dtype=np.float32)
ans = tf.py_func(fun, (x, y), (tf.float32, tf.float32), name="ab_op")
print("="*30, "result:")
print(ans) sess.run(tf.global_variables_initializer())
print(sess.run(ans))
运行结果:
---------------------------
包装的参数为Variable:
import tensorflow as tf
import numpy as np sess = tf.Session() def fun(a, b):
print("+"*30)
print("function fun excute!!!")
return a+b x = tf.Variable([1.0,2.0,3.0], dtype=np.float32)
y = tf.Variable([4.0,5.0,6.0], dtype=np.float32)
ans = tf.py_func(fun, (x, y), (tf.float32, ), name="ab_op")
print("="*30, "result:")
print(ans) sess.run(tf.global_variables_initializer())
print(sess.run(ans))
运行结果:
---------------------------------------------
包装的参数为Variable,求反传梯度报错:
import tensorflow as tf
import numpy as np sess = tf.Session() def fun(a, b):
print("+"*30)
print("function fun excute!!!")
return a+b x = tf.Variable([1.0,2.0,3.0], dtype=np.float32)
y = tf.Variable([4.0,5.0,6.0], dtype=np.float32)
x2 = tf.Variable([1.0,2.0,3.0], dtype=np.float32)
y2 = tf.Variable([4.0,5.0,6.0], dtype=np.float32)
ans = tf.py_func(fun, (x, y), (tf.float32, ), name="ab_op")
ans2 = x2 + y2
print("="*30, "result:")
print(ans) sess.run(tf.global_variables_initializer())
print(sess.run(ans)) op2 = tf.gradients(ans2, (x2, y2))
print("Ops2 Gradients: \n", sess.run(op2)) op = tf.gradients(ans, (x, y))
print("Ops Gradients: \n", sess.run(op))
运行结果:
证明:
tf.py_func包装后的函数是不可以进行反传的。
其实,tf.py_func就是在tensorflow计算图执行的时候调用python代码,而调用python代码时运行在python的代码空间中,自然是不支持反传的。
==================================================
tf.py_func的一些使用笔记——TensorFlow1.x的更多相关文章
- tf.py_func
在 faster rcnn的tensorflow 实现中看到这个函数 rois,rpn_scores=tf.py_func(proposal_layer,[rpn_cls_prob,rpn_bbox ...
- Tensorflow之调试(Debug) && tf.py_func()
Tensorflow之调试(Debug)及打印变量 tensorflow调试tfdbg 几种常用方法: 1.通过Session.run()获取变量的值 2.利用Tensorboard查看一些可视化统计 ...
- 使用多块GPU进行训练 1.slim.arg_scope(对于同等类型使用相同操作) 2.tf.name_scope(定义名字的范围) 3.tf.get_variable_scope().reuse_variable(参数的复用) 4.tf.py_func(构造函数)
1. slim.arg_scope(函数, 传参) # 对于同类的函数操作,都传入相同的参数 from tensorflow.contrib import slim as slim import te ...
- tf.contrib.layers.fully_connected参数笔记
tf.contrib.layers.fully_connected 添加完全连接的图层. tf.contrib.layers.fully_connected( inputs, num_ou ...
- tf.split函数的用法(tensorflow1.13.0)
tf.split(input, num_split, dimension): dimension指输入张量的哪一个维度,如果是0就表示对第0维度进行切割:num_split就是切割的数量,如果是2就表 ...
- TensorFlow学习笔记(一):数据操作指南
扩充 TensorFlow tf.tile 对数据进行扩充操作 import tensorflow as tf temp = tf.tile([1,2,3],[2]) temp2 = tf.tile( ...
- tf.data
以往的TensorFLow模型数据的导入方法可以分为两个主要方法,一种是使用feed_dict另外一种是使用TensorFlow中的Queues.前者使用起来比较灵活,可以利用Python处理各种输入 ...
- tf调试函数
Tensorflow之调试(Debug)及打印变量 参考资料:https://wookayin.github.io/tensorflow-talk-debugging 几种常用方法: 1.通过Se ...
- R2CNN项目部分代码学习
首先放出大佬的项目地址:https://github.com/yangxue0827/R2CNN_FPN_Tensorflow 那么从输入的数据开始吧,输入的数据要求为tfrecord格式的数据集,好 ...
- tensorflow_目标识别object_detection_api,RuntimeError: main thread is not in main loop,fig = plt.figure(frameon=False)_tkinter.TclError: no display name and no $DISPLAY environment variable
最近在使用目标识别api,但是报错了: File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/script_o ...
随机推荐
- String忽略大小写方法compareToIgnoreCase源码及Comparator自定义比较器
String忽略大小写方法compareToIgnoreCase源码及Comparator自定义比较器 //源码 public int compareToIgnoreCase(String str) ...
- ZynqMP PL固件通过U-BOOT从指定位置加载FPGA BIT
原因 PL固件可能经常修改,而BOOT.BIN和文件系统.内核实际上基本不会变,在一个平台上可以用同一份.如果每次修改都要重新打包PL 固件到BOOT.BIN,操作起来非常麻烦.所以希望PL 的固件可 ...
- java并发和排序的简单例子(Runnable+TreeSet)
很多时候并发需要考虑线程安全,但也有很多时候和线程安全毛关系都没有,因为并发最大的作用是并行,线程安全仅仅是并发的一个子话题. 例如常常会用于并发运算,并发i/o. 下文是一个练习笔记. 运行环境:w ...
- idea远程debug(物理机、docker、k8s)
IDEA远程DEBUG 1:物理机部署的Springboot项目远程DEBUG 1.1:idea配置 点击"Edit Configurations",再点击+,选择Remote, ...
- 在Linux驱动中使用regmap
背景 在学习SPI的时候,看到了某个rtc驱动中用到了regmap,在学习了对应的原理以后,也记录一下如何使用. 介绍 在Linu 3.1开始,Linux引入了regmap来统一管理内核的I2C, S ...
- 全志T3+FPGA国产核心板——Pango Design Suite的FPGA程序加载固化
本文主要基于紫光同创Pango Design Suite(PDS)开发软件,演示FPGA程序的加载.固化,以及程序编译等方法.适用的开发环境为Windows 7/10 64bit. 测试板卡为全志T3 ...
- 从PDF到OFD,国产化浪潮下多种文档格式导出的完美解决方案
前言 近年来,中国在信息技术领域持续追求自主创新和供应链安全,伴随信创上升为国家战略,一些行业也开始明确要求文件导出的格式必须为 OFD 格式.OFD 格式目前在政府.金融.税务.教育.医疗等需要文件 ...
- 【VMware vSAN】vSAN Data Protection Part 2:配置管理。
上篇文章"vSAN Data Protection Part 1:安装部署."介绍了如何安装及部署 VMware Snapshot Service Appliance 设备,并在 ...
- Spring 常见的事务管理、事务的传播特性、隔离级别
事务管理 事务:多个操作,要么同时成功,要么失败后一起回滚 具备ACID四种特性 Atomic(原子性) Consistency(一致性) lsolation(隔离性) Durablility(持久性 ...
- QAnything AI开源的企业级本地知识库问答解决方案,致力于支持任意格式文件或数据库的问答
QAnything AI简介 QAnything ai是一个本地知识库问答系统,旨在支持多种文件格式和数据库,允许离线安装和使用.您可以简单地删除任何格式的任何本地存储文件,并获得准确.快速和可靠的答 ...