1. import tensorflow as tf
  2. import numpy as np
  3. import math
  4. import keras
  5. from keras.layers import Conv2D,Reshape,Input
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  9. """ Channel attention module"""
  11. if __name__ == '__main__':
  12. file = tf.read_file('img.jpg')
  13. x = tf.image.decode_jpeg(file)
  14. #print("Tensor:", x)
  15. sess = tf.Session()
  16. x1 = sess.run(x)
  17. print("x1:",x1)
  18. gamma = 0.05
  19. sess = tf.Session()
  20. x1 = sess.run(x)
  21. x1 = tf.expand_dims(x1, dim =0)
  22. print("x1.shape:", x1.shape)
  24. m_batchsize, height, width, C = x1.shape
  26. proj_query = Reshape((width * height, C))(x1)
  27. print("proj_query:", type(proj_query))
  28. print("proj_query:", proj_query.shape)
  29. proj_query = sess.run(proj_query)
  30. print(proj_query)
  31. proj_key = Reshape((width * height, C))(x1)
  32. proj_key = sess.run(proj_key).transpose(0, 2, 1)
  33. print(proj_key)
  34. print("proj_key:", type(proj_key))
  35. print("proj_key:", proj_key.shape)
  37. proj_query = proj_query.astype(np.float32)
  38. proj_key = proj_key.astype(np.float32)
  40. # N, C, C, bmm 批次矩阵乘法
  41. energy = tf.matmul(proj_key,proj_query)
  42. energy = sess.run(energy)
  43. print("energy:", energy)
  45. # 这里实现了softmax用最后一维的最大值减去了原始数据, 获得了一个不是太大的值
  46. # 沿着最后一维的C选择最大值, keepdim保证输出和输入形状一致, 除了指定的dim维度大小为1
  47. energy_new = tf.reduce_max(energy, -1, keep_dims=True)
  48. print("after_softmax_energy:",sess.run(energy_new))
  50. sess = tf.Session()
  51. e = energy_new
  52. print("b:", sess.run(energy_new))
  54. size = energy.shape[1]
  55. for i in range(size - 1):
  56. e = tf.concat([e, energy_new], axis=-1)
  58. energy_new = e
  59. print("energy_new2:", sess.run(energy_new))
  60. energy_new = energy_new - energy
  61. print("energy_new3:", sess.run(energy_new))
  63. attention = tf.nn.softmax(energy_new, axis=-1)
  64. print("attention:", sess.run(attention))
  66. proj_value = Reshape((width * height, C))(x1)
  67. proj_value = sess.run(proj_value)
  68. proj_value = proj_value.astype(np.float32)
  69. print("proj_value:", proj_value.shape)
  70. out = tf.matmul(proj_value, attention)
  72. out = sess.run(out)
  73. #plt.imshow(out)
  74. print("out1:", out)
  75. out = out.reshape(m_batchsize, width * height, C)
  76. #out1 = out.reshape(m_batchsize, C, height, width)
  77. print("out2:", out.shape)
  79. out = gamma * out + x
  80. #out = sess.run(out)
  81. #out = out.astype(np.int16)
  82. print("out3:", out)
  1. import tensorflow as tf
  2. import numpy as np
  3. import math
  4. import keras
  5. from keras.layers import Conv2D,Reshape,Input
  6. from keras.regularizers import l2
  7. from keras.layers.advanced_activations import ELU, LeakyReLU
  8. from keras import Model
  9. import cv2
  11. """
  12. Important:
  14. 1、A为CxHxW => Conv+BN+ReLU => B, C 都为CxHxW
  16. 2、Reshape B, C to CxN (N=HxW)
  17. 3、Transpose B to B’
  18. 4、Softmax(Matmul(B’, C)) => spatial attention map S为NxN(HWxHW)
  19. 5、如上式1, 其中sji测量了第i个位置在第j位置上的影响
  20. 6、也就是第i个位置和第j个位置之间的关联程度/相关性, 越大越相似.
  21. 7、A => Covn+BN+ReLU => D 为CxHxW => reshape to CxN
  22. 8、Matmul(D, S’) => CxHxW, 这里设置为DS
  23. 9、Element-wise sum(scale parameter alpha * DS, A) => the final output E 为 CxHxW (式2)
  24. 10、alpha is initialized as 0 and gradually learn to assign more weight.
  25. """
  26. """
  27. inputs :
  28. x : input feature maps( N X C X H X W)
  29. returns :
  30. out : attention value + input feature
  31. attention: N X (HxW) X (HxW)
  32. """
  33. """ Position attention module"""
  34. if __name__ == '__main__':
  35. #x = tf.random_uniform([2, 7, 7, 3],minval=0,maxval=255,dtype=tf.float32)
  36. file = tf.read_file('img.jpg')
  37. x = tf.image.decode_jpeg(file)
  38. #x = cv2.imread('ROIVIA3.jpg')
  39. print(x)
  40. gamma = 0.05
  41. sess = tf.Session()
  42. x1 = sess.run(x)
  43. x1 = tf.expand_dims(x1, axis=0)
  44. print(x1.shape)
  45. in_dim = 3
  47. xlen = x1.shape[1]
  48. ylen = x1.shape[2]
  49. input = Input(shape=(xlen,ylen,3))
  50. query_conv = Conv2D(1, (1,1), activation='relu',kernel_initializer='he_normal')(input)
  51. key_conv = Conv2D(1, (1, 1), activation='relu', kernel_initializer='he_normal')(input)
  52. value_conv = Conv2D(3, (1, 1), activation='relu', kernel_initializer='he_normal')(input)
  53. print(query_conv)
  55. batchsize, height, width, C = x1.shape
  56. #print(C, height, width )
  57. # B => N, C, HW
  58. proj_query = Reshape(( width * height ,1))(query_conv)
  59. proj_key = Reshape(( width * height, 1))(key_conv)
  60. proj_value = Reshape((width * height, 3))(value_conv)
  61. print("proj_query:",proj_query)
  62. print("proj_key:", proj_key)
  63. print("proj_value:",proj_value.shape)
  64. model = Model(inputs=[input],outputs=[proj_query])
  65. model.compile(optimizer='adam',loss='binary_crossentropy')
  66. proj_query = model.predict(x1,steps=1)
  67. print("proj_query:",proj_query)
  68. # B' => N, HW, C
  69. proj_query = proj_query.transpose(0, 2, 1)
  70. print("proj_query2:", proj_query.shape)
  71. print("proj_query2:", type(proj_query))
  72. # C => N, C, HW
  73. model1 = Model(inputs=[input], outputs=[proj_key])
  74. model1.compile(optimizer='adam', loss='binary_crossentropy')
  75. proj_key = model1.predict(x1, steps=1)
  76. print("proj_key:", proj_key.shape)
  78. print(proj_key)
  79. # B'xC => N, HW, HW
  80. energy = tf.matmul(proj_key, proj_query)
  81. print("energy:",energy.shape)
  83. # S = softmax(B'xC) => N, HW, HW
  84. attention = tf.nn.softmax(energy, axis=-1)
  85. print("attention:", attention.shape)
  87. # D => N, C, HW
  88. model2 = Model(inputs=[input], outputs=[proj_value])
  89. model2.compile(optimizer='adam', loss='binary_crossentropy')
  90. proj_value = model2.predict(x1, steps=1)
  91. print("proj_value:",proj_value.shape)
  93. # DxS' => N, C, HW
  94. out = tf.matmul(proj_value, sess.run(attention).transpose(0, 2, 1))
  95. print("out:", out.shape)
  97. # N, C, H, W
  98. out = Reshape((height, width, 3))(out)
  99. print("out1:", out.shape)
  101. out = gamma * out + sess.run(x1)
  102. print("out2:", type(out))


