风格迁移: 在内容上尽量与基准图像保持一致,在风格上尽量与风格图像保持一致。

  • 1. 使用预训练的VGG19网络提取特征
  • 2. 损失函数之一是“内容损失”(content loss),代表合成的图像的特征与基准图像的特征之间的L2距离,保证生成的图像内容和基准图像保持一致。
  • 3. 损失函数之二是“风格损失”(style loss),代表合成图像的特征与风格图像的特征之间的Gram矩阵之间的差异,保证生成图像的风格和风格图像保持一致。
  • 4. 损失函数之三是“差异损失”(variation loss),代表合成的图像局部特征之间的差异,保证生成的图像局部特征的一致性,整体看上去自然不突兀。

基于keras的代码实现:

  1. # coding: utf-8
  2. from __future__ import print_function
  3. from keras.preprocessing.image import load_img, img_to_array
  4. import numpy as np
  5. from scipy.optimize import fmin_l_bfgs_b
  6. import time
  7. import argparse
  8. from scipy.misc import imsave
  9. from keras.applications import vgg19
  10. from keras import backend as K
  11. import os
  12. from PIL import Image, ImageFont, ImageDraw, ImageOps, ImageEnhance, ImageFilter
  13. # 输入参数
  14. parser = argparse.ArgumentParser(description='基于Keras的图像风格迁移.') # 解析器
  15. parser.add_argument('--style_reference_image_path', metavar='ref', type=str,default = './style.jpg',
  16. help='目标风格图片的位置')
  17. parser.add_argument('--base_image_path', metavar='ref', type=str,default = './base.jpg',
  18. help='基准图片的位置')
  19. parser.add_argument('--iter', type=int, default=25, required=False,
  20. help='迭代次数')
  21. parser.add_argument('--pictrue_size', type=int, default=500, required=False,
  22. help='图片大小.')
  23. # 获取参数
  24. args = parser.parse_args()
  25. base_image_path = args.base_image_path
  26. style_reference_image_path = args.style_reference_image_path
  27. iterations = args.iter
  28. pictrue_size = args.pictrue_size
  29. source_image = Image.open(base_image_path)
  30. source_image= source_image.resize((pictrue_size, pictrue_size))
  31. width, height = pictrue_size, pictrue_size
  32. def save_img(fname, image, image_enhance=True): # 图像增强
  33. image = Image.fromarray(image)
  34. if image_enhance:
  35. # 亮度增强
  36. enh_bri = ImageEnhance.Brightness(image)
  37. brightness = 1.2
  38. image = enh_bri.enhance(brightness)
  39. # 色度增强
  40. enh_col = ImageEnhance.Color(image)
  41. color = 1.2
  42. image = enh_col.enhance(color)
  43. # 锐度增强
  44. enh_sha = ImageEnhance.Sharpness(image)
  45. sharpness = 1.2
  46. image = enh_sha.enhance(sharpness)
  47. imsave(fname, image)
  48. return
  49. # util function to resize and format pictures into appropriate tensors
  50. def preprocess_image(image):
  51. """
  52. 预处理图片,包括变形到(1,width, height)形状,数据归一到0-1之间
  53. :param image: 输入一张图片
  54. :return: 预处理好的图片
  55. """
  56. image = image.resize((width, height))
  57. image = img_to_array(image)
  58. image = np.expand_dims(image, axis=0) # (width, height)->(1,width, height)
  59. image = vgg19.preprocess_input(image) # 0-255 -> 0-1.0
  60. return image
  61. def deprocess_image(x):
  62. """
  63. 将0-1之间的数据变成图片的形式返回
  64. :param x: 数据在0-1之间的矩阵
  65. :return: 图片,数据都在0-255之间
  66. """
  67. x = x.reshape((width, height, 3))
  68. x[:, :, 0] += 103.939
  69. x[:, :, 1] += 116.779
  70. x[:, :, 2] += 123.68
  71. # 'BGR'->'RGB'
  72. x = x[:, :, ::-1]
  73. x = np.clip(x, 0, 255).astype('uint8') # 以防溢出255范围
  74. return x
  75. def gram_matrix(x): # Gram矩阵
  76. assert K.ndim(x) == 3
  77. if K.image_data_format() == 'channels_first':
  78. features = K.batch_flatten(x)
  79. else:
  80. features = K.batch_flatten(K.permute_dimensions(x, (2, 0, 1)))
  81. gram = K.dot(features, K.transpose(features))
  82. return gram
  83. # 风格损失,是风格图片与结果图片的Gram矩阵之差,并对所有元素求和
  84. def style_loss(style, combination):
  85. assert K.ndim(style) == 3
  86. assert K.ndim(combination) == 3
  87. S = gram_matrix(style)
  88. C = gram_matrix(combination)
  89. S_C = S-C
  90. channels = 3
  91. size = height * width
  92. return K.sum(K.square(S_C)) / (4. * (channels ** 2) * (size ** 2))
  93. #return K.sum(K.pow(S_C,4)) / (4. * (channels ** 2) * (size ** 2)) # 居然和平方没有什么不同
  94. #return K.sum(K.pow(S_C,4)+K.pow(S_C,2)) / (4. * (channels ** 2) * (size ** 2)) # 也能用,花后面出现了叶子
  95. def eval_loss_and_grads(x): # 输入x,输出对应于x的梯度和loss
  96. if K.image_data_format() == 'channels_first':
  97. x = x.reshape((1, 3, height, width))
  98. else:
  99. x = x.reshape((1, height, width, 3))
  100. outs = f_outputs([x]) # 输入x,得到输出
  101. loss_value = outs[0]
  102. if len(outs[1:]) == 1:
  103. grad_values = outs[1].flatten().astype('float64')
  104. else:
  105. grad_values = np.array(outs[1:]).flatten().astype('float64')
  106. return loss_value, grad_values
  107. # an auxiliary loss function
  108. # designed to maintain the "content" of the
  109. # base image in the generated image
  110. def content_loss(base, combination):
  111. return K.sum(K.square(combination - base))
  112. # the 3rd loss function, total variation loss,
  113. # designed to keep the generated image locally coherent
  114. def total_variation_loss(x,img_nrows=width, img_ncols=height):
  115. assert K.ndim(x) == 4
  116. if K.image_data_format() == 'channels_first':
  117. a = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, 1:, :img_ncols - 1])
  118. b = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, :img_nrows - 1, 1:])
  119. else:
  120. a = K.square(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :])
  121. b = K.square(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :])
  122. return K.sum(K.pow(a + b, 1.25))
  123. # Evaluator可以只需要进行一次计算就能得到所有的梯度和loss
  124. class Evaluator(object):
  125. def __init__(self):
  126. self.loss_value = None
  127. self.grads_values = None
  128. def loss(self, x):
  129. assert self.loss_value is None
  130. loss_value, grad_values = eval_loss_and_grads(x)
  131. self.loss_value = loss_value
  132. self.grad_values = grad_values
  133. return self.loss_value
  134. def grads(self, x):
  135. assert self.loss_value is not None
  136. grad_values = np.copy(self.grad_values)
  137. self.loss_value = None
  138. self.grad_values = None
  139. return grad_values
  140. # 得到需要处理的数据,处理为keras的变量(tensor),处理为一个(3, width, height, 3)的矩阵
  141. # 分别是基准图片,风格图片,结果图片
  142. base_image = K.variable(preprocess_image(source_image)) # 基准图像
  143. style_reference_image = K.variable(preprocess_image(load_img(style_reference_image_path)))
  144. if K.image_data_format() == 'channels_first':
  145. combination_image = K.placeholder((1, 3, width, height))
  146. else:
  147. combination_image = K.placeholder((1, width, height, 3))
  148. # 组合以上3张图片,作为一个keras输入向量
  149. input_tensor = K.concatenate([base_image, style_reference_image, combination_image], axis=0) #组合
  150. # 使用Keras提供的训练好的Vgg19网络,不带3个全连接层
  151. model = vgg19.VGG19(input_tensor=input_tensor,weights='imagenet', include_top=False)
  152. model.summary() # 打印出模型概况
  153. '''
  154. Layer (type) Output Shape Param #
  155. =================================================================
  156. input_1 (InputLayer) (None, None, None, 3) 0
  157. _________________________________________________________________
  158. block1_conv1 (Conv2D) (None, None, None, 64) 1792 A
  159. _________________________________________________________________
  160. block1_conv2 (Conv2D) (None, None, None, 64) 36928
  161. _________________________________________________________________
  162. block1_pool (MaxPooling2D) (None, None, None, 64) 0
  163. _________________________________________________________________
  164. block2_conv1 (Conv2D) (None, None, None, 128) 73856 B
  165. _________________________________________________________________
  166. block2_conv2 (Conv2D) (None, None, None, 128) 147584
  167. _________________________________________________________________
  168. block2_pool (MaxPooling2D) (None, None, None, 128) 0
  169. _________________________________________________________________
  170. block3_conv1 (Conv2D) (None, None, None, 256) 295168 C
  171. _________________________________________________________________
  172. block3_conv2 (Conv2D) (None, None, None, 256) 590080
  173. _________________________________________________________________
  174. block3_conv3 (Conv2D) (None, None, None, 256) 590080
  175. _________________________________________________________________
  176. block3_conv4 (Conv2D) (None, None, None, 256) 590080
  177. _________________________________________________________________
  178. block3_pool (MaxPooling2D) (None, None, None, 256) 0
  179. _________________________________________________________________
  180. block4_conv1 (Conv2D) (None, None, None, 512) 1180160 D
  181. _________________________________________________________________
  182. block4_conv2 (Conv2D) (None, None, None, 512) 2359808
  183. _________________________________________________________________
  184. block4_conv3 (Conv2D) (None, None, None, 512) 2359808
  185. _________________________________________________________________
  186. block4_conv4 (Conv2D) (None, None, None, 512) 2359808
  187. _________________________________________________________________
  188. block4_pool (MaxPooling2D) (None, None, None, 512) 0
  189. _________________________________________________________________
  190. block5_conv1 (Conv2D) (None, None, None, 512) 2359808 E
  191. _________________________________________________________________
  192. block5_conv2 (Conv2D) (None, None, None, 512) 2359808
  193. _________________________________________________________________
  194. block5_conv3 (Conv2D) (None, None, None, 512) 2359808
  195. _________________________________________________________________
  196. block5_conv4 (Conv2D) (None, None, None, 512) 2359808 F
  197. _________________________________________________________________
  198. block5_pool (MaxPooling2D) (None, None, None, 512) 0
  199. =================================================================
  200. '''
  201. # Vgg19网络中的不同的名字,储存起来以备使用
  202. outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])
  203. loss = K.variable(0.)
  204. layer_features = outputs_dict['block5_conv2']
  205. base_image_features = layer_features[0, :, :, :]
  206. combination_features = layer_features[2, :, :, :]
  207. content_weight = 0.08
  208. loss += content_weight * content_loss(base_image_features,
  209. combination_features)
  210. feature_layers = ['block1_conv1','block2_conv1','block3_conv1','block4_conv1','block5_conv1']
  211. feature_layers_w = [0.1,0.1,0.4,0.3,0.1]
  212. # feature_layers = ['block5_conv1']
  213. # feature_layers_w = [1]
  214. for i in range(len(feature_layers)):
  215. # 每一层的权重以及数据
  216. layer_name, w = feature_layers[i], feature_layers_w[i]
  217. layer_features = outputs_dict[layer_name] # 该层的特征
  218. style_reference_features = layer_features[1, :, :, :] # 参考图像在VGG网络中第i层的特征
  219. combination_features = layer_features[2, :, :, :] # 结果图像在VGG网络中第i层的特征
  220. loss += w * style_loss(style_reference_features, combination_features) # 目标风格图像的特征和结果图像特征之间的差异作为loss
  221. loss += total_variation_loss(combination_image)
  222. # 求得梯度,输入combination_image,对loss求梯度, 每轮迭代中combination_image会根据梯度方向做调整
  223. grads = K.gradients(loss, combination_image)
  224. outputs = [loss]
  225. if isinstance(grads, (list, tuple)):
  226. outputs += grads
  227. else:
  228. outputs.append(grads)
  229. f_outputs = K.function([combination_image], outputs)
  230. evaluator = Evaluator()
  231. x = preprocess_image(source_image)
  232. img = deprocess_image(x.copy())
  233. fname = '原始图片.png'
  234. save_img(fname, img)
  235. # 开始迭代
  236. for i in range(iterations):
  237. start_time = time.time()
  238. print('迭代', i,end=" ")
  239. x, min_val, info = fmin_l_bfgs_b(evaluator.loss, x.flatten(), fprime=evaluator.grads, maxfun=20, epsilon=1e-7)
  240. # 一个scipy的L-BFGS优化器
  241. print('目前loss:', min_val,end=" ")
  242. # 保存生成的图片
  243. img = deprocess_image(x.copy())
  244. fname = 'result_%d.png' % i
  245. end_time = time.time()
  246. print('耗时%.2f s' % (end_time - start_time))
  247. if i%5 == 0 or i == iterations-1:
  248. save_img(fname, img, image_enhance=True)
  249. print('文件保存为', fname)

基准图像:

风格图像:

合成的艺术风格图像:

训练时候整体的loss是3个loss的和,每个loss都有一个系数,调整不同的系数,对应不同的效果。

“内容损失”(content loss)

以下图片分别对应内容损失系数为0.1、1、5、10的效果:

随着内容损失系数的增大,迭代优化会更加侧重于调整合成图像的内容,使得图像跟原始图像越来越接近。

“风格损失”(style loss)

风格损失是VGG网络5个CNN层的特征的融合,单纯增大风格损失系数对图像最终风格影响不大,以下是系数是1和100的对比:

系数相差100倍,但是图像风格并没有明显的改变。可能调整5个卷积特征不同的比例系数会有效果。

以下是单纯使用第1、2、3、4、5个卷积层特征的效果:

可见 5个卷积层特征里第3和第4个卷积层对图像的风格影响较大。

以下调整第3和第4个卷积层的系数,5个系数比为1:1:1:1:1和0.5:0.5:0.4:0.4:1

增大第3、4层比例之后,图像风格更加接近风格图像。

“差异损失”(variation loss)

图像差异损失衡量的是图像本身的局部特征之间的差异,系数越大,图像局部越接近,表现在图像上就是图像像素间过度自然,以下是系数是1、5、10的效果:

以上。

keras图像风格迁移的更多相关文章

  1. 图像风格迁移(Pytorch)

    图像风格迁移 最后要生成的图片是怎样的是难以想象的,所以朴素的监督学习方法可能不会生效, Content Loss 根据输入图片和输出图片的像素差别可以比较损失 \(l_{content} = \fr ...

  2. Keras实现风格迁移

    风格迁移 风格迁移算法经历多次定义和更新,现在应用在许多智能手机APP上. 风格迁移在保留目标图片内容的基础上,将图片风格引用在目标图片上. 风格本质上是指在各种空间尺度上图像中的纹理,颜色和视觉图案 ...

  3. fast neural style transfer图像风格迁移基于tensorflow实现

    引自:深度学习实践:使用Tensorflow实现快速风格迁移 一.风格迁移简介 风格迁移(Style Transfer)是深度学习众多应用中非常有趣的一种,如图,我们可以使用这种方法把一张图片的风格“ ...

  4. Distill详述「可微图像参数化」:神经网络可视化和风格迁移利器!

    近日,期刊平台 Distill 发布了谷歌研究人员的一篇文章,介绍一个适用于神经网络可视化和风格迁移的强大工具:可微图像参数化.这篇文章从多个方面介绍了该工具. 图像分类神经网络拥有卓越的图像生成能力 ...

  5. 基于 Keras 实现图像风格转移

     Style Transfer 这个方向火起来是从2015年Gatys发表的Paper A Neural Algorithm of Artistic Style(神经风格迁移) , 这里就简单提一下论 ...

  6. A Neural Algorithm of Artistic Style 图像风格转换 - keras简化版实现

    前言 深度学习是最近比较热的词语.说到深度学习的应用,第一个想到的就是Prisma App的图像风格转换.既然感兴趣就直接开始干,读了论文,一知半解:看了别人的源码,才算大概了解的具体的实现,也惊叹别 ...

  7. Gram格拉姆矩阵在风格迁移中的应用

    Gram定义 n维欧式空间中任意k个向量之间两两的内积所组成的矩阵,称为这k个向量的格拉姆矩阵(Gram matrix) 根据定义可以看到,每个Gram矩阵背后都有一组向量,Gram矩阵就是由这一组向 ...

  8. 『cs231n』通过代码理解风格迁移

    『cs231n』卷积神经网络的可视化应用 文件目录 vgg16.py import os import numpy as np import tensorflow as tf from downloa ...

  9. ng-深度学习-课程笔记-14: 人脸识别和风格迁移(Week4)

    1 什么是人脸识别( what is face recognition ) 在相关文献中经常会提到人脸验证(verification)和人脸识别(recognition). verification就 ...

随机推荐

  1. PHP的openssl_encrypt方法的Java实现

    <?php class OpenSSL3DES { /*密钥,22个字符*/ const KEY='09bd821d3e764f44899a9dc6'; /*向量,8个或10个字符*/ cons ...

  2. angular6 safe url pipe

    safe-url.pipe.ts import { Pipe, PipeTransform } from '@angular/core'; import { DomSanitizer } from ' ...

  3. 阿里云 oss 上传文件,js直传,.net 签名,回调

    后台签名 添加引用 string dir = string.Format("{0:yyyy-MM-dd}", date) + "/"; OssClient cl ...

  4. java8 简便的map和list操作

    如果你看到这篇文章,说明你对java繁琐的list和map操作产生了厌烦.在java中,频繁的操作基本上是获取到对象list,然后根据某个属性或者某几个属性的值,把list转为map,然后遍历其他对象 ...

  5. 【Angular 5】数据绑定、事件绑定和双向绑定

    本文为Angular5的学习笔记,IDE使用Visual Studio Code,内容是关于数据绑定,包括Property Binding.Class Binding.Style Binding. 在 ...

  6. Python3 Tcp未发送/接收完数据即被RST处理办法

    一.背景说明 昨天一个同事让帮忙写个服务,用于接收并返回他那边提交过来的数据,以便其查看提交的数据及格式是否正确. 开始想用django写个接口,但写接口接口名称就得是定死的,他那边只能向这接口提交数 ...

  7. 安装pyspider遇到的坑

    pyspider是国人写的一款开源爬虫框架,个人觉得这个框架用起来很方便,至于如何方便可以继续看下去. 作者博客:http://blog.binux.me/ 安装pyspider安装pyspider: ...

  8. learning makefile call func

  9. centos 日志文件

    以下介绍的是20个位于/var/log/ 目录之下的日志文件.其中一些只有特定版本采用,如dpkg.log只能在基于Debian的系统中看到./var/log/messages — 包括整体系统信息, ...

  10. Spring _day02_IoC注解开发入门

    1.Spring IoC注解开发入门 1.1 注解开发案例: 创建项目所需要的jar,四个基本的包(beans core context expression ),以及两个日志记录的包,还要AOP的包 ...