P和C
- import tensorflow as tf
- import numpy as np
- import math
- import keras
- from keras.layers import Conv2D,Reshape,Input
- import numpy as np
- import matplotlib.pyplot as plt
- """ Channel attention module"""
- if __name__ == '__main__':
- file = tf.read_file('img.jpg')
- x = tf.image.decode_jpeg(file)
- #print("Tensor:", x)
- sess = tf.Session()
- x1 = sess.run(x)
- print("x1:",x1)
- gamma = 0.05
- sess = tf.Session()
- x1 = sess.run(x)
- x1 = tf.expand_dims(x1, dim =0)
- print("x1.shape:", x1.shape)
- m_batchsize, height, width, C = x1.shape
- proj_query = Reshape((width * height, C))(x1)
- print("proj_query:", type(proj_query))
- print("proj_query:", proj_query.shape)
- proj_query = sess.run(proj_query)
- print(proj_query)
- proj_key = Reshape((width * height, C))(x1)
- proj_key = sess.run(proj_key).transpose(0, 2, 1)
- print(proj_key)
- print("proj_key:", type(proj_key))
- print("proj_key:", proj_key.shape)
- proj_query = proj_query.astype(np.float32)
- proj_key = proj_key.astype(np.float32)
- # N, C, C, bmm 批次矩阵乘法
- energy = tf.matmul(proj_key,proj_query)
- energy = sess.run(energy)
- print("energy:", energy)
- # 这里实现了softmax用最后一维的最大值减去了原始数据, 获得了一个不是太大的值
- # 沿着最后一维的C选择最大值, keepdim保证输出和输入形状一致, 除了指定的dim维度大小为1
- energy_new = tf.reduce_max(energy, -1, keep_dims=True)
- print("after_softmax_energy:",sess.run(energy_new))
- sess = tf.Session()
- e = energy_new
- print("b:", sess.run(energy_new))
- size = energy.shape[1]
- for i in range(size - 1):
- e = tf.concat([e, energy_new], axis=-1)
- energy_new = e
- print("energy_new2:", sess.run(energy_new))
- energy_new = energy_new - energy
- print("energy_new3:", sess.run(energy_new))
- attention = tf.nn.softmax(energy_new, axis=-1)
- print("attention:", sess.run(attention))
- proj_value = Reshape((width * height, C))(x1)
- proj_value = sess.run(proj_value)
- proj_value = proj_value.astype(np.float32)
- print("proj_value:", proj_value.shape)
- out = tf.matmul(proj_value, attention)
- out = sess.run(out)
- #plt.imshow(out)
- print("out1:", out)
- out = out.reshape(m_batchsize, width * height, C)
- #out1 = out.reshape(m_batchsize, C, height, width)
- print("out2:", out.shape)
- out = gamma * out + x
- #out = sess.run(out)
- #out = out.astype(np.int16)
- print("out3:", out)
- import tensorflow as tf
- import numpy as np
- import math
- import keras
- from keras.layers import Conv2D,Reshape,Input
- from keras.regularizers import l2
- from keras.layers.advanced_activations import ELU, LeakyReLU
- from keras import Model
- import cv2
- """
- Important:
- 1、A为CxHxW => Conv+BN+ReLU => B, C 都为CxHxW
- 2、Reshape B, C to CxN (N=HxW)
- 3、Transpose B to B’
- 4、Softmax(Matmul(B’, C)) => spatial attention map S为NxN(HWxHW)
- 5、如上式1, 其中sji测量了第i个位置在第j位置上的影响
- 6、也就是第i个位置和第j个位置之间的关联程度/相关性, 越大越相似.
- 7、A => Covn+BN+ReLU => D 为CxHxW => reshape to CxN
- 8、Matmul(D, S’) => CxHxW, 这里设置为DS
- 9、Element-wise sum(scale parameter alpha * DS, A) => the final output E 为 CxHxW (式2)
- 10、alpha is initialized as 0 and gradually learn to assign more weight.
- """
- """
- inputs :
- x : input feature maps( N X C X H X W)
- returns :
- out : attention value + input feature
- attention: N X (HxW) X (HxW)
- """
- """ Position attention module"""
- if __name__ == '__main__':
- #x = tf.random_uniform([2, 7, 7, 3],minval=0,maxval=255,dtype=tf.float32)
- file = tf.read_file('img.jpg')
- x = tf.image.decode_jpeg(file)
- #x = cv2.imread('ROIVIA3.jpg')
- print(x)
- gamma = 0.05
- sess = tf.Session()
- x1 = sess.run(x)
- x1 = tf.expand_dims(x1, axis=0)
- print(x1.shape)
- in_dim = 3
- xlen = x1.shape[1]
- ylen = x1.shape[2]
- input = Input(shape=(xlen,ylen,3))
- query_conv = Conv2D(1, (1,1), activation='relu',kernel_initializer='he_normal')(input)
- key_conv = Conv2D(1, (1, 1), activation='relu', kernel_initializer='he_normal')(input)
- value_conv = Conv2D(3, (1, 1), activation='relu', kernel_initializer='he_normal')(input)
- print(query_conv)
- batchsize, height, width, C = x1.shape
- #print(C, height, width )
- # B => N, C, HW
- proj_query = Reshape(( width * height ,1))(query_conv)
- proj_key = Reshape(( width * height, 1))(key_conv)
- proj_value = Reshape((width * height, 3))(value_conv)
- print("proj_query:",proj_query)
- print("proj_key:", proj_key)
- print("proj_value:",proj_value.shape)
- model = Model(inputs=[input],outputs=[proj_query])
- model.compile(optimizer='adam',loss='binary_crossentropy')
- proj_query = model.predict(x1,steps=1)
- print("proj_query:",proj_query)
- # B' => N, HW, C
- proj_query = proj_query.transpose(0, 2, 1)
- print("proj_query2:", proj_query.shape)
- print("proj_query2:", type(proj_query))
- # C => N, C, HW
- model1 = Model(inputs=[input], outputs=[proj_key])
- model1.compile(optimizer='adam', loss='binary_crossentropy')
- proj_key = model1.predict(x1, steps=1)
- print("proj_key:", proj_key.shape)
- print(proj_key)
- # B'xC => N, HW, HW
- energy = tf.matmul(proj_key, proj_query)
- print("energy:",energy.shape)
- # S = softmax(B'xC) => N, HW, HW
- attention = tf.nn.softmax(energy, axis=-1)
- print("attention:", attention.shape)
- # D => N, C, HW
- model2 = Model(inputs=[input], outputs=[proj_value])
- model2.compile(optimizer='adam', loss='binary_crossentropy')
- proj_value = model2.predict(x1, steps=1)
- print("proj_value:",proj_value.shape)
- # DxS' => N, C, HW
- out = tf.matmul(proj_value, sess.run(attention).transpose(0, 2, 1))
- print("out:", out.shape)
- # N, C, H, W
- out = Reshape((height, width, 3))(out)
- print("out1:", out.shape)
- out = gamma * out + sess.run(x1)
- print("out2:", type(out))
随机推荐
- 洛谷P1039 侦探推理(模拟)
侦探推理 题目描述 明明同学最近迷上了侦探漫画<柯南>并沉醉于推理游戏之中,于是他召集了一群同学玩推理游戏.游戏的内容是这样的,明明的同学们先商量好由其中的一个人充当罪犯(在明明不知情的情 ...
- Python【每日一问】13
问:请简述一下python的GIL 答:GIL 锁,全局解释器锁,仅在CPython解释器中,作用就是,限制多线程同时执行,保证同一时间内只有一个线程在执行.
- sqlserver 中NOLOCK、HOLDLOCK、UPDLOCK、TABLOCK、TABLOCKX
https://www.cnblogs.com/sthinker/p/5922967.html
- 基于ModBus-TCP/IT 台达PLC 通讯协议解析
客户端发送:19 B2 00 00 00 06 06 03 00 27 00 02 上面是modbus客户端发出的报文内容,为modbus tcp/ip协议格式,其前面的六个字节为头字节( heade ...
- 注解 - Excel 校验工具
注解类: @Retention(RetentionPolicy.RUNTIME) public @interface ExcelValidate { public boolean ignoreBlan ...
- MYSQL数据库中中文乱码问题
show variables like 'character%'; set character_set_database=gbk; 把记事本中的代码引入到mysql数据库中:source +addre ...
- Linux 信号signal处理函数
转自:http://www.cnblogs.com/taobataoma/archive/2007/08/30/875662.html alarm(设置信号传送闹钟) 相关函数 signal,slee ...
- 初次接触Linux
最近由于工作需求,需要接触Linux系统. 使用VMware虚拟机,安装ubuntu系统.网上教程很多. 配置opencv环境.这是我参考的网上帖子https://blog.csdn.net/fish ...
- python大法好——Python XML解析
Python XML解析 什么是XML? XML 被设计用来传输和存储数据. XML是一套定义语义标记的规则,这些标记将文档分成许多部件并对这些部件加以标识. 它也是元标记语言,即定义了用于定义其他与 ...
- swoole结合支持thinkphp 5.0版本
安装swoole pecl install swoole 修改PHP配置文件php.ini加入 extension=swoole.so 有可能不需要人工去加,安装时自动加入进来了, 查看swoole扩 ...