前言

​ 当一个TensorFlow模型训练出来的时候,为了投入到实际应用,所以就需要部署到服务器上。由于我本次所做的项目是一个javaweb的图像识别项目。所有我就想去寻找一下java调用TensorFlow训练模型的办法。

由于TensorFlow很久没更新的缘故,网上的博客大都是18/19年的,并且是基于TensorFlow1.0的,对于现在使用的TensorFlow2.0不太友好。

下面我简述一下TensorFlow1.0时期的方法:

1.动态模型生成不便

需要将训练的.h5模型转换成.pb模型,并且需要自己定义.pb模型的输入输出参数。(pb模型是一种基于动态图的模型)

pb的生成代码冗长、而且对初学者真滴不太友好

相比之下.h5模型的生成代码就一行

此外,这个生成pb模型的代码是否能照搬使用,还是一个问题,并且还可能报一些奇奇怪怪的错误。

2.maven导包不便

查阅资料发现java上的TensorFlow的jar包都是TensorFlow1.0的

现状:

并且maven官网上的TensorFlow2.0的api已经改名成了tensorflow-core-api,并且网上相关方面的教程十分难找。由于网上都是导入的1.0的包,自己导入2.0的包之后,详细的调用教程可以说是没有。从上面也可以看出来TensorFlow对java的调用也不怎么重视了。所以这又给学习的途中徒增了很多困难。

全新思路

思路一

用java直接调用训练好的模型很困难,那么我们想办法让java调用python脚本,让python脚本去调用.h5模型会不会更简单呢?

代码如下

  1. package com.guard.service;
  2. import java.io.BufferedReader;
  3. import java.io.IOException;
  4. import java.io.InputStreamReader;
  5. public class api_service {
  6. public String recognize(String path){
  7. //此处的path是图片路径
  8. Process proc;
  9. String res = null;
  10. try {
  11. System.out.println("接受到的参数"+path);
  12. String[] cmd = new String[] { "python", "E:\\machine_learning\\predict.py", path};
  13. proc = Runtime.getRuntime().exec(cmd);
  14. BufferedReader in = new BufferedReader(new InputStreamReader(proc.getInputStream()));
  15. String line = null;
  16. while ((line = in.readLine()) != null) {
  17. System.out.println(line);
  18. res = line;
  19. }
  20. in.close();
  21. proc.waitFor();
  22. } catch (IOException e) {
  23. e.printStackTrace();
  24. } catch (InterruptedException e) {
  25. e.printStackTrace();
  26. }
  27. System.out.println(res+">>>>>>>>>>>");
  28. return res;
  29. }
  30. }

但是我们可以看出,这个其实是用java在win上跑了这样一个指令

虽然这个确实是一个好办法,但是这个路径参数需要事先知道服务器上的路径,并且在协作开发的时候,每个人的路径和环境就不同,虽然该方法能用,但是我认为还不够好。

思路二

我们可以直接用python的flask框架,直接生成一个api接口,就可以远程直接调用TensorFlow训练好的模型进行结果预测。

个人认为,这种方法相较于用java调用命令行,这种方法还是更加直观的

并且flask仅仅需要加个@app.route的注解就能实现,可谓是十分方便

下面是模型调用代码

model.py

  1. import glob
  2. import sys
  3. import os
  4. import cv2
  5. import numpy as np
  6. import tensorflow as tf
  7. import image_processing
  8. def model_ues(path):
  9. # 缩放图片大小为100*100
  10. w = 100
  11. h = 100
  12. # 测试图像的地址 (改为自己的)
  13. # path_test = "resource/test24.jpg"
  14. api_token = "fklasjfljasdlkfjlasjflasjfljhasdljflsdjflkjsadljfljsda"
  15. path_test = image_processing.download_img(path,api_token)
  16. # 创建保存图像的空列表
  17. imgs = []
  18. img = cv2.imread(path_test)
  19. img = cv2.resize(img, (w, h))
  20. # 将每张经过处理的图像数据保存在之前创建的imgs空列表当中
  21. imgs.append(img)
  22. imgs = np.asarray(imgs, np.float32)
  23. # print("shape of data:",imgs.shape)
  24. # 导入模型
  25. model = tf.keras.models.load_model(r"resource/rice_0.93.h5")
  26. # 创建图像标签列表
  27. rice_dict = {0: 'Rice blast', 1: 'Rice fleck',
  28. 2: 'Rice koji disease', 3: 'Sheath blight'}
  29. # 将图像导入模型进行预测
  30. prediction = model.predict_classes(imgs)
  31. # prediction = np.argmax(model.predict(imgs), axis=-1)
  32. # 绘制预测图像
  33. for i in range(np.size(prediction)):
  34. # 打印每张图像的预测结果
  35. print(rice_dict[prediction[i]])
  36. return rice_dict[prediction[0]]

为了实现图片外链接受,下面是图片下载脚本

image_processing.py

  1. # coding: utf8
  2. import requests
  3. import random
  4. def download_img(img_url, api_token):
  5. print (img_url)
  6. header = {"Authorization": "Bearer " + api_token} # 设置http header,视情况加需要的条目,这里的token是用来鉴权的一种方式
  7. r = requests.get(img_url, headers=header, stream=True)
  8. print(r.status_code) # 返回状态码
  9. file_img = 'resource/img.png'
  10. # file_img = 'resource/'
  11. print(file_img)
  12. if r.status_code == 200:
  13. open(file_img, 'wb').write(r.content) # 将内容写入图片
  14. print("done")
  15. del r
  16. return file_img
  17. # if __name__ == '__main__':
  18. # # 下载要的图片
  19. # img_url = "https://z3.ax1x.com/2021/07/27/W5l6Qe.png"
  20. # api_token = "fklasjfljasdlkfjlasjflasjfljhasdljflsdjflkjsadljfljsda"
  21. # download_img(img_url, api_token)

主程序脚本

app.py

  1. from flask import Flask,render_template, url_for, request, json,jsonify
  2. import model
  3. app = Flask(__name__)
  4. #设置编码
  5. app.config['JSON_AS_ASCII'] = False
  6. @app.route('/test')
  7. def hello_world():
  8. return "hello world"
  9. @app.route('/predict', methods=['GET', 'POST'])
  10. def form_data():
  11. my_path = request.form['path']
  12. print(my_path)
  13. str = model.model_ues(my_path)
  14. print("http://127.0.0.1:5000/predict")
  15. return jsonify({'result':str,'msg':'200'})
  16. if __name__ == '__main__':
  17. app.run()

数据解析

虽然我们能够通过postman进行测试接受到回传的结果,但是我们要怎么用java实现呢??

1.使用postman生成大致代码框架(postman生成的代码可能不能直接运行)

这里我选用的是java-okhttp的方法,但其实使用Unirest写出来的代码更加简洁易懂。

  1. public class Get_result {
  2. public String getResult(String path) throws IOException {
  3. // String path = "https://i.loli.net/2021/07/29/badDNR2OCironUf.jpg";
  4. OkHttpClient client = new OkHttpClient().newBuilder()
  5. .build();
  6. MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
  7. RequestBody body = RequestBody.create(mediaType, "path="+path);
  8. Request request = new Request.Builder()
  9. .url("http://127.0.0.1:8000/predict")
  10. .method("POST", body)
  11. .addHeader("Content-Type", "application/x-www-form-urlencoded")
  12. .build();
  13. Response response = client.newCall(request).execute();
  14. String result = response.body().string();
  15. System.out.println(result);
  16. }
  17. }
  1. {
  2. "msg": "200",
  3. "result": "Rice fleck"
  4. }

获取到json数据之后,就需要对json数据进行解析

java上的解析原理是,先按照json编写一个类,之后用Gson对接受到的数据按照这个类进行规范化

(这里可以用GsonFormatPlus插件来自动生成这个实体类)

  1. //Rice_result.java---为该json的实体类
  2. package com.guard.tool;
  3. import lombok.Data;
  4. import lombok.NoArgsConstructor;
  5. @NoArgsConstructor
  6. @Data
  7. public class Rice_result {
  8. private String msg;
  9. private String result;
  10. }

下面是数据解析代码(和上面的okhttp获取json数据的代码连起来看)

  1. //json数据解析
  2. Gson gson = new Gson();
  3. java.lang.reflect.Type type = new TypeToken<Rice_result>(){}.getType();
  4. Rice_result rice_result = gson.fromJson(result, type);
  5. System.out.println(rice_result);
  6. if("200".equals(rice_result.getMsg())){
  7. // System.out.println(rice_result.getResult());
  8. return Rice_result.convertdata(rice_result.getResult());
  9. }else {
  10. // System.out.println("获取结果出错!!");
  11. return "获取结果出错!!";
  12. }

这样的话就可以进行json数据的解析了。

图链制作

由于需要使用java发送post请求给flask的预测端口,那么就需要把本地上传的数据做成图链,把图链作为数据传给flask的预测端口,从而来接收结果。

由于前端js的知识大多遗忘,这里就选用了用java来发送一个post请求,获得回传的信息。

这里我使用的是sm.ms的图床(该图床无需登录,且速度快,算得上是一个好的选择)

  1. //sm.ms的使用方法,建议看官方文档
  2. package com.guard.tool;
  3. import com.google.gson.Gson;
  4. import com.google.gson.reflect.TypeToken;
  5. import okhttp3.*;
  6. import java.io.File;
  7. import java.io.IOException;
  8. public class CloudUpload {
  9. public String toUrl(String path) throws IOException {
  10. // String file_path = "E:/machine_learning/test8.jpg";
  11. String file_path = path;
  12. OkHttpClient client = new OkHttpClient().newBuilder()
  13. .build();
  14. MediaType mediaType = MediaType.parse("multipart/form-data");
  15. RequestBody body = new MultipartBody.Builder().setType(MultipartBody.FORM)
  16. .addFormDataPart("smfile",file_path,
  17. RequestBody.create(MediaType.parse("application/octet-stream"),
  18. new File(file_path)))
  19. .addFormDataPart("format","json")
  20. .build();
  21. Request request = new Request.Builder()
  22. .url("https://sm.ms/api/v2/upload")
  23. .method("POST", body)
  24. .addHeader("Content-Type", "multipart/form-data")
  25. .addHeader("Authorization", "TlxzRSaVJj0o7HFZOd9sgdf4Jl60RA00")
  26. //这里的user-agent和Cookie需要自己打开网站,到网站的页面去拿取
  27. .addHeader("user-agent","Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36")
  28. .addHeader("Cookie", "SMMSrememberme=42417%3A10e8e9cb5281082b493fdee73381aeb2dca0bd3d; PHPSESSID=1gjog2em3ogof23vrqi79vd41m; SM_FC=runWNk3mPIiL8mzl%2FrlEfzM940LRKjLm182cm2qDrm4%3D")
  29. .build();
  30. Response response = client.newCall(request).execute();
  31. String result = response.body().string();
  32. System.out.println(result);
  33. // String result = response.body().string();
  34. Gson gson = new Gson();
  35. java.lang.reflect.Type type = new TypeToken<Image_data>(){}.getType();
  36. Image_data imge_data = gson.fromJson(result, type);
  37. System.out.println(imge_data);
  38. if (imge_data.getSuccess()){
  39. System.out.println(imge_data.getData().getUrl());
  40. return imge_data.getData().getUrl();
  41. }
  42. else{
  43. System.out.println("图片已经上传过一次!!");
  44. System.out.println(imge_data.getImages());
  45. return imge_data.getImages();
  46. }
  47. }
  48. }

回传的json结果--这个就需要使用上面的插件来进行处理

  1. {
  2. "success": true,
  3. "code": "success",
  4. "message": "Upload success.",
  5. "data": {
  6. "file_id": 0,
  7. "width": 192,
  8. "height": 454,
  9. "filename": "test25.jpg",
  10. "storename": "xICPNzFsfth5uJk.png",
  11. "size": 124993,
  12. "path": "/2021/08/01/xICPNzFsfth5uJk.png",
  13. "hash": "2exIdQGvBru46RKMyNjg3DhCTO",
  14. "url": "https://i.loli.net/2021/08/01/xICPNzFsfth5uJk.png",
  15. "delete": "https://sm.ms/delete/2exIdQGvBru46RKMyNjg3DhCTO",
  16. "page": "https://sm.ms/image/xICPNzFsfth5uJk"
  17. },
  18. "RequestId": "9BFE9DEB-8370-44C8-A8AF-AAB2DB753A18"
  19. }

总结

以上就是我这次在小组编写<基于CNN图像分类的水稻病虫害识别>这个项目中的收获。在此记录下学习路上踩过的一些坑和一些解决方法。

TensorFlow模型部署到服务器---TensorFlow2.0的更多相关文章

  1. 【tensorflow-转载】tensorflow模型部署系列

    参考 1. tensorflow模型部署系列: 完

  2. 移动端目标识别(1)——使用TensorFlow Lite将tensorflow模型部署到移动端(ssd)之TensorFlow Lite简介

    平时工作就是做深度学习,但是深度学习没有落地就是比较虚,目前在移动端或嵌入式端应用的比较实际,也了解到目前主要有 caffe2,腾讯ncnn,tensorflow,因为工作用tensorflow比较多 ...

  3. 移动端目标识别(2)——使用TENSORFLOW LITE将TENSORFLOW模型部署到移动端(SSD)之TF Lite Developer Guide

    TF Lite开发人员指南 目录: 1 选择一个模型 使用一个预训练模型 使用自己的数据集重新训练inception-V3,MovileNet 训练自己的模型 2 转换模型格式 转换tf.GraphD ...

  4. 移动端目标识别(3)——使用TensorFlow Lite将tensorflow模型部署到移动端(ssd)之Running on mobile with TensorFlow Lite (写的很乱,回头更新一个简洁的版本)

    承接移动端目标识别(2) 使用TensorFlow Lite在移动设备上运行         在本节中,我们将向您展示如何使用TensorFlow Lite获得更小的模型,并允许您利用针对移动设备优化 ...

  5. 将训练好的Tensorflow模型部署到web应用中

    做一个简易web使用Flask是最好的选择,不仅上手快,使用也很便利.Django很强大也很好用,但一次就会创建一个项目的所需的文件,我觉得对于测试一个模型在web端有没有效果没必要用它. flask ...

  6. 吴裕雄--天生自然python TensorFlow图片数据处理:解决TensorFlow2.0 module ‘tensorflow’ has no attribute ‘python_io’

    tf.python_io出错 TensorFlow 2.0 中使用 Python_io 暂时使用如下指令: tf.compat.v1.python_io.TFRecordWriter(filename ...

  7. 一文上手TensorFlow2.0(一)

    目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU) Te ...

  8. 学习笔记TF022:产品环境模型部署、Docker镜像、Bazel工作区、导出模型、服务器、客户端

    产品环境模型部署,创建简单Web APP,用户上传图像,运行Inception模型,实现图像自动分类. 搭建TensorFlow服务开发环境.安装Docker,https://docs.docker. ...

  9. tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署

    TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...

随机推荐

  1. Redis 性能问题分析

    在一些网络服务的系统中,Redis 的性能,可能是比 MySQL 等硬盘数据库的性能更重要的课题.比如微博,把热点微博[1],最新的用户关系,都存储在 Redis 中,大量的查询击中 Redis,而不 ...

  2. Jenkins 进阶篇 - 节点配置

    当我们使用 Jenkins 构建的项目达到一定规模后,一个 Jenkins 服务可能承受不了负载,会导致很多的构建任务堆积,严重的话还会拖垮这台服务器,导致上面的服务无法使用.例如我们公司目前在 Je ...

  3. Tkinter 吐槽之一:多线程与 UI 交互

    背景 最近想简单粗暴的用 Python 写一个 GUI 的小程序.因为 Tkinter 是 Python 自带的 GUI 解决方案,为了部署方便,就直接选择了 Tkinter. 本来觉得 GUI 发展 ...

  4. CCF CSP认证考试在线评测系统

    关于 CCF CSP 认证考试在线评测系统 CCF CSP 认证考试简介 CCF 是中国计算机学会的简称.CCF 计算机软件能力认证(简称 CCF CSP 认证考试)是 CCF 于 2014 年推出, ...

  5. js笔记22

    1.在拖拽元素的时候,如果元素的内部加了文字或者图片,拖拽效果会失灵? 浏览器会给文字和图片一个默认行为,当文字和图片被选中的时候,会有一个拖拽的效果,即使我们没有人为给他添加.所以当我们点击这个元素 ...

  6. Linux中系统时间同步ntpdate简介

    Linux服务器运行久时,系统时间就会存在一定的误差,一般情况下可以使用date命令进行时间设置,但在做数据库集群分片等操作时对多台机器的时间差是有要求的,此时就需要使用ntpdate进行时间同步.所 ...

  7. salesforce零基础学习(一百零四)Salesforce Optimizer

    本篇参考: https://admin.salesforce.com/blog/2017/analyzing-org-salesforce-optimizer-webinar-recap 假设你在做一 ...

  8. Gym 100008E Harmonious Matrices 高斯消元

    POJ 1222 高斯消元更稳 看这个就懂了 #include <bits/stdc++.h> using namespace std; const int maxn = 2000; in ...

  9. nginx限流模块(防范DDOS攻击)

    Nginx限流模式(防范DDOS攻击) nginx中俩个限流模块: 1.ngx_http_limit_req_module(按请求速率限流) 2.ngx_http_limit_conn_module( ...

  10. php+redis实现全页缓存系统

    php redis 实现全页缓存系统之前的一个项目说的一个功能,需要在后台预先存入某个页面信息放到数据库,比如app的注册协议,用户协议,这种.然后在写成一个php页面,app在调用接口的时候访问这个 ...