在本教程中,我们将使用Flask来部署PyTorch模型,并用讲解用于模型推断的 REST API。特别是,我们将部署一个预训练的DenseNet 121模

型来检测图像。

备注:

可在GitHub上获取本文用到的完整代码

这是在生产中部署PyTorch模型的系列教程中的第一篇。到目前为止,以这种方式使用Flask是开始为PyTorch模型提供服务的最简单方法,

但不适用于具有高性能要求的用例。因此:

1.定义API

我们将首先定义API端点、请求和响应类型。我们的API端点将位于/ predict,它接受带有包含图像的file参数的HTTP POST请求。响应

将是包含预测的JSON响应:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

2.依赖(包)

运行下面的命令来下载我们需要的依赖:

$ pip install Flask==1.0.3 torchvision-0.3.0

3.简单的Web服务器

以下是一个简单的Web服务器,摘自Flask文档

from flask import Flask
app = Flask(__name__) @app.route('/')
def hello():
return 'Hello World!'

将以上代码段保存在名为app.py的文件中,您现在可以通过输入以下内容来运行Flask开发服务器:

$ FLASK_ENV=development FLASK_APP=app.py flask run

当您在web浏览器中访问http://localhost:5000/时,您会收到文本Hello World的问候!

我们将对以上代码片段进行一些更改,以使其适合我们的API定义。首先,我们将重命名predict方法。我们将端点路径更新为/predict

由于图像文件将通过HTTP POST请求发送,因此我们将对其进行更新,使其也仅接受POST请求:

@app.route('/predict', methods=['POST'])
def predict():
return 'Hello World!'

我们还将更改响应类型,以使其返回包含ImageNet类的id和name的JSON响应。更新后的app.py文件现在为:

from flask import Flask, jsonify
app = Flask(__name__) @app.route('/predict', methods=['POST'])
def predict():
return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

4.推理

在下一部分中,我们将重点介绍编写推理代码。这将涉及两部分,第一部分是准备图像,以便可以将其馈送到DenseNet;第二部分,我们将编

写代码以从模型中获取实际的预测。

4.1 准备图像

DenseNet模型要求图像为尺寸为224 x 224的 3 通道RGB图像。我们还将使用所需的均值和标准偏差值对图像张量进行归一化。你可以点击

这里来了解更多关于它的内容。

我们将使用来自torchvision库的transforms来建立转换管道,该转换管道可根据需要转换图像。您可以在此处

阅读有关转换的更多信息。

import io

import torchvision.transforms as transforms
from PIL import Image def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)

上面的方法以字节为单位获取图像数据,应用一系列变换并返回张量。要测试上述方法,请以字节模式读取图像文件(首先将…/_static/img/

sample_file.jpeg替换为计算机上文件的实际路径),然后查看是否获得了张量:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
tensor = transform_image(image_bytes=image_bytes)
print(tensor)
  • 输出结果:
tensor([[[[ 0.4508,  0.4166,  0.3994,  ..., -1.3473, -1.3302, -1.3473],
[ 0.5364, 0.4851, 0.4508, ..., -1.2959, -1.3130, -1.3302],
[ 0.7077, 0.6392, 0.6049, ..., -1.2959, -1.3302, -1.3644],
...,
[ 1.3755, 1.3927, 1.4098, ..., 1.1700, 1.3584, 1.6667],
[ 1.8893, 1.7694, 1.4440, ..., 1.2899, 1.4783, 1.5468],
[ 1.6324, 1.8379, 1.8379, ..., 1.4783, 1.7352, 1.4612]], [[ 0.5728, 0.5378, 0.5203, ..., -1.3704, -1.3529, -1.3529],
[ 0.6604, 0.6078, 0.5728, ..., -1.3004, -1.3179, -1.3354],
[ 0.8529, 0.7654, 0.7304, ..., -1.3004, -1.3354, -1.3704],
...,
[ 1.4657, 1.4657, 1.4832, ..., 1.3256, 1.5357, 1.8508],
[ 2.0084, 1.8683, 1.5182, ..., 1.4657, 1.6583, 1.7283],
[ 1.7458, 1.9384, 1.9209, ..., 1.6583, 1.9209, 1.6408]], [[ 0.7228, 0.6879, 0.6531, ..., -1.6476, -1.6302, -1.6476],
[ 0.8099, 0.7576, 0.7228, ..., -1.6476, -1.6476, -1.6650],
[ 1.0017, 0.9145, 0.8797, ..., -1.6476, -1.6650, -1.6999],
...,
[ 1.6291, 1.6291, 1.6465, ..., 1.6291, 1.8208, 2.1346],
[ 2.1868, 2.0300, 1.6814, ..., 1.7685, 1.9428, 2.0125],
[ 1.9254, 2.0997, 2.0823, ..., 1.9428, 2.2043, 1.9080]]]])

4.2 预测

现在将使用预训练的DenseNet 121模型来预测图像的类别。我们将使用torchvision库中的一个库,加载模型并进行推断。在此示例中,我们

将使用预训练的模型,但您可以对自己的模型使用相同的方法。在这个教程

中了解有关加载模型的更多信息。

from torchvision import models

# 确保使用`pretrained`作为`True`来使用预训练的权重:
model = models.densenet121(pretrained=True)
# 由于我们仅将模型用于推理,因此请切换到“eval”模式:
model.eval() def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
return y_hat

张量y_hat将包含预测的类的id的索引。但是,我们需要一个易于阅读的类名。为此,我们需要一个类id来命名映射。将该文件

下载为imagenet_class_index.json并记住它的保存位置(或者,如果您按照本教程中的确切步骤操作,请将其保存在tutorials/_static中)。

此文件包含ImageNet类的id到ImageNet类的name的映射。我们将加载此JSON文件并获取预测索引的类的name。

import json

imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]

在使用字典imagenet_class_index之前,首先我们将张量值转换为字符串值,因为字典imagenet_class_index中的keys是字符串。我们将

测试上述方法:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
print(get_prediction(image_bytes=image_bytes))
  • 输出结果:
['n02124075', 'Egyptian_cat']

你会得到这样的一个响应:

['n02124075', 'Egyptian_cat']

数组中的第一项是ImageNet类的id,第二项是人类可读的name。

注意:您是否注意到模型变量不是get_prediction方法的一部分?或者为什么模型是全局变量?就内存和计算而言,加载模型可能是

一项昂贵的操作。如果将模型加载到get_prediction方法中,则每次调用该方法时都会不必要地加载该模型。由于我们正在构建Web服务

器,因此每秒可能有成千上万的请求,因此我们不应该浪费时间为每个推断重复加载模型。因此,我们仅将模型加载到内存中一次。在生

产系统中,必须有效利用计算以能够大规模处理请求,因此通常应在处理请求之前加载模型。

5.将模型集成到我们的API服务器中

在最后一部分中,我们将模型添加到Flask API服务器中。由于我们的API服务器应该获取图像文件,因此我们将更新predict方法以从请求中

读取文件:

from flask import request

@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
# 从请求中获得文件
file = request.files['file']
# 转化为字节
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})

app.py文件现已完成。以下是完整版本;将路径替换为保存文件的路径,它的运行应是如下:

import io
import json from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval() def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0) def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx] @app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name}) if __name__ == '__main__':
app.run()

让我们测试一下我们的web服务器,运行:

$ FLASK_ENV=development FLASK_APP=app.py flask run

我们可以使用requests库来发送一个POST请求到我们的app:

import requests

resp = requests.post("http://localhost:5000/predict",
files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

打印resp.json()会显示下面的结果:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

6.下一步工作

我们编写的服务器非常琐碎,可能无法完成生产应用程序所需的一切。因此,您可以采取一些措施来改善它:

  • 端点/predict假定请求中总会有一个图像文件。这可能不适用于所有请求。我们的用户可能发送带有其他参数的图像,或者根本不发送任何图像。
  • 用户也可以发送非图像类型的文件。由于我们没有处理错误,因此这将破坏我们的服务器。添加显式的错误处理路径来引发异常,这将使我们

    能够更好地处理错误的输入
  • 即使模型可以识别大量类别的图像,也可能无法识别所有图像。增强实现以处理模型无法识别图像中的任何情况的情况。
  • 我们在开发模式下运行Flask服务器,该服务器不适合在生产中进行部署。您可以查看教程

    以在生产环境中部署Flask服务器。
  • 您还可以通过创建一个带有表单的页面来添加UI,该表单可以拍摄图像并显示预测。查看类似项目的演示及其源代码
  • 在本教程中,我们仅展示了如何构建可以一次返回单个图像预测的服务。我们可以修改服务以能够一次返回多个图像的预测。此外,service-streamer

    库自动将对服务的请求排队,并将它们采样到可用于模型的min-batches中。您可以查看此教程
  • 最后,我们鼓励您在页面顶部查看链接到的有关部署PyTorch模型的其他教程。

欢迎关注磐创博客资源汇总站:

http://docs.panchuang.net/

欢迎关注PyTorch官方中文教程站:

http://pytorch.panchuang.net/

OpenCV中文官方文档:

http://woshicver.com/

通过带Flask的REST API在Python中部署PyTorch的更多相关文章

  1. Python+Flask搭建mock api server

    Python+Flask搭建mock api server 前言: 近期由于工作需要,需要一个Mock Server调用接口直接返回API结果: 假如可以先通过接口文档的定义,自己模拟出服务器返回结果 ...

  2. Designing a RESTful API with Python and Flask 201

    rest服务器的搭建 - CSDN博客 http://blog.csdn.net/zhanghaotian2011/article/details/8760794 REST的架构设计 REST(Rep ...

  3. 使用 Python 和 Flask 设计 RESTful API

    近些年来 REST (REpresentational State Transfer) 已经变成了 web services 和 web APIs 的标配. 在本文中我将向你展示如何简单地使用 Pyt ...

  4. 真正搞明白Python中Django和Flask框架的区别

    在谈Python中Django框架和Flask框架的区别之前,我们需要先探讨如下几个问题. 一.为什么要使用框架? 为了更好地阐述这个问题,我们把开发一个应用的过程进行类比,往往开发一个应用(web应 ...

  5. flask开发restful api系列(8)-再谈项目结构

    上一章,我们讲到,怎么用蓝图建造一个好的项目,今天我们继续深入.上一章中,我们所有的接口都写在view.py中,如果几十个,还稍微好管理一点,假如上百个,上千个,怎么找?所有接口堆在一起就显得杂乱无章 ...

  6. flask开发restful api系列(7)-蓝图与项目结构

    如果有几个原因可以让你爱上flask这个极其灵活的库,我想蓝图绝对应该算上一个,部署蓝图以后,你会发现整个程序结构非常清晰,模块之间相互不影响.蓝图对restful api的最明显效果就是版本控制:而 ...

  7. flask开发restful api

    flask开发restful api 如果有几个原因可以让你爱上flask这个极其灵活的库,我想蓝图绝对应该算上一个,部署蓝图以后,你会发现整个程序结构非常清晰,模块之间相互不影响.蓝图对restfu ...

  8. 从底层带你理解Python中的一些内部机制

    下面博文将带你创建一个字节码级别的追踪API以追踪Python的一些内部机制,比如类似YIELDVALUE.YIELDFROM操作码的实现,推式构造列表(List Comprehensions).生成 ...

  9. 学习参考《Flask Web开发:基于Python的Web应用开发实战(第2版)》中文PDF+源代码

    在学习python Web开发时,我们会选择使用Django.flask等框架. 在学习flask时,推荐学习看看<Flask Web开发:基于Python的Web应用开发实战(第2版)> ...

随机推荐

  1. 数据大爆炸:KDD 2016

    2016"> 饕餮盛宴 ACM SIGKDD国际会议(简称KDD)是数据挖掘领域的顶级国际会议,由ACM (计算机协会)的数据挖掘及知识发现专委会(SIGKDD)负责组织筹办.在现今如 ...

  2. Redis list实现原理 - 双向循环链表

    双向链表 双向表示每个节点知道自己的直接前驱和直接后继,每个节点需要三个域 查找方向可以是从左往右也可以是从右往左,但是要实现从右往左还需要终端节点的地址,所以通常会设计成双向的循环链表; 双向的循环 ...

  3. Blind Estimation and Detection of Space-Time Trellis Coded Transmissions over the Rayleigh Fading MIMO Channel

    目录 文章来源 摘要 基本概念 粒子滤波 时间序列模型 系统模型 通信系统 经典状态空间表示 论文所提出的状态空间表示 借鉴之处 文章来源 IEEE TRANSACTIONS ON COMMUNICA ...

  4. Spring Boot从入门到精通(六)集成Redis实现缓存机制

    Redis(Remote Dictionary Server ),即远程字典服务,是一个开源的使用ANSI C语言编写.支持网络.可基于内存亦可持久化的日志型.Key-Value数据库,并提供多种语言 ...

  5. JDK 1.8 新特性之Date-Time API

    来源:请点击查看 1.8之前的日期类: 线程不安全:java.util.Date 这个类线程不安全,而且所有日期类都是可变的. 时间处理麻烦:默认的开始日期从1900年,不支持国际化,不提供时区支持, ...

  6. SpringBoot入门系列(三)资源文件属性配置

    前面介绍了Spring的@Controller和@RestController控制器, 他们是如何响应客户端请求,如何返回json数据.不清楚的朋友可以看看之前的文章:https://www.cnbl ...

  7. 【30分钟学完】canvas动画|游戏基础(6):坐标旋转探究

    前言 本篇主要讲坐标旋转及其应用,这是编程动画必不可少的技术. 阅读本篇前请先打好前面的基础. 本人能力有限,欢迎牛人共同讨论,批评指正. 坐标旋转 模拟场景:已知一个中心点(centerX,cent ...

  8. Django进行数据迁移时,报错:(1064, "You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near '(6) NOT NULL)' at line 1")

    进行数据迁移时: 第一步: 命令:python manage.py makemigrations 在对应的应用里面的migrations文件夹中产生了一个0001_initial.py文件 第二步:执 ...

  9. 学以致用:手把手教你撸一个工具库并打包发布,顺便解决JS浮点数计算精度问题

    本文讲解的是怎么实现一个工具库并打包发布到npm给大家使用.本文实现的工具是一个分数计算器,大家考虑如下情况: \[ \sqrt{(((\frac{1}{3}+3.5)*\frac{2}{9}-\fr ...

  10. VUE一 基础语法

    一.模板语法 二.Class和Style绑定 三.条件渲染 四.vue事件处理器 五.vue组件