网上有很多关于tensorflow lite在安卓端部署的教程,但是大多只讲如何把训练好的模型部署到安卓端,不讲如何训练,而实际上在部署的时候,需要知道训练模型时预处理的细节,这就导致了自己训练的模型在部署到安卓端的时候出现各种问题。因此,本文会记录从PC端训练、导出到安卓端部署的各种细节。欢迎大家讨论、指教。

PC端系统:Ubuntu14

tensorflow版本:tensroflow1.14

安卓版本:9.0

PC端训练过程

数据集:自定义生成

训练框架:tensorflow slim  关于tensorflow slim如何安装,这里不再赘述,大家自行百度解决。

数据生成代码:生成50000张28*28大小三通道的验证码图片,共分10类,0-9,生成的数据保存在datasets/images/里面

# -*- coding: utf-8 -*-

import cv2
import numpy as np from captcha.image import ImageCaptcha def generate_captcha(text=''):
"""Generate a digit image."""
capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
image = capt.generate_image(text)
image = np.array(image, dtype=np.uint8)
return image if __name__ == '__main__':
output_dir = './datasets/images/'
for i in range(50000):
label = np.random.randint(0, 10)
image = generate_captcha(str(label))
image_name = 'image{}_{}.jpg'.format(i+1, label)
output_path = output_dir + image_name
cv2.imwrite(output_path, image)

训练:本次训练我用tensorflow slim 搭建了一个七层卷积的网络,最后测试准确率在96%~99%左右,模型1.2M,适合在移动端部署。训练的时候我做了两点工作

1、指明了模型的输入和输出节点的名字,PC端部署测试模型的时候要用到,也便于快速确定模型的输出数据到底是什么格式,移动端代码要与其保持一致

inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 3], name='inputs')
.......
.......
prob_ = tf.identity(prob, name='prob')

2、训练结束的时候直接把模型保存成PB格式

        constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['inputs','prob']) #训练完毕直接把模型保存为PB格式
with tf.gfile.FastGFile('model3.pb', mode='wb') as f: #模型的名字是model.pb
f.write(constant_graph.SerializeToString())

训练代码如下

# -*- coding: utf-8 -*-

"""Train a CNN model to classifying 10 digits.

Example Usage:
---------------
python3 train.py \
--images_path: Path to the training images (directory).
--model_output_path: Path to model.ckpt.
""" import cv2
import glob
import numpy as np
import os
import tensorflow as tf import model
from tensorflow.python.framework import graph_util flags = tf.app.flags flags.DEFINE_string('images_path', None, 'Path to training images.')
flags.DEFINE_string('model_output_path', None, 'Path to model checkpoint.')
FLAGS = flags.FLAGS def get_train_data(images_path):
"""Get the training images from images_path. Args:
images_path: Path to trianing images. Returns:
images: A list of images.
lables: A list of integers representing the classes of images. Raises:
ValueError: If images_path is not exist.
"""
if not os.path.exists(images_path):
raise ValueError('images_path is not exist.') images = []
labels = []
images_path = os.path.join(images_path, '*.jpg')
count = 0
for image_file in glob.glob(images_path):
count += 1
if count % 100 == 0:
print('Load {} images.'.format(count))
image = cv2.imread(image_file)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Assume the name of each image is imagexxx_label.jpg
label = float(image_file.split('_')[-1].split('.')[0])
images.append(image)
labels.append(label)
images = np.array(images)
labels = np.array(labels)
return images, labels def next_batch_set(images, labels, batch_size=128):
"""Generate a batch training data. Args:
images: A 4-D array representing the training images.
labels: A 1-D array representing the classes of images.
batch_size: An integer. Return:
batch_images: A batch of images.
batch_labels: A batch of labels.
"""
indices = np.random.choice(len(images), batch_size)
batch_images = images[indices]
batch_labels = labels[indices]
return batch_images, batch_labels def main(_):
inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 3], name='inputs')
labels = tf.placeholder(tf.int32, shape=[None], name='labels') cls_model = model.Model(is_training=True, num_classes=10)
preprocessed_inputs = cls_model.preprocess(inputs)#预处理
prediction_dict = cls_model.predict(preprocessed_inputs)
loss_dict = cls_model.loss(prediction_dict, labels)
loss = loss_dict['loss']
postprocessed_dict = cls_model.postprocess(prediction_dict)
classes = postprocessed_dict['classes']
prob = postprocessed_dict['prob']
classes_ = tf.identity(classes, name='classes')
prob_ = tf.identity(prob, name='prob')
acc = tf.reduce_mean(tf.cast(tf.equal(classes, labels), 'float')) global_step = tf.Variable(0, trainable=False)
learning_rate = tf.train.exponential_decay(0.05, global_step, 150, 0.9) optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)
train_step = optimizer.minimize(loss, global_step) saver = tf.train.Saver() images, targets = get_train_data(FLAGS.images_path) init = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init) for i in range(6000):
batch_images, batch_labels = next_batch_set(images, targets)
train_dict = {inputs: batch_images, labels: batch_labels} sess.run(train_step, feed_dict=train_dict) loss_, acc_,prob__,classes__ = sess.run([loss, acc, prob_,classes_], feed_dict=train_dict) train_text = 'step: {}, loss: {}, acc: {},classes:{}'.format(
i+1, loss_, acc_,classes__)
print(train_text)
print (prob__)
saver.save(sess, FLAGS.model_output_path)
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['inputs','prob']) #训练完毕直接把模型保存为PB格式
with tf.gfile.FastGFile('model3.pb', mode='wb') as f: #模型的名字是model.pb
f.write(constant_graph.SerializeToString())
if __name__ == '__main__':
tf.app.run()

这里尤其要注意,训练的时候图片是否做过预处理,比如减去均值和除法归一化操作,因为移动端需要保持和训练时候一样的操作。我的在训练的时候,预处理工作中包含了减去均值和除法归一化,并且把这两个OP打包直接放进了模型里面,也就是说图片数据进入模型之后会先进行预处理然后再进行正式的卷积等系列操作。所以,移动端的数据不需要单独写预处理的代码。很多时候,导出模型的时候并没有把预处理操作打包进模型,所以移动端要单独写几行关于减去均值和归一化的代码,然后再把数据送到分类模型当中。

另外一种把ckpt模型导出为pb模型的方式,代码如下

import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(input_checkpoint,output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路径
:return:
'''
# checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
# input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
#input_node_names = "inputs"
output_node_names = "inputs,classes"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # 获得默认的图
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图 with tf.Session() as sess:
saver.restore(sess, input_checkpoint) #恢复图并得到数据
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=input_graph_def,# 等于:sess.graph_def
output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点 # for op in graph.get_operations():
# print(op.name, op.values())
# 输入ckpt模型路径
input_checkpoint='model/model.ckpt'
# 输出pb模型的路径
out_pb_path="frozen_model.pb"
# 调用freeze_graph将ckpt转为pb
freeze_graph(input_checkpoint,out_pb_path)

把PB模型导出为tflite格式代码

import tensorflow as tf
#把pb文件路径改成自己的pb文件路径即可
path = "model2.pb" #如果是不知道自己的模型的输入输出节点,建议用tensorboard做可视化查看计算图,计算图里有输入输出的节点名称
inputs = ["inputs"]
outputs = ["prob"]
#转换pb模型到tflite模型
converter = tf.lite.TFLiteConverter.from_frozen_graph(path, inputs, outputs)
#converter.post_training_quantize = True
tflite_model = converter.convert()
open("model3.tflite", "wb").write(tflite_model)

还有另外一种利用bazel把模型导出为tflite的办法

进入tensorflow源码目录,两步编译
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/lite/toco:toco
./bazel-bin/tensorflow/contrib/lite/toco/toco
--input_file=/media/bayes/69da5b29-ae56-4feb-93a1-2ce24323aa78/project/model2.pb
--output_file=/media/bayes/69da5b29-ae56-4feb-93a1-2ce24323aa78/project/model2.tflite
--input_format=TENSORFLOW_GRAPHDEF
--output_format=TFLITE
--inference_type=FLOAT
--input_shape=1,28,28,3
--input_array=inputs
--output_array=prob

PB模型测试模型准确率

# -*- coding: utf-8 -*-

"""Evaluate the trained CNN model.
Example Usage:
---------------
python3 infrence_pb.py \
--frozen_graph_path: Path to model frozen graph.
""" import numpy as np
import tensorflow as tf from captcha.image import ImageCaptcha flags = tf.app.flags
flags.DEFINE_string('frozen_graph_path', None, 'Path to model frozen graph.')
FLAGS = flags.FLAGS def generate_captcha(text=''):
capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
image = capt.generate_image(text)
image = np.array(image, dtype=np.uint8)
return image def main(_):
model_graph = tf.Graph()
with model_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(FLAGS.frozen_graph_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='') with model_graph.as_default():
with tf.Session(graph=model_graph) as sess:
inputs = model_graph.get_tensor_by_name('inputs:0')
classes = model_graph.get_tensor_by_name('classes:0')
prob = model_graph.get_tensor_by_name('prob:0')
for i in range(10):
label = np.random.randint(0, 10)
image = generate_captcha(str(label))
image =
image_np = np.expand_dims(image, axis=0)
predicted_label,probs = sess.run([classes,prob],
feed_dict={inputs: image_np})
print(predicted_label, ' vs ', label)
print(probs) if __name__ == '__main__':
tf.app.run()

tflite格式测试模型准确率

# -*- coding:utf-8 -*-
import os
import cv2
import numpy as np
import time import tensorflow as tf test_image_dir = './test_images/'
#model_path = "./model/quantize_frozen_graph.tflite"
model_path = "./model3.tflite" # Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors() # Get input and output tensors.
input_details = interpreter.get_input_details()
print(str(input_details))
output_details = interpreter.get_output_details()
print(str(output_details)) #with tf.Session( ) as sess:
if 1:
file_list = os.listdir(test_image_dir) model_interpreter_time = 0
start_time = time.time()
# 遍历文件
for file in file_list:
print('=========================')
full_path = os.path.join(test_image_dir, file)
print('full_path:{}'.format(full_path)) img = cv2.imread(full_path )
res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC)
# 变成长784的一维数据
#new_img = res_img.reshape((784))
new_img = np.array(res_img, dtype=np.uint8)
# 增加一个维度,变为 [1, 784]
image_np_expanded = np.expand_dims(new_img, axis=0)
image_np_expanded = image_np_expanded.astype('float32') # 类型也要满足要求 # 填装数据
model_interpreter_start_time = time.time()
interpreter.set_tensor(input_details[0]['index'], image_np_expanded) # 注意注意,我要调用模型了
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
model_interpreter_time += time.time() - model_interpreter_start_time # 出来的结果去掉没用的维度
result = np.squeeze(output_data)
print('result:{}'.format(result))
#print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded}))) # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
#print('result:{}'.format( (np.where(result==np.max(result)))[0][0] ))
used_time = time.time() - start_time
print('used_time:{}'.format(used_time))
print('model_interpreter_time:{}'.format(model_interpreter_time))

模型训练好以后,接下来要把模型部署到安卓端,其实这步很简单,只要替换安卓代码相应部分即可,安卓代码我会上传到CSDN,大家按需下载即可。那么主要留意更改哪些代码呢

#模型的输入大小
private int[] ddims = {1, 3, 28, 28};
#模型的名称
private static final String[] PADDLE_MODEL = {
"model3",
"mobilenet_quant_v1_224",
"mobilenet_v1_1.0_224",
"mobilenet_v2"
}; #标签的名称
BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("cacheLabel1.txt")));
#模型输出的数据类型,在PC端可以清楚地看到
float[][] labelProbArray = new float[1][10];

#输入数据预处理工作是否已经包含在模型里面
//  imgData.putFloat(((((val >> 16) & 0xFF) - 128f) / 128f));
// imgData.putFloat(((((val >> 8) & 0xFF) - 128f) / 128f));
// imgData.putFloat((((val & 0xFF) - 128f) / 128f));
imgData.putFloat(((val >> 16) & 0xFF) );
imgData.putFloat(((val >> 8) & 0xFF) );
imgData.putFloat((val & 0xFF) );

留一张测试图片,大家可以拿去测试,正确结果应该是0.0,安卓代码地址是这里,CSDN下载请搜索 anquangan

查看PB模型节点代码

#coding:utf-8

import tensorflow as tf
from tensorflow.python.framework import graph_util
tf.reset_default_graph() # 重置计算图
output_graph_path = 'model3.pb'
with tf.Session() as sess: tf.global_variables_initializer().run()
output_graph_def = tf.GraphDef()
# 获得默认的图
graph = tf.get_default_graph()
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
# 得到当前图有几个操作节点
print("%d ops in the final graph." % len(output_graph_def.node)) tensor_name = [tensor.name for tensor in output_graph_def.node]
print(tensor_name)
print('---------------------------')
# 在log_graph文件夹下生产日志文件,可以在tensorboard中可视化模型
#summaryWriter = tf.summary.FileWriter('log_graph/', graph) for op in graph.get_operations():
# print出tensor的name和值
print(op.name, op.values())

tensorflow从训练自定义CNN网络模型到Android端部署tflite的更多相关文章

  1. 在C#下使用TensorFlow.NET训练自己的数据集

    在C#下使用TensorFlow.NET训练自己的数据集 今天,我结合代码来详细介绍如何使用 SciSharp STACK 的 TensorFlow.NET 来训练CNN模型,该模型主要实现 图像的分 ...

  2. 复现VGG19训练自定义图像分类

    1.复现VGG训练自定义图像分类,成功了哈哈. 需要代码工程可联系博主qq号,在左边连接可找到. 核心代码: # coding:utf-8 import tensorflow as tf import ...

  3. tensorflow笔记:多层CNN代码分析

    tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 ...

  4. tensorflow分布式训练

    https://blog.csdn.net/hjimce/article/details/61197190  tensorflow分布式训练 https://cloud.tencent.com/dev ...

  5. 深度学习笔记 (二) 在TensorFlow上训练一个多层卷积神经网络

    上一篇笔记主要介绍了卷积神经网络相关的基础知识.在本篇笔记中,将参考TensorFlow官方文档使用mnist数据集,在TensorFlow上训练一个多层卷积神经网络. 下载并导入mnist数据集 首 ...

  6. Tensorflow Mask-RCNN训练识别箱子的模型运行结果(练习)

    Tensorflow Mask-RCNN训练识别箱子的模型

  7. 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 直接调用 C++ 接口实现

    现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过直 ...

  8. 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 间接调用 Python 实现

    现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过 ...

  9. 安装 tensorflow 1.1.0;以及安装其他相似版本tensorflow遇到的问题;tensorflow 1.13.2 cuda-10环境变量配置问题;Tensorflow 指定训练时如何指定使用的GPU;

    # 安装 2.7 环境conda create -n python2. python= conda activate python2. # 安装 1.1.0 gpu版本pip # 配置环境变量expo ...

随机推荐

  1. C#文本操作

    1.使用FileStream读写文件 文件头: 复制代码代码如下: using System;using System.Collections.Generic;using System.Text;us ...

  2. java如何连接Oracle数据库问题

    Oracle数据库纯属自学,不对请留言改正! 在学Oracle前相信已经大致知道mysql或sqlserver数据库,这个跟前面两个不大一样,你安装的时候让你输入一个密码,貌似是一个系统管理员密码,跟 ...

  3. ACM-Maximum Tape Utilization Ratio

    题目描述:Maximum Tape Utilization Ratio Tags: 贪婪策略 设有n 个程序{1,2,…, n }要存放在长度为L的磁带上.程序i存放在磁带上的长度是li ,1 < ...

  4. Python LMDB的使用

    在python中使用lmdb linux中,可以使用指令 pip install lmdb 安装lmdb包. ---- lmdb 数据库文件生成 增 改 删 查 1.生成一个空的lmdb数据库文件 # ...

  5. c++程序—布尔值

    #include<iostream> using namespace std; #include<string> int main() { //创建bool数据类型 bool ...

  6. TD信息通(无课表)使用体验

    首先,在注册账户的时候,TD信息通还是比较严谨的.用户名字符数.密码字符数.邮箱格式等都有要求,我认为,这对App的长远发展来说,是很重要的一个细节.而且,在登陆之前,会有一项关于是否自动登陆的选择, ...

  7. 在执行 php artisan key:generate ,报 Could not open input file: artisan 错误

    Could not open input file: artisan 必须保证命令是在项目根目录,如下图所示:

  8. POJ 3663:Costume Party

    Costume Party Time Limit: 1000MS   Memory Limit: 65536K Total Submissions: 12607   Accepted: 4977 De ...

  9. Windows系统自带选择文件的对话重写和居中处理

    class CMyFileDialog: public CFileDialogImpl<CMyFileDialog> { public: CMyFileDialog(BOOL bOpenF ...

  10. windows操作

    5.windows激活 数字权利许可工具激活 https://jingyan.baidu.com/article/9113f81b4d49232b3314c75e.html 4.网络连接不上 原因,v ...