Tensorflow是Google开源的一套机器学习框架,支持GPU、CPU、Android等多种计算平台。本文将介绍在Tensorflow在Android上的使用。

Android使用Tensorflow框架需要引入两个文件libtensorflow_inference.so、libandroid_tensorflow_inference_java.jar。这两个文件可以使用官方预编译的文件。如果预编译的so不满足要求(比如不支持训练模型中的某些操作符运算),也可以自己通过bazel编译生成这两个文件。

将libandroid_tensorflow_inference_java.jar放在app下的libs目录下,so文件命名为libtensorflow_jni.so放在src/main/jniLibs目录下对应的ABI文件夹下。目录结构如下:



Android目录结构

同时在app的build.gradle中的dependencies模块下添加如下配置:

dependencies {
...
compile files('libs/libandroid_tensorflow_inference_java.jar')
...
}

使用tensorflow框架进行机器学习分为四个步骤:

  • 构造神经网络

  • 训练神经网络模型

  • 将训练好的模型输出为pb文件

  • ndroid上加载pb模型进行计算

前三步是模型的构造,我们通过python实现,下面给出了一个二分类的简单模型的构造过程,首先是训练过程:

# -*-coding:utf-8 -*-
from __future__ import print_function
import os
import tensorflow as tf
from numpy.random import RandomState os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' """
训练模型
"""
def train():
# 定义训练数据集batch大小为8
batch_size = 8 # 定义神经网络参数,参数体现出神经网络结构,一个输入层,一个输出层,一个隐藏层
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1), name="w1_val")
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1), name="w2_val") # 定义输入输出格式
x = tf.placeholder(tf.float32, shape=(None, 2), name='x_input')
y_ = tf.placeholder(tf.float32, shape=(None, 1)) # 定义神经网络前向传播过程
a = tf.matmul(x, w1)
y = tf.matmul(a, w2, name="cal_node") # 定义交叉熵和反向传播算法
cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
train_step = tf.train.AdadeltaOptimizer(0.001).minimize(cross_entropy) # 生成随机训练集
rdm = RandomState(1)
dataset_size = 128 # 定义映射关系
X = rdm.rand(dataset_size, 2)
Y = [[int(x1 + x2 < 1)] for (x1, x2) in X] with tf.Session() as sess:
# 初始化所有参数
init_op = tf.global_variables_initializer()
sess.run(init_op) # print sess.run(w1)
# print sess.run(w2) STEPS = 500
for i in range(STEPS):
start = (i * batch_size) % dataset_size
end = min(start + batch_size, dataset_size) # 训练神经网络,更新神经网络参数
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]}) if i % 100 == 0:
total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy)) print(sess.run(w1))
print(sess.run(w2)) # 保存check point
saver = tf.train.Saver(tf.trainable_variables())
saver.save(sess, './model/checpt')

上面的代码首先定义神经网络,初始化训练数据,进行500次训练过程,并将训练结果checkpoints保存到model文件夹下,checkpoints包含了训练模型得到的参数信息,共生成四个相关的文件,如下图:

由于checkpoint文件众多,为了方便使用,我们通过下面的代码将它们生成一个pb文件,在android上只需要这个pb文件即可使用这个训练好的模型:

"""
存储pb模型
"""
def dump_graph_to_pb(pb_path):
with tf.Session() as sess:
check_point = tf.train.get_checkpoint_state("./model/")
if check_point:
saver = tf.train.import_meta_graph(check_point.model_checkpoint_path + '.meta')
saver.restore(sess, check_point.model_checkpoint_path)
else:
raise ValueError("Model load failed from {}".format(check_point.model_checkpoint_path)) graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), "cal_node".split(",")) with tf.gfile.GFile(pb_path, "wb") as f:
f.write(graph_def.SerializeToString())

拿到生成的pb模型,我们可以在android上使用了。将pb文件在这main/assets下:

接下来就可以载入pb,进行计算了:

public class MainActivity extends AppCompatActivity {
private Graph graph_;
private Session session_;
private AssetManager assetManager; private static ExecutorService executorService;
private static Handler handler;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main); executorService = Executors.newFixedThreadPool(5); // 初始化tensorflow
initTensorFlow("outmodel.pb"); // 使用tensorflow进行计算
runTensorFlow();
}
...
}

通过如下方式载入pb模型,初始化tensorflow:

private boolean initTensorFlow(String modelFile) {
assetManager = getAssets();
// 新建Graph
graph_ = new Graph(); InputStream is = null;
try {
// 读取Assets pb文件
is = assetManager.open(modelFile);
} catch (IOException e) {
e.printStackTrace();
return false;
} try {
// 加载pb到Graph
TensorUtil.loadGraph(is, graph_);
is.close();
} catch (IOException e) {
e.printStackTrace();
return false;
}
// 初始化session
session_ = new Session(graph_);
if (session_ == null) {
return false;
} return true;
}

然后就可以使用tensorflow API进行运算了:

private void runTensorFlow() {
executorService.execute(generatePredictRunnable(handler));
} private Runnable generatePredictRunnable(Handler handler) {
return new Runnable() {
@Override
public void run() {
float[][] input = new float[1][2]; input[0][0] = 1;
input[0][1] = 2; // 定义输入tensor
Tensor inputTensor = Tensor.create(input); // 指定输入,输出节点,运行并得到结果
Tensor resultTensor = session_.runner()
.feed("x_input", inputTensor)
.fetch("cal_node")
.run()
.get(0); float[][] dst = new float[1][1];
resultTensor.copyTo(dst); // 处理结果
ArrayList<Float> resultList = new ArrayList<>();
for (float val : dst[0]) {
if (val != 0) {
resultList.add(val);
} else {
break;
}
}
}
};
}

上面就是通过python训练机器学习模型,并在android平台进行调用的完整流程。

原创作者:JackMeGo,原文链接:https://www.jianshu.com/p/eef4ab014a12



欢迎关注我的微信公众号「码农突围」,分享Python、Java、大数据、机器学习、人工智能等技术,关注码农技术提升•职场突围•思维跃迁,20万+码农成长充电第一站,陪有梦想的你一起成长。

Python+Android进行TensorFlow开发的更多相关文章

  1. 【tensorflow】1.安装Tensorflow开发环境,安装Python 的IDE--PyCharm

    ================================================== 安装Tensorflow开发环境,安装Python 的IDE--PyCharm 1.PyCharm ...

  2. Python+Android开发

    1 下载Scripting Layer for Android (SL4A) Scripting Layer for Android (SL4A) 是一个开源项目,目标是为android系统提供脚本语 ...

  3. Python C++ OpenCV TensorFlow手势识别(1-10) 毕设 定制开发

    Python C++ OpenCV TensorFlow手势识别(1-10) 毕设 支持定制开发 (MFC,QT, PyQt5界面,视频摄像头识别) QQ: 3252314061 效果如下:

  4. Tensorflow开发环境配置及其基本概念

    Tensorflow开发环境配置及其基本概念 1.1. 安装Tensorflow开发环境 1.1.1. 安装pycharm 1.1.2. 安装pythe3.6 1.1.3. 安装Tensorflow ...

  5. 基于python语言的tensorflow的‘端到端’的字符型验证码识别源码整理(github源码分享)

    基于python语言的tensorflow的‘端到端’的字符型验证码识别 1   Abstract 验证码(CAPTCHA)的诞生本身是为了自动区分 自然人 和 机器人 的一套公开方法, 但是近几年的 ...

  6. TensorFlow 开发环境搭建--Pycharm

    今天动手开始搭建TensorFlow开发环境, 用PyCharm来跑MNIST中的例子.记录过程如下 下载安装 (1)首先安装AnaConda, AnaConda可以帮忙去管理安装包,帮忙创建虚拟环境 ...

  7. 【1】TensorFlow光速入门-tensorflow开发基本流程

    本文地址:https://www.cnblogs.com/tujia/p/13862339.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  8. Android应用安全开发之浅谈加密算法的坑

      <Android应用安全开发之浅谈加密算法的坑> 作者:阿里移动安全@伊樵,@舟海 阿里聚安全,一站式解决应用开发安全问题     Android开发中,难免会遇到需要加解密一些数据内 ...

  9. [转]Android样式的开发:shape篇

    转载自Keegan小钢原文链接:http://keeganlee.me/post/android/20150830 Android样式的开发:shape篇Android样式的开发:selector篇A ...

随机推荐

  1. Spring-IOC(DI)的三种注入方式

    spring为方便不同的需求,为我们提供了3中不同的注入方式分别是set.get方法注入,构造注入还有p命名空间注入,老规矩,直接上代码 首先创建实体类Student public class Stu ...

  2. jmeter压测遇到的问题

    一.今天压力测试时,开始12秒后出现了很多异常, 都是 java.net.NoRouteToHostException: Cannot assign requested address. 1.首先我这 ...

  3. [SDOI2006] 线性方程组

    洛谷 P2455 传送门 刚开始写了个消成上三角的,结果狂wa. 后来经过研究发现,消成上三角那种不能直接判断无解或无穷多解,需要其它的操作. 所以干脆学了个消成对角线的,写了一发A了. 其实两种消元 ...

  4. js弱类型转换的知识点

    本文属于转载知识点,以下是原博文作者:不死鸟哇的文章,文章链接:原文JavaScript里什么情况下a==!a为true呢? 今天群里有位同学问了这样一个问题,JavaScript在什么情况下会出现变 ...

  5. vue基础指令了解

    Vue了解 """ vue框架 vue是前台框架:Angular.React.Vue vue:结合其他框架优点.轻量级.中文API.数据驱动.双向绑定.MVVM设计模式. ...

  6. javascript学习内容

    http协议 犀牛书 MDN js单线程 let只在代码块内有效 es5只有全局作用域 const变量指向的内存地址不得改动,值不能保证不变 全局变量不加var node.js 更改连接到服务器的方式 ...

  7. 你每天跑这么多自动化用例,能发现BUG吗?

    阿里QA导读:为什么要度量测试有效性?这么多的CASE,花了大量时间和资源去运行,真能发现bug吗?CI做到90%的行覆盖率了,能发现问题吗?测试用例越来越多,删一些,会不会就发现不了问题了?怎么找出 ...

  8. loadrunner通过web的post请求方法测接口

    loadrunner通过web的post请求方法测接口 loginapi() 模拟APP发送请求给Cloud, Action() "Name=input","Value= ...

  9. 学习HEXO的历程

    前言: 简介 开始搭建 命令 API测试 逛github相关的帖子时,发现了hexo.正好想要做一个个人的博客,用来记录自己的各类感悟,所以花一些时间学习学习,以后博客可以放github,省得去注册c ...

  10. 少用 string.Format

    如果你使用的是 C# 6.0 及其以上版本的话我建议你使用新增的 内插字符串 这个功能.这个功能可以更好的帮助开发人员设置字符串格式.下面我们就来看一下为什么要少用 string.Format 而要多 ...