图像超分辨重构的原理,输入一张像素点少,像素较低的图像, 输出一张像素点多,像素较高的图像

而在作者的文章中,作者使用downsample_up, 使用imresize(img, []) 将图像的像素从原理的384,384降低到96, 96, 从而构造出高水平的图像和低水平的图像

作者使用了三个部分构成网络,

第一部分是生成网络,用于进行图片的生成,使用了16层的残差网络,最后的输出结果为tf.nn.tanh(),即为-1, 1, 因为图像进行了-1,1的预处理

第二部分是判别网络, 用于进行图片的判别操作,对于判别网络而言,是希望将生成的图片判别为假,将真的图片判别为真

第三部分是VGG19来提取生成图片和真实图片的conv5层卷积层的输出结果,用于生成局部部位的损失值mse

损失值说明:

d_loss:

d_loss_1: tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real))  # 真实图像的判别结果的损失值

d_loss_2: tl.cost.sigmoid_cross_entrpopy(logits_fake, tf.zeros_like(logits_real)) # 生成图像的判别结果的损失值

g_loss:

g_gan_loss: 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_real))  # 损失值表示为 -log(D(g(lr))) # 即生成的图像被判别为真的损失值

mse_loss: tl.cost.mean_squared_error(net_g.outputs, t_target_image)  # 计算真实值与生成值之间的像素差

vgg_loss: tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs) # 用于计算生成图片和真实图片经过vgg19的卷积层后,特征图之间的差异,用来获得特征细节的差异性

训练说明:

首先进行100次迭代,用来优化生成网络,使用tf.train.AdamOptimer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_var)

等生成网络迭代好以后,开始迭代生成网络和判别网络,以及VGG19的损失值缩小

生成网络:使用了16个残差模块,在残差模块的输入与下一层的输出之间又进行一次残差直连

判别网络:使用的是feature_map递增的卷积层构造成的判别网路

代码说明:

第一步:将参数从config中导入到main.py

第二步:使用tl.file.exists_or_mkdir() 构造用于储存图片的文件夹,同时定义checkpoint的文件夹

第三步:使用sorted(tl.files.load_file_list) 生成图片的列表, 使用tl.vis.read_images() 进行图片的读入

第四步:构建模型的构架Model

第一步:定义输入参数t_image = tf.placeholder('float32', [batch_size, 96, 96, 3]), t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3])

第二步: 使用SGRAN_g 用来生成最终的生成网络,net_g, 输入参数为t_image, is_training, reuse

第三步: 使用SGRAN_d 用来生成判别网络,输出结果为net_d网络架构,logits_real, 输入参数为t_target_image, is_training, reuse, 同理输入t_image, 获得logits_fake

第四步: 使用net_g.print_params(False) 和 net_g.print_layers() 不打印参数,打印每一层

第五步:将net_g.outputs即生成的结果和t_target_image即目标图像的结果输入到Vgg_19_simple_api, 获得vgg_net, 以及conv第五层的输出结果

第一步:tf.image.resize_images()进行图片的维度变换,为了可以使得其能输入到VGG_19中

第二步:将变化了维度的t_target_image 输入到Vgg_19_simple_api, 获得net_vgg, 和 vgg_target_emb即第五层卷积的输出结果

第三步:将变化了维度的net_g.outputs 输入到Vgg_19_simple_api, 获得 vgg_pred_emb即第五层卷积的输出结果

第六步: 构造net_g_test = SGRAN_g(t_image, False, True) 用于进行训练中的测试图片

第五步:构造模型loss,还有trian_ops操作

第一步: loss的构造, d_loss 和 g_loss的构造

第一步: d_loss的构造, d_loss_1 + d_loss_2

第一步: d_loss_1: 构造真实图片的判别损失值,即tl.cost.softmax_cross_entropy(logits_real, tf.ones_like(logits_real))

第二步: d_loss_2: 构造生成图片的判别损失值, 即tl.cost.softmax_cross_entropy(logits_fake, tf.ones_like(logits_fake))

第二步: g_loss的构造,g_gan_loss, mse_loss, vgg_loss

第一步: g_gan_loss, 生成网络被判别网络判别为真的概率,使用tl.cost.softmax_cross_entropy(logits_fake, tf.ones_like(logits_fake))

第二步:mse_loss 生成图像与目标图像之间的像素点差值,使用tl.cost.mean_squared_error(t_target_image, net_g.outputs)

第三步:vgg_loss  将vgg_target_emb.outputs与vgg_pred_emb.outputs获得第五层卷积层输出的mse_loss

第二步:构造train_op,包括 g_optim_init用预训练, 构造g_optim, d_optim

第一步:g_var = tl.layers.get_variables_with_name(‘SGRAN_g') 生成网络的参数获得

第二步: d_var = tl.layers.get_variable_with_name('SGRAN_d') 判别网络的参数获得

第三步: 使用with tf.variable_scope('learning_rate'): 使用lr_v = tf.Variable(lr_init)

第四步:定义train_op, g_optim_init, g_optim, d_optim

第一步:构造g_optim_init 使用tf.train.Adaoptimer(lr_v, beta1=betal).minimize(mse_loss, var_list=g_var)

第二步:构造g_optim 使用tf.train.Adaoptimer(lr_v, beta1=betal).minimize(g_loss, var_list=g_var)

第三步:构造d_optim 使用tf.train.Adaoptimer(lr_v, beta1=betal).minimize(d_loss, var_list=d_var)

第六步:使用tl.files.load_and_assign_npz() 载入训练好的sess参数

第一步: 使用tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))

第二步: tl.layers.initialize_global_variables(sess)

第三步: 使用tf.file.load_and_assign_npz 进行g_net的参数下载, 否者就下载g_{}_init的参数下载

第四步:使用tf.file.load_and_assgin_npz进行d_net的参数下载

第七步:下载VGG网络,将其运用到net_vgg

第一步:使用np.load(path, encoding='latin1').item() 下载参数

第二步:循环sorted(npz.items()) 进行参数循环,将其添加到params

第三步:使用tl.files.assign_params(sess, params, net_vgg) 将参数运用到net_vgg

第八步:进行参数的训练操作

第一步:从图片中跳出一个batch_size的数据构成测试集

第一步: 使用tl.prepro.threading_data fn = crop_sub_imgs_fn, 使用crop进行裁剪操作

第二步: 使用tl.prepro.threading_data fn = downsample 使用imresize进行图片的维度压缩

第二步:进行预训练操作

第一步:循环迭代, 获得一个batch的数据,使用crop_sub_imgs_fn 和 downsample构造低水平的数据和高水平的数据

第二步:使用sess.run, g_optim_init进行图片的预训练

第三步:进行训练操作

第一步:循环迭代,获得一个batch的数据,使用crop_sub_imgs_fn 和 downsample构造低水平的数据和高水平的数据

第二步:使用sess.run, g_optim 和 d_optim 进行图片的训练操作

第九步:进行evaluate图片的测试阶段

第一步: 构造图片展示的文件夹,使用tf.files.exits_files_mkdir

第二步: 使用tl.files.load_file_list 和 tl.vis.read_images读入图片

第三步:根据索引选择一张图片,/127.5 - 1 进行归一化处理

第四步:使用tf.placeholder('float32', [1, None, None, 3]) 构造输入的t_image

第五步: 使用SGRAN_g(t_image, False, False) 构造net_g

第六步:使用tf.Session() 构造sess,使用tl.files.load_and_assign_npz下载训练好的sess, network=net_g

第七步:使用sess.run([net_g.outputs], feed_dict={t_image:[valid_lr_img]}) 获得图片

第八步:使用tl.vis.save_images(outputs[0])保存图片

第九步:使用scipy.misc.imresize() 将低像素的图片扩大为原来的四倍,与重构的图像作对比

代码: main.py  主函数

  1. import tensorlayer as tl
  2. import tensorflow as tf
  3. import numpy as np
  4. from config import config
  5. from model import *
  6. import os
  7. import time
  8. import scipy
  9.  
  10. ## 添加参数
  11. batch_size = config.TRAIN.batch_size
  12. lr_init = config.TRAIN.lr_init
  13. betal = config.TRAIN.betal
  14.  
  15. ### initialze G
  16. n_epoch_init = config.TRAIN.n_epoch_init
  17. ### adversarial learning
  18. n_epoch = config.TRAIN.n_epoch
  19. lr_decay = config.TRAIN.lr_decay
  20. decay_every = config.TRAIN.decay_every
  21.  
  22. ni = int(np.sqrt(batch_size))
  23.  
  24. def train():
  25.  
  26. # 创建用于进行图片储存的文件
  27. save_dir_ginit = 'sample/{}_ginit'.format(tl.global_flag['mode'])
  28. save_dir_gan = 'sample/{}_gan'.format(tl.global_flag['mode'])
  29. tl.files.exists_or_mkdir(save_dir_ginit)
  30. tl.files.exists_or_mkdir(save_dir_gan)
  31. checkpoint = 'checkpoint'
  32. tl.files.exists_or_mkdir(checkpoint)
  33.  
  34. train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
  35. train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
  36.  
  37. train_hr_img = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=8)
  38. train_lr_img = tl.vis.read_images(train_lr_img_list, path=config.TRAIN.lr_img_path, n_threads=8)
  39.  
  40. # 构造输入
  41. t_image = tf.placeholder('float32', [batch_size, 96, 96, 3])
  42. t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3])
  43. # 构造生成的model,获得生成model的输出net_g
  44. net_g = SRGAN_g(t_image, True, False)
  45. # 构造判别网络,判别net_g.output, t_target_image, net_d表示整个网络
  46. net_d, logist_real = SRGAN_d(t_target_image, True, False)
  47. _, logist_fake = SRGAN_d(net_g.outputs, True, True)
  48. # 构造VGG网络
  49.  
  50. net_g.print_params(False)
  51. net_g.print_layers()
  52. net_d.print_params(False)
  53. net_d.print_layers()
  54.  
  55. # 进行输入数据的维度变换,将其转换为224和224
  56. target_image_224 = tf.image.resize_images(t_target_image, [224, 224], method=0, align_corners=False)
  57. pred_image_224 = tf.image.resize_images(net_g.outputs, [224, 224], method=0, align_corners=False)
  58.  
  59. net_vgg, vgg_target_emb = Vgg_19_simple_api((target_image_224 + 1) / 2, reuse=False)
  60. _, vgg_pred_emb = Vgg_19_simple_api((net_g + 1) / 2, reuse=True)
  61. # 进行训练阶段的测试
  62. net_g_test = SRGAN_g(t_image, False, True)
  63.  
  64. #### ========== DEFINE_TRAIN_OP =================###
  65. d_loss_1 = tl.cost.sigmoid_cross_entropy(logist_real, tf.ones_like(logist_real))
  66. d_loss_2 = tl.cost.sigmoid_cross_entropy(logist_fake, tf.zeros_like(logist_fake))
  67. d_loss = d_loss_1 + d_loss_2
  68.  
  69. g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logist_fake, tf.ones_like(logist_fake))
  70. mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True)
  71. vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_target_emb.outputs, vgg_pred_emb.outputs, is_mean=True)
  72. g_loss = g_gan_loss + mse_loss + vgg_loss
  73.  
  74. g_var = tl.layers.get_variables_with_name('SRGAN_g', True, True)
  75. d_var = tl.layers.get_variables_with_name('SRGAN_d', True, True)
  76.  
  77. with tf.variable_scope('learning_rate'):
  78. lr_v = tf.Variable(lr_init, trainable=False)
  79.  
  80. g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=betal).minimize(mse_loss, var_list=g_var)
  81. g_optim = tf.train.AdamOptimizer(lr_v, beta1=betal).minimize(g_loss, var_list=g_var)
  82. d_optim = tf.train.AdamOptimizer(lr_v, beta1=betal).minimize(d_loss, var_list=d_var)
  83.  
  84. ###======================RESTORE_MODEL_SESS ==================###
  85. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
  86. tl.layers.initialize_global_variables(sess)
  87. if tl.files.load_and_assign_npz(sess, checkpoint + '/g_{}.npz'.format(tl.global_flag['mode'], network=net_g)) is False:
  88. tl.files.load_and_assign_npz(sess, checkpoint + '/g_init_{}.npz'.format(tl.global_flag['mode'], network=net_g))
  89. tl.files.load_and_assign_npz(sess, checkpoint + '/d_{}.npz'.format(tl.global_flag['mode'], network=net_d))
  90.  
  91. ### ================== load vgg params =================== ###
  92. vgg_npy_path = 'vgg19.npy'
  93. if not os.path.isfile(vgg_npy_path):
  94. print('Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg')
  95. exit()
  96.  
  97. npz = np.load(vgg_npy_path, encoding='latin1').item()
  98. params = []
  99. for var in sorted(npz.items()):
  100. W = np.asarray(var[1][0])
  101. b = np.asarray(var[1][1])
  102. params.extend([W, b])
  103.  
  104. tl.files.assign_params(sess, params, net_vgg)
  105.  
  106. print('ok')
  107.  
  108. ###======================== TRAIN =======================###
  109. sample_imgs = train_hr_img[0:batch_size]
  110. # 进行随机裁剪,保证其维度为384
  111. sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False)
  112. # 进行像素的降低
  113. sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn)
  114. # 进行图片的保存
  115. tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png')
  116. tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png')
  117. tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png')
  118. tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png')
  119.  
  120. ###======================== initial train G =====================###
  121.  
  122. for epoch in range(n_epoch_init):
  123.  
  124. n_iter = 0
  125. init_loss_total = 0
  126.  
  127. for idx in range(0, len(train_hr_img), batch_size):
  128.  
  129. b_img_384 = tl.prepro.threading_data(train_hr_img[idx:idx+batch_size], fn=crop_sub_imgs_fn, is_random=False)
  130. b_img_96 = tl.prepro.threading_data(b_img_384, fn=downsample_fn)
  131.  
  132. _, MSE_LOSS = sess.run([g_optim_init, mse_loss], feed_dict={t_image:b_img_96, t_target_image:b_img_384})
  133.  
  134. init_loss_total += MSE_LOSS
  135.  
  136. if (epoch != 0) and (epoch % 10 == 0):
  137. out = sess.run(net_g_test.outputs, feed_dict={t_image:sample_imgs_96})
  138. print('[*] save image')
  139. tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch)
  140.  
  141. if (epoch != 0) and (epoch % 10 ==0):
  142.  
  143. tl.files.save_npz(net_g.all_params, name=checkpoint + '/g_init_{}.npz'.format(tl.global_flag['mode']))
  144.  
  145. ### ======================== train GAN ================== ###
  146.  
  147. for epoch in range(0, n_epoch+1):
  148.  
  149. if epoch != 0 and epoch % decay_every == 0:
  150. new_lr = lr_decay ** (epoch // decay_every)
  151. sess.run(tf.assign(lr_v, new_lr * lr_v))
  152. log = '** new learning rate: %f(for GAN)' % (lr_init * new_lr)
  153. print(log)
  154.  
  155. elif epoch == 0:
  156. sess.run(tf.assign(lr_v, lr_init))
  157. log = '** init lr: %f decay_every_init: %d, lr_decay: %f(for GAN)'%(lr_init, decay_every, lr_decay)
  158. print(log)
  159.  
  160. epoch_time = time.time()
  161. total_d_loss, total_g_loss, n_iter = 0, 0, 0
  162.  
  163. for idx in range(0, len(train_hr_img), batch_size):
  164.  
  165. b_img_384 = tl.prepro.threading_data(train_hr_img[idx:idx+batch_size], fn=crop_sub_imgs_fn, is_random=False)
  166. b_img_96 = tl.prepro.threading_data(b_img_384, fn=downsample_fn)
  167.  
  168. _, errD = sess.run([d_optim, d_loss], feed_dict={t_image:b_img_96, t_target_image:b_img_384})
  169. _, errG, errM, errV, errA = sess.run([g_optim, g_loss, mse_loss, vgg_loss, g_gan_loss], feed_dict={t_image:b_img_96, t_target_image:b_img_384})
  170.  
  171. total_d_loss += errD
  172. total_g_loss += errG
  173.  
  174. if epoch != 0 and epoch % 10 == 0:
  175. out = sess.run(net_g_test.outputs, feed_dict={t_image:sample_imgs_96})
  176. print('[*] save image')
  177. tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d' % epoch)
  178.  
  179. if epoch != 0 and epoch % 10 == 0:
  180.  
  181. tl.files.save_npz(net_g.all_params, name = checkpoint + '/g_{}.npz'.format(tl.global_flag['mode']))
  182. tl.files.save_npz(net_d.all_params, name= checkpoint + '/d_{}.npz'.format(tl.global_flag['mode']))
  183.  
  184. def evaluate():
  185.  
  186. save_dir = 'sample/{}'.format(tl.global_flag['mode'])
  187. tl.files.exists_or_mkdir(save_dir)
  188. checkpoints = 'checkpoints'
  189.  
  190. evaluate_hr_img_list = sorted(tl.files.load_file_list(config.VALID.hr_img_path, regx='.*.png', printable=False))
  191. evaluate_lr_img_list = sorted(tl.files.load_file_list(config.VALID.lr_img_path, regx='.*.png', printable=False))
  192.  
  193. valid_lr_imgs = tl.vis.read_images(evaluate_lr_img_list, path=config.VALID.lr_img_path, n_threads=8)
  194. valid_hr_imgs = tl.vis.read_images(evaluate_hr_img_list, path=config.VALID.hr_img_path, n_threads=8)
  195.  
  196. ### ==================== DEFINE MODEL =================###
  197. imid = 64
  198. valid_lr_img = valid_lr_imgs[imid]
  199. valid_hr_img = valid_hr_imgs[imid]
  200.  
  201. valid_lr_img = (valid_lr_img / 127.5) - 1
  202.  
  203. t_image = tf.placeholder('float32', [1, None, None, 3])
  204. net_g = SGRAN_g(t_image, False, False)
  205.  
  206. sess = tf.Session()
  207. tl.files.load_and_assign_npz(sess, checkpoints + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g)
  208.  
  209. output = sess.run([net_g.outputs], feed_dict={t_image:[valid_lr_img]})
  210.  
  211. tl.vis.save_images(output[0], [ni, ni], save_dir + '/valid_gen.png')
  212. tl.vis.save_images(valid_lr_img, [ni, ni], save_dir + '/valid_lr.png')
  213. tl.vis.save_images(valid_hr_img, [ni, ni], save_dir + '/valid_hr.png')
  214.  
  215. size = valid_hr_img.shape
  216. out_bicu = scipy.misc.imresize(valid_lr_img, [size[0]*4, size[1]*4], interp='bicubic', mode=None)
  217. tl.vis.save_images(out_bicu, [ni, ni], save_dir + '/valid_out_bicu.png')
  218.  
  219. if __name__ == '__main__':
  220. import argparse
  221. parse = argparse.ArgumentParser()
  222. parse.add_argument('--mode', type=str, default='srgan', help='srgan evaluate')
  223. args = parse.parse_args()
  224.  
  225. tl.global_flag['mode'] = args.mode
  226. if tl.global_flag['mode'] == 'srgan':
  227. train()
  228.  
  229. elif tl.global_flag['mode'] == 'evaluate':
  230. evaluate()

model.py 构建模型

  1. import tensorflow as tf
  2. import tensorlayer as tl
  3. from tensorlayer.layers import *
  4. import time
  5.  
  6. def SRGAN_g(input_image, is_train, reuse):
  7.  
  8. w_init = tf.random_normal_initializer(stddev=0.2)
  9. b_init = None
  10. g_init = tf.random_normal_initializer(1, 0.02)
  11.  
  12. with tf.variable_scope('SRGAN_g', reuse=reuse):
  13. n = InputLayer(input_image, name='in')
  14. n = Conv2d(n, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='n64s1/c')
  15. temp = n
  16. for i in range(16):
  17. nn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n64s1/c1/%d' % i)
  18. nn = BatchNormLayer(nn, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%d' % i)
  19. nn = Conv2d(nn, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n64s1/c2/%d' % i)
  20. nn = BatchNormLayer(nn, act=None, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%d'%i)
  21. nn = ElementwiseLayer([n, nn], tf.add, name='b_residual_add_%d' % i)
  22. n = nn
  23.  
  24. n = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n64s1/c3')
  25. n = BatchNormLayer(n, act=None, is_train=is_train, gamma_init=g_init, name='n64s1/b3')
  26. n = ElementwiseLayer([temp, n], tf.add, name='add3')
  27.  
  28. # 进行反卷积操作
  29. n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n64s1/c4')
  30. n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshuffler2/1')
  31.  
  32. n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n64s1/c5')
  33. n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshuffle2/2')
  34.  
  35. n = Conv2d(n, 3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init, name='out')
  36.  
  37. return n
  38.  
  39. def SRGAN_d(input_image, is_training=True, reuse=False):
  40.  
  41. w_init = tf.random_normal_initializer(stddev=0.2)
  42. b_init = None
  43. g_init = tf.random_normal_initializer(1.0, stddev=0.02)
  44. lrelu = lambda x: tl.act.lrelu(x, 0.2)
  45. df_dim = 64
  46. with tf.variable_scope('SRGAN_d', reuse=reuse):
  47. tl.layers.set_name_reuse(reuse)
  48. net_in = InputLayer(input_image, name='input/image')
  49. net_h0 = Conv2d(net_in, df_dim, (4, 4), (2, 2), act=lrelu, padding='SAME', W_init=w_init, name='h0/c')
  50.  
  51. net_h1 = Conv2d(net_h0, df_dim*2, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, name='h1/c')
  52. net_h1 = BatchNormLayer(net_h1, act=lrelu, is_train=is_training, gamma_init=g_init, name='h1/bn')
  53. net_h2 = Conv2d(net_h1, df_dim*4, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, name='h2/c')
  54. net_h2 = BatchNormLayer(net_h2, act=lrelu, is_train=is_training, gamma_init=g_init, name='h2/bn')
  55. net_h3 = Conv2d(net_h2, df_dim*8, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, name='h3/c')
  56. net_h3 = BatchNormLayer(net_h3, act=lrelu, is_train=is_training, gamma_init=g_init, name='h3/bn')
  57. net_h4 = Conv2d(net_h3, df_dim*16, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, name='h4/c')
  58. net_h4 = BatchNormLayer(net_h4, act=lrelu, is_train=is_training, gamma_init=g_init, name='h4/bn')
  59. net_h5 = Conv2d(net_h4, df_dim*32, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, name='h5/c')
  60. net_h5 = BatchNormLayer(net_h5, act=lrelu, is_train=is_training, gamma_init=g_init, name='h5/bn')
  61. net_h6 = Conv2d(net_h5, df_dim*16, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, name='h6/c')
  62. net_h6 = BatchNormLayer(net_h6, act=lrelu, is_train=is_training, gamma_init=g_init, name='h6/bn')
  63. net_h7 = Conv2d(net_h6, df_dim*8, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, name='h7/c')
  64. net_h7 = BatchNormLayer(net_h7, act=lrelu, is_train=is_training, gamma_init=g_init, name='h7/bn')
  65.  
  66. net = Conv2d(net_h7, df_dim*2, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, name='reg/c')
  67. net = BatchNormLayer(net, act=lrelu, is_train=is_training, gamma_init=g_init, name='reg/bn')
  68. net = Conv2d(net, df_dim*2, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='reg/c2')
  69. net = BatchNormLayer(net, act=lrelu, is_train=is_training, gamma_init=g_init, name='reg/bn2')
  70. net = Conv2d(net, df_dim*8, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='reg/c3')
  71. net = BatchNormLayer(net, act=lrelu, is_train=is_training, gamma_init=g_init, name='reg/bn3')
  72.  
  73. net_h8 = ElementwiseLayer([net_h7, net], tf.add, name='red/add')
  74. net_h8.outputs = tl.act.lrelu(net_h8.outputs, 0.2)
  75.  
  76. net_ho = FlattenLayer(net_h8, name='ho/flatten')
  77. net_ho = DenseLayer(net_ho, n_units=1, act=tf.identity, W_init=w_init, name='ho/dense')
  78. logits = net_ho.outputs
  79. net_ho.outputs = tf.nn.sigmoid(net_ho.outputs)
  80.  
  81. return net_ho, logits
  82.  
  83. def Vgg_19_simple_api(input_image, reuse):
  84.  
  85. VGG_MEAN = [103.939, 116.779, 123.68]
  86. # 将输入的rgb图像转换为bgr
  87. with tf.variable_scope('VGG19', reuse=reuse) as vs:
  88. start_time = time.time()
  89. print('build the model')
  90. input_image = input_image * 255
  91. red, green, blue = tf.split(input_image, 3, 3)
  92. assert red.get_shape().as_list()[1:] == [224, 224, 1]
  93. assert green.get_shape().as_list()[1:] == [224, 224, 1]
  94. assert blue.get_shape().as_list()[1:] == [224, 224, 1]
  95.  
  96. bgr = tf.concat([blue-VGG_MEAN[0], green-VGG_MEAN[1], red-VGG_MEAN[2]], axis=3)
  97. assert input_image.get_shape().as_list()[1:] == [224, 224, 3]
  98.  
  99. net_in = InputLayer(bgr, name='input')
  100. # 构建网络
  101. """conv1"""
  102. network = Conv2d(net_in, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv1_1')
  103. network = Conv2d(network, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv1_2')
  104. network = MaxPool2d(network, (2, 2), (2, 2), padding='SAME', name='pool1')
  105. '''conv2'''
  106. network = Conv2d(network, 128, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv2_1')
  107. network = Conv2d(network, 128, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv2_2')
  108. network = MaxPool2d(network, (2, 2), (2, 2), padding='SAME', name='pool2')
  109. '''conv3'''
  110. network = Conv2d(network, 256, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv3_1')
  111. network = Conv2d(network, 256, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv3_2')
  112. network = Conv2d(network, 256, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv3_3')
  113. network = Conv2d(network, 256, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv3_4')
  114. network = MaxPool2d(network, (2, 2), (2, 2), padding='SAME', name='pool3')
  115. '''conv4'''
  116. network = Conv2d(network, 512, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv4_1')
  117. network = Conv2d(network, 512, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv4_2')
  118. network = Conv2d(network, 512, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv4_3')
  119. network = Conv2d(network, 512, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv4_4')
  120. network = MaxPool2d(network, (2, 2), (2, 2), padding='SAME', name='pool4')
  121. '''conv5'''
  122. network = Conv2d(network, 512, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv5_1')
  123. network = Conv2d(network, 512, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv5_2')
  124. network = Conv2d(network, 512, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv5_3')
  125. network = Conv2d(network, 512, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', name='conv5_4')
  126. network = MaxPool2d(network, (2, 2), (2, 2), padding='SAME', name='pool5')
  127. conv = network
  128. """fc6-8"""
  129. network = FlattenLayer(network, name='flatten')
  130. network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc6')
  131. network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc7')
  132. network = DenseLayer(network, n_units=1000, act=tf.identity, name='fc8')
  133. print('finish the bulid %fs' % (time.time() - start_time))
  134. return network, conv

config.py  参数文件

  1. from easydict import EasyDict as edict
  2. import json
  3.  
  4. config = edict()
  5.  
  6. config.TRAIN = edict()
  7.  
  8. # Adam
  9. config.TRAIN.batch_size = 1
  10. config.TRAIN.lr_init = 1e-4
  11. config.TRAIN.betal = 0.9
  12.  
  13. ### initialize G
  14. config.TRAIN.n_epoch_init = 100
  15.  
  16. ### adversarial_leaning
  17. config.TRAIN.n_epoch = 2000
  18. config.TRAIN.lr_decay = 0.1
  19. config.TRAIN.decay_every = int(config.TRAIN.n_epoch / 2)
  20.  
  21. ## train set location
  22. config.TRAIN.hr_img_path = r'C:\Users\qq302\Desktop\srdata\DIV2K_train_HR'
  23. config.TRAIN.lr_img_path = r'C:\Users\qq302\Desktop\srdata\DIV2K_train_LR_bicubic\X4'
  24.  
  25. # valid set location
  26.  
  27. config.VALID = edict()
  28.  
  29. config.VALID.hr_img_path = r'C:\Users\qq302\Desktop\srdata\DIV2K_valid_HR'
  30. config.VALID.lr_img_path = r'C:\Users\qq302\Desktop\srdata\DIV2K_valid_LR_bicubic/X4'

utils.py  操作文件

  1. from tensorlayer.prepro import *
  2.  
  3. def crop_sub_imgs_fn(img, is_random=True):
  4.  
  5. x = crop(img, wrg=384, hrg=384, is_random=is_random)
  6. # 进行 -1 - 1 的归一化
  7. x = x / 127.5 - 1
  8. return x
  9.  
  10. def downsample_fn(img):
  11.  
  12. x = imresize(img, [96, 96], interp='bicubic', mode=None)
  13. # 存在一定的问题
  14. x = x / 127.5 - 1
  15. return x

深度原理与框架-图像超分辨重构-tensorlayer的更多相关文章

  1. 深度学习原理与框架-图像补全(原理与代码) 1.tf.nn.moments(求平均值和标准差) 2.tf.control_dependencies(先执行内部操作) 3.tf.cond(判别执行前或后函数) 4.tf.nn.atrous_conv2d 5.tf.nn.conv2d_transpose(反卷积) 7.tf.train.get_checkpoint_state(判断sess是否存在

    1. tf.nn.moments(x, axes=[0, 1, 2])  # 对前三个维度求平均值和标准差,结果为最后一个维度,即对每个feature_map求平均值和标准差 参数说明:x为输入的fe ...

  2. 图像超分辨-DBPN

    本文译自2018CVPR DeepBack-Projection Networks For Super-Resolution 代码: github 特点:不同于feedback net,引入back ...

  3. 人工智能范畴及深度学习主流框架,IBM Watson认知计算领域IntelligentBehavior介绍

    人工智能范畴及深度学习主流框架,IBM Watson认知计算领域IntelligentBehavior介绍 工业机器人,家用机器人这些只是人工智能的一个细分应用而已.图像识别,语音识别,推荐算法,NL ...

  4. 人工智能深度学习Caffe框架介绍,优秀的深度学习架构

    人工智能深度学习Caffe框架介绍,优秀的深度学习架构 在深度学习领域,Caffe框架是人们无法绕过的一座山.这不仅是因为它无论在结构.性能上,还是在代码质量上,都称得上一款十分出色的开源框架.更重要 ...

  5. 从Theano到Lasagne:基于Python的深度学习的框架和库

    从Theano到Lasagne:基于Python的深度学习的框架和库 摘要:最近,深度神经网络以“Deep Dreams”形式在网站中如雨后春笋般出现,或是像谷歌研究原创论文中描述的那样:Incept ...

  6. 人工智能范畴及深度学习主流框架,谷歌 TensorFlow,IBM Watson认知计算领域IntelligentBehavior介绍

    人工智能范畴及深度学习主流框架,谷歌 TensorFlow,IBM Watson认知计算领域IntelligentBehavior介绍 ================================ ...

  7. atitit.http get post的原理以及框架实现java php

    atitit.http get post的原理以及框架实现java php 1. 相关的设置 1 1.1. urlencode 1 1.2. 输出流的编码 1 1.3. 图片,文件的post 1 2. ...

  8. Dubbo原理与框架设计

    Dubbo是常用的开源服务治理型RPC框架,在之前osgi框架下不同bundle之间的方法调用时用到过.其工作原理和框架设计值得开源技术爱好者学习和研究. 一.Dubbo的工作原理 调用关系说明 服务 ...

  9. 深度学习Keras框架笔记之AutoEncoder类

    深度学习Keras框架笔记之AutoEncoder类使用笔记 keras.layers.core.AutoEncoder(encoder, decoder,output_reconstruction= ...

随机推荐

  1. DataGrip for Mac破解步骤详解 亲测好用

    https://blog.csdn.net/le945926/article/details/81912085

  2. logback root level logger level 日志级别覆盖?继承?

    1. logback-spring.xml 配置 <appender name="STDOUT" class="ch.qos.logback.core.Consol ...

  3. FPGA Asynchronous FIFO设计思路

    FPGA Asynchronous FIFO设计思路 将一个多位宽,且在不停变化的数据从一个时钟域传递到另一个时钟域是比较困难的. 同步FIFO的指针比较好确定,当FIFO counter达到上限值时 ...

  4. CURL 支持 GET、PUT、POST、DELETE请求

    一个方法解决所有的 curl 请求的问题. <?php function curlTypeData( $method, $url, $data=false, $json=false ) { $d ...

  5. javascript+html5+canvse+3d俄罗斯方块

    javascript+html5+canvse+3d俄罗斯方块 必须使用支持html5的浏览器打开,比如firefox,chrome 得分:0速度:1000 // 你的浏览器不支持 <canva ...

  6. ECMA6 New Features

    花了一些时间把ECMA6的新特性进行了回顾,给自己建立了思维索引,大部分内容借鉴了阮一峰大神的博客. refers: http://es6.ruanyifeng.com/#docs/arraybuff ...

  7. OpenLayers加载谷歌地图服务

    谷歌地图的地址如下: 谷歌交通地图地址:http://www.google.cn/maps/vt/pb=!1m4!1m3!1i{z}!2i{x}!3i{y}!2m3!1e0!2sm!3i3800725 ...

  8. 负载均衡中的session保持

    什么叫负载均衡中的session保持 当我们需要做负载均衡时,服务端肯定有多台服务器,用户每次请求进来,会根据负载均衡算法被分配到某一台机器上,假设用户需要进行一段连续操作时,在第一台机器登陆后,下一 ...

  9. 小程序中添加客服按钮contact-button

    小程序的客服系统,是微信做的非常成功的一个功能,开发者可以很方便的通过一行代码,就可实现客服功能. 1. 普通客服按钮添加 <button open-type='contact' session ...

  10. 存储过程中调用webservice

    存储过程中调用webservice其实是在数据库中利用系统函数调用OLE. 1.查找webservice api 可得到MSSOAP.SoapClient. 2.查找API 接口可得到mssoapin ...