采用Tensorflow内部函数直接对模型进行冻结
# enhance_raw.py
# transform from single frame into multi-frame enhanced single raw
from __future__ import division
import os, time, scipy.io
import tensorflow as tf
import numpy as np
import rawpy
import glob
from model_sid_latest import network_enhance_raw
import platform
import os from tensorflow.python.tools import freeze_graph os.environ["CUDA_VISIBLE_DEVICES"] = "" if platform.system() == 'Windows':
data_dir = 'D:/data/LightOnOff/'
elif platform.system() == 'Linux':
data_dir = './dataset/LightOnOff/'
else:
print('platform not supported!')
assert False checkpoint_dir = './model_light_on_off/'
result_dir = './out_light_on_off/'
log_dir = './log_light_on_off/'
learning_rate = 1e-4
save_model_every_n_epoch = 10
max_epoch = 20000
if platform.system() == 'Windows':
save_output_every_n_steps = 1
else:
save_output_every_n_steps = 100 # BBF100-2
bbf_w = 4032
bbf_h = 3024 patch_h = 512
patch_w = 512 patch_h = 800
patch_w = 1024 max_level = 1023
black_level = 64 tf.reset_default_graph() # set up dataset
train_ids = os.listdir(data_dir)
train_ids.sort() def preprocess(raw, bl, wl):
im = raw.raw_image_visible.astype(np.float32)
im = np.maximum(im - bl, 0)
return im / (wl - bl) def pack_raw_bbf(path):
raw = rawpy.imread(path)
bl = 64
wl = 1023
im = preprocess(raw, bl, wl)
im = np.expand_dims(im, axis=2)
H = im.shape[0]
W = im.shape[1]
if raw.raw_pattern[0, 0] == 0: # CFA=RGGB
out = np.concatenate((im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :],
im[1:H:2, 0:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 2: # BGGR
out = np.concatenate((im[1:H:2, 1:W:2, :],
im[0:H:2, 1:W:2, :],
im[0:H:2, 0:W:2, :],
im[1:H:2, 0:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 0: # GRBG
out = np.concatenate((im[0:H:2, 1:W:2, :],
im[0:H:2, 0:W:2, :],
im[1:H:2, 0:W:2, :],
im[1:H:2, 1:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 2: # GBRG
out = np.concatenate((im[1:H:2, 0:W:2, :],
im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :]), axis=2)
else:
assert False
wb = np.array(raw.camera_whitebalance)
wb[3] = wb[1]
wb = wb / wb[1]
out = np.minimum(out * wb, 1.0) # normalize the brightness
# out = np.minimum(out * 0.2 / np.maximum(1e-6, np.mean(out[:, :, 1])), 1.0) h_, w_ = im.shape[0]//2, im.shape[1]//2
out_16bit_ = np.zeros([h_, w_, 4], dtype=np.uint16)
out_16bit_[:, :, :] = np.uint16(out[:, :, :] * (wl - bl))
del out
return out_16bit_ def raw2rgb(raw): # GRBG
assert len(raw.shape)==3
h, w = raw.shape[0]<<1, raw.shape[1]<<1
rgb = np.zeros([h, w, 3])
rgb[0:h:2, 0:w:2, 1] = raw[:, :, 1]
rgb[0:h:2, 1:w:2, 0] = raw[:, :, 0]
rgb[1:h:2, 0:w:2, 2] = raw[:, :, 2]
rgb[1:h:2, 1:w:2, 1] = raw[:, :, 3]
return rgb def max_in_all(left, left_top, top, top_right, right, right_bottom, bottom, bottom_left, center):
return np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(left, left_top),
top),
top_right),
right),
right_bottom),
bottom),
bottom_left),
center) def demosaic(rgb):
for chn_id in range(3):
left = rgb[0:-2, 1:-1, chn_id]
left_top = rgb[0:-2, 0:-2, chn_id]
top = rgb[0:-2, 1:-1, chn_id]
top_right = rgb[0:-2, 2:, chn_id]
right = rgb[1:-1, 2:, chn_id]
right_bottom = rgb[2:, 2:, chn_id]
bottom = rgb[2:, 1:-1, chn_id]
bottom_left = rgb[2:, 0:-2, chn_id]
center = rgb[1:-1, 1:-1, chn_id]
rgb[1:-1, 1:-1, chn_id] = max_in_all(left, left_top, top, top_right, right, right_bottom, bottom, bottom_left, center)
return rgb def gray_ps(rgb):
return np.power(np.power(rgb[:, :, 0], 2.2) * 0.2973 + np.power(rgb[:,:,1], 2.2) * 0.6274 + np.power(rgb[:,:,2], 2.2) * 0.0753, 1/2.2) + 1e-7 def gamma_correction(x, curve_ratio):
gray_scale = np.expand_dims(gray_ps(x), axis=-1)
gray_scale_new = np.power(gray_scale, curve_ratio)
return np.minimum(x * gray_scale_new / gray_scale, 1.0) # setting the ratio of GPU global memory usage
gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
in_im = tf.placeholder(tf.float32, [1, patch_h, patch_w, 4], name='input')
gt_im = tf.placeholder(tf.float32, [1, patch_h, patch_w, 4])
out_im = network_enhance_raw(in_im, patch_h, patch_w)
norm_im = tf.minimum(tf.maximum(out_im, 0.0), 1.0) ssim_loss = 1 - tf.image.ssim_multiscale(norm_im[0], gt_im[0], 1.0)
l1_loss = tf.reduce_mean(tf.reduce_sum(tf.abs(norm_im - gt_im), axis=-1))
l2_loss = tf.reduce_mean(tf.reduce_sum(tf.square(norm_im - gt_im), axis=-1))
# G_loss = ssim_loss
G_loss = l1_loss + l2_loss tf.summary.scalar('G_loss', G_loss)
tf.summary.scalar('MS-SSIM Loss', ssim_loss)
tf.summary.scalar('L1 Loss', l1_loss)
tf.summary.scalar('L2 Loss', l2_loss) t_vars = tf.trainable_variables()
lr = tf.placeholder(tf.float32)
G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss) saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
print('loaded ' + ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path) # save the images for tracking training states
if not os.path.isdir(result_dir):
os.mkdir(result_dir) g_loss = np.zeros((500, 1)) merged = tf.summary.merge_all()
writer = tf.summary.FileWriter(log_dir, sess.graph) gt_files = [None] * len(train_ids)
input_files = [None] * len(train_ids) input_images = [None] * len(train_ids)
gt_images = [None] * len(train_ids) for i in range(0, len(train_ids)):
gt_files[i] = glob.glob(os.path.join(data_dir, train_ids[i]) + '/*on*.dng')[0]
input_files[i] = glob.glob(os.path.join(data_dir, train_ids[i]) + '/*off*.dng')
input_images[i] = [None] * len(input_files[i]) steps = 0
st = time.time() for epoch in range(0, max_epoch):
for ind in np.random.permutation(len(train_ids)):
steps += 1
sid = np.random.randint(0, len(input_files[ind]))
if input_images[ind][sid] is None:
input_images[ind][sid] = np.expand_dims(pack_raw_bbf(input_files[ind][sid]), axis=0)
if gt_images[ind] is None:
gt_images[ind] = np.expand_dims(np.maximum(pack_raw_bbf(gt_files[ind]), 0), axis=0) # random cropping
xx = np.random.randint(0, bbf_w//2 - patch_w)
yy = np.random.randint(0, bbf_h//2 - patch_h)
input_patch = np.float32(input_images[ind][sid][:, yy:yy + patch_h, xx:xx + patch_w, :]) / (max_level - black_level)
gt_patch = np.float32(gt_images[ind][:, yy:yy + patch_h, xx:xx + patch_w, :]) / (max_level - black_level) # random flipping
if np.random.randint(2, size=1)[0] == 1: # random flip
input_patch = np.flip(input_patch, axis=1)
gt_patch = np.flip(gt_patch, axis=1)
if np.random.randint(2, size=1)[0] == 1:
input_patch = np.flip(input_patch, axis=0)
gt_patch = np.flip(gt_patch, axis=0)
# if np.random.randint(2, size=1)[0] == 1: # random transpose
# input_patch = np.transpose(input_patch, (0, 2, 1, 3))
# gt_patch = np.transpose(gt_patch, (0, 2, 1, 3)) # summary, _, G_current, output = sess.run(
# [merged, G_opt, G_loss, out_im],
# feed_dict={
# in_im: input_patch,
# gt_im: gt_patch,
# lr: learning_rate})
# g_loss[ind] = G_current summary, output = sess.run(
[merged, out_im],
feed_dict={
in_im: input_patch,
gt_im: gt_patch,
lr: learning_rate
}) # saver.save(sess, checkpoint_dir + '%d.ckpt' % epoch)
# print('model saved.')
# exit(0) tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model_raw2raw.pb')
freeze_graph.freeze_graph(
'output_model/pb_model/model_raw2raw.pb',
'',
False,
'./model_light_on_off/0.ckpt',
'gen/output',
'save/restore_all',
'save/Const:0',
'output_model/pb_model/frozen_model.pb',
True,
"")
exit(0) if steps % save_output_every_n_steps == 0:
loss_ = np.mean(g_loss[np.where(g_loss)])
cost_ = (time.time() - st)/save_output_every_n_steps
st = time.time()
print("%d %d Loss=%.6f Speed=%.6f" % (epoch, steps, loss_, cost_))
writer.add_summary(summary, global_step=steps)
# save the current output image for network inspection
out_ = np.minimum(np.maximum(output, 0), 1)
in_rgb = gamma_correction(demosaic(raw2rgb(input_patch[0])), 0.35)
gt_rgb = gamma_correction(demosaic(raw2rgb(gt_patch[0])), 0.35)
out_rgb = gamma_correction(demosaic(raw2rgb(out_[0])), 0.35)
temp = np.concatenate((in_rgb, gt_rgb, out_rgb), axis=1)
scipy.misc.toimage(temp * 255, high=255, low=0, cmin=0, cmax=255)\
.save(result_dir + '/%d_%s_00.jpg' % (epoch, train_ids[ind])) # clean up the memory if necessary
if platform.system() == 'Windows':
input_images[ind][sid] = None
gt_images[ind] = None if epoch % save_model_every_n_epoch == 0:
saver.save(sess, checkpoint_dir + '%d.ckpt' % epoch)
print('model saved.')
采用Tensorflow内部函数直接对模型进行冻结的更多相关文章
- tensorflow加载embedding模型进行可视化
1.功能 采用python的gensim模块训练的word2vec模型,然后采用tensorflow读取模型可视化embedding向量 ps:采用C++版本训练的w2v模型,python的gensi ...
- TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model
TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-i ...
- tensorflow训练验证码识别模型
tensorflow训练验证码识别模型的样本可以使用captcha生成,captcha在linux中的安装也很简单: pip install captcha 生成验证码: # -*- coding: ...
- 开园第一篇---有关tensorflow加载不同模型的问题
写在前面 今天刚刚开通博客,主要想法跟之前某位博主说的一样,希望通过博客园把每天努力的点滴记录下来,也算一种坚持的动力.我是小白一枚,有啥问题欢迎各位大神指教,鞠躬~~ 换了新工作,目前手头是OCR项 ...
- 【6】TensorFlow光速入门-python模型转换为tfjs模型并使用
本文地址:https://www.cnblogs.com/tujia/p/13862365.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...
- 【4】TensorFlow光速入门-保存模型及加载模型并使用
本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...
- 【TensorFlow】基于ssd_mobilenet模型实现目标检测
最近工作的项目使用了TensorFlow中的目标检测技术,通过训练自己的样本集得到模型来识别游戏中的物体,在这里总结下. 本文介绍在Windows系统下,使用TensorFlow的object det ...
- TensorFlow学习笔记12-word2vec模型
为什么学习word2word2vec模型? 该模型用来学习文字的向量表示.图像和音频可以直接处理原始像素点和音频中功率谱密度的强度值, 把它们直接编码成向量数据集.但在"自然语言处理&quo ...
- tensorflow之逻辑回归模型实现
前面一篇介绍了用tensorflow实现线性回归模型预测sklearn内置的波士顿房价,现在这一篇就记一下用逻辑回归分类sklearn提供的乳腺癌数据集,该数据集有569个样本,每个样本有30维,为二 ...
随机推荐
- java中的函数
1.函数:定义在类中的具有特定功能的一段独立小程序.函数也称之为方法. 为了提高代码的复用性,对代码进行抽取. 将这个部分定义成一个独立的功能.方便使用. java中对功能的定义通过函数来实现的.2函 ...
- liunx驱动----系统滴答时钟的使用
2019-3-12系统滴答定时器中断使用 定义一个timer 其实就是使用系统的滴答定时器产生一个中断. 初始化timer init_timer函数 实现如下 void fastcall ini ...
- 操作mongodb
MongoDB数据库是以k-v形式存储在磁盘上的. import pymongoclient = pymongo.MongoClient(host='10.29.3.40',port=27017)db ...
- LeetCode Weekly Contest 117
已经正式在实习了,好久都没有刷题了(应该有半年了吧),感觉还是不能把思维锻炼落下,所以决定每周末刷一次LeetCode. 这是第一周(菜的真实,只做了两题,还有半小时不想看了,冷~). 第一题: 96 ...
- Vue 组件&组件之间的通信 之 子组件向父组件传值
子组件向父组件传值:子组件通过$.emit()方法以事件形式向父组件发送消息传值: 使用步骤: 定义组件:现有自定义组件com-a.com-b,com-a是com-b的父组件: 准备获取数据:父组件c ...
- C# RSA加解密与验签,AES加解密,以及与JAVA平台的密文加解密
前言: RSA算法是利用公钥与密钥对数据进行加密验证的一种算法.一般是拿私钥对数据进行签名,公钥发给友商,将数据及签名一同发给友商,友商利用公钥对签名进行验证.也可以使用公钥对数据加密,然后用私钥对数 ...
- 获取手机当前显示的ViewController
//获取手机当前显示的ViewController - (UIViewController*)currentViewController{ UIViewController* vc = [UIAppl ...
- [c/c++] programming之路(27)、union共用体
共用体时刻只有一个变量,结构体变量同时并存 一.创建共用体的三种形式 #include<stdio.h> #include<stdlib.h> #include<stri ...
- raphael参数说明
大纲 :first-child { margin-top: 0px; } .markdown-preview:not([data-use-github-style]) h1, .markdown-pr ...
- PKUWC 2017 Day 2 简要题解
*注意:题面请移步至loj查看. 从这里开始 Problem A 随机算法 Problem B 猎人杀 Problem C 随机游走 怎么PKU和THU都编了一些假算法,然后求正确率[汗]. 之前听说 ...