本文是基于TensorRT 5.0.2基础上,关于其内部的end_to_end_tensorflow_mnist例子的分析和介绍。

1 引言

假设当前路径为:

TensorRT-5.0.2.6/samples

其对应当前例子文件目录树为:

# tree python

python
├── common.py
├── end_to_end_tensorflow_mnist
│   ├── model.py
│   ├── README.md
│   ├── requirements.txt
│   └── sample.py

2 基于tensorflow生成模型

其中只有2个文件:

  • model:该文件包含简单的训练模型代码
  • sample:该文件使用UFF mnist模型去创建一个TensorRT inference engine

首先介绍下model.py

# 该脚本包含一个简单的模型训练过程
import tensorflow as tf
import numpy as np '''main中第一步:获取数据集 '''
def process_dataset(): # 导入mnist数据集
# 手动下载aria2c -x 16 https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
# 将mnist.npz移动到~/.keras/datasets/
# tf.keras.datasets.mnist.load_data会去读取~/.keras/datasets/mnist.npz,而不从网络下载
(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 # Reshape
NUM_TRAIN = 60000
NUM_TEST = 10000
x_train = np.reshape(x_train, (NUM_TRAIN, 28, 28, 1))
x_test = np.reshape(x_test, (NUM_TEST, 28, 28, 1))
return x_train, y_train, x_test, y_test '''main中第二步:构建模型 '''
def create_model(): model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=[28,28, 1]))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(512, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
return model '''main中第五步:模型存储 '''
def save(model, filename): output_names = model.output.op.name
sess = tf.keras.backend.get_session() # freeze graph
frozen_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output_names]) # 移除训练的节点
frozen_graph = tf.graph_util.remove_training_nodes(frozen_graph) # 保存模型
with open(filename, "wb") as ofile:
ofile.write(frozen_graph.SerializeToString()) def main(): ''' 1 - 获取数据'''
x_train, y_train, x_test, y_test = process_dataset() ''' 2 - 构建模型'''
model = create_model() ''' 3 - 模型训练'''
model.fit(x_train, y_train, epochs = 5, verbose = 1) ''' 4 - 模型评估'''
model.evaluate(x_test, y_test) ''' 5 - 模型存储'''
save(model, filename="models/lenet5.pb") if __name__ == '__main__':
main()

在获得

models/lenet5.pb

之后,执行下述命令,将其转换成uff文件,输出结果如

'''该converter会显示关于input/output nodes的信息,这样你就可以用来在解析的时候进行注册;
本例子中,我们基于tensorflow.keras的命名规则,事先已知input/output nodes名称了 ''' [root@30d4bceec4c4 end_to_end_tensorflow_mnist]# convert-to-uff models/lenet5.pb
Loading models/lenet5.pb

3 基于tensorflow的pb文件生成UFF并处理

# 该例子使用UFF MNIST 模型去创建一个TensorRT Inference Engine
from random import randint
from PIL import Image
import numpy as np import pycuda.driver as cuda
import pycuda.autoinit # 该import会让pycuda自动管理CUDA上下文的创建和清理工作 import tensorrt as trt import sys, os
# sys.path.insert(1, os.path.join(sys.path[0], ".."))
# import common # 这里将common中的GiB和find_sample_data,allocate_buffers,do_inference等函数移动到该py文件中,保证自包含。
def GiB(val):
'''以GB为单位,计算所需要的存储值,向左位移10bit表示KB,20bit表示MB '''
return val * 1 << 30 def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[]):
'''该函数就是一个参数解析函数。
Parses sample arguments.
Args:
description (str): Description of the sample.
subfolder (str): The subfolder containing data relevant to this sample
find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.
Returns:
str: Path of data directory.
Raises:
FileNotFoundError
'''
# 为了简洁,这里直接将路径硬编码到代码中。
data_root = kDEFAULT_DATA_ROOT = os.path.abspath("/TensorRT-5.0.2.6/python/data/") subfolder_path = os.path.join(data_root, subfolder)
if not os.path.exists(subfolder_path):
print("WARNING: " + subfolder_path + " does not exist. Using " + data_root + " instead.")
data_path = subfolder_path if os.path.exists(subfolder_path) else data_root if not (os.path.exists(data_path)):
raise FileNotFoundError(data_path + " does not exist.") for index, f in enumerate(find_files):
find_files[index] = os.path.abspath(os.path.join(data_path, f))
if not os.path.exists(find_files[index]):
raise FileNotFoundError(find_files[index] + " does not exist. ") if find_files:
return data_path, find_files
else:
return data_path
#----------------- TRT_LOGGER = trt.Logger(trt.Logger.WARNING) class ModelData(object):
MODEL_FILE = os.path.join(os.path.dirname(__file__), "models/lenet5.uff")
INPUT_NAME ="input_1"
INPUT_SHAPE = (1, 28, 28)
OUTPUT_NAME = "dense_1/Softmax" '''main中第二步:构建engine'''
def build_engine(model_file): with trt.Builder(TRT_LOGGER) as builder, \
builder.create_network() as network, \
trt.UffParser() as parser: builder.max_workspace_size = GiB(1) # 解析 Uff 网络
parser.register_input(ModelData.INPUT_NAME, ModelData.INPUT_SHAPE)
parser.register_output(ModelData.OUTPUT_NAME)
parser.parse(model_file, network) # 构建并返回一个engine
return builder.build_cuda_engine(network) '''main中第三步 '''
def allocate_buffers(engine): inputs = []
outputs = []
bindings = []
stream = cuda.Stream() for binding in engine: size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
dtype = trt.nptype(engine.get_binding_dtype(binding)) # 分配host和device端的buffer
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes) # 将device端的buffer追加到device的bindings.
bindings.append(int(device_mem)) # Append to the appropriate list.
if engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem, device_mem))
else:
outputs.append(HostDeviceMem(host_mem, device_mem)) return inputs, outputs, bindings, stream '''main中第四步 '''
# 从pagelocked_buffer.中读取测试样本
def load_normalized_test_case(data_path, pagelocked_buffer, case_num=randint(0, 9)): test_case_path = os.path.join(data_path, str(case_num) + ".pgm") # Flatten该图像成为一个1维数组,然后归一化,并copy到host端的 pagelocked内存中.
img = np.array(Image.open(test_case_path)).ravel()
np.copyto(pagelocked_buffer, 1.0 - img / 255.0) return case_num '''main中第五步:执行inference '''
# 该函数可以适应多个输入/输出;输入和输出格式为HostDeviceMem对象组成的列表
def do_inference(context, bindings, inputs, outputs, stream, batch_size=1): # 将数据移动到GPU
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] # 执行inference.
context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle) # 将结果从 GPU写回到host端
[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] # 同步stream
stream.synchronize() # 返回host端的输出结果
return [out.host for out in outputs] def main(): ''' 1 - 寻找模型文件'''
data_path = find_sample_data(
description="Runs an MNIST network using a UFF model file",
subfolder="mnist")
model_file = ModelData.MODEL_FILE ''' 2 - 基于build_engine函数构建engine'''
with build_engine(model_file) as engine: ''' 3 - 分配buffer并创建一个流'''
inputs, outputs, bindings, stream = allocate_buffers(engine) with engine.create_execution_context() as context: ''' 4 - 读取测试样本,并归一化'''
case_num = load_normalized_test_case(data_path, pagelocked_buffer=inputs[0].host) ''' 5 - 执行inference,do_inference函数会返回一个list类型,此处只有一个元素'''
[output] = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream) pred = np.argmax(output)
print("Test Case: " + str(case_num))
print("Prediction: " + str(pred)) if __name__ == '__main__':
main()

结果如:

TensorRT&Sample&Python[end_to_end_tensorflow_mnist]的更多相关文章

  1. TensorRT&Sample&Python[yolov3_onnx]

    本文是基于TensorRT 5.0.2基础上,关于其内部的yolov3_onnx例子的分析和介绍. 本例子展示一个完整的ONNX的pipline,在tensorrt 5.0的ONNX-TensorRT ...

  2. TensorRT&Sample&Python[uff_custom_plugin]

    本文是基于TensorRT 5.0.2基础上,关于其内部的uff_custom_plugin例子的分析和介绍. 本例子展示如何使用cpp基于tensorrt python绑定和UFF解析器进行编写pl ...

  3. TensorRT&Sample&Python[fc_plugin_caffe_mnist]

    本文是基于TensorRT 5.0.2基础上,关于其内部的fc_plugin_caffe_mnist例子的分析和介绍. 本例子相较于前面例子的不同在于,其还包含cpp代码,且此时依赖项还挺多.该例子展 ...

  4. TensorRT&Sample&Python[network_api_pytorch_mnist]

    本文是基于TensorRT 5.0.2基础上,关于其内部的network_api_pytorch_mnist例子的分析和介绍. 本例子直接基于pytorch进行训练,然后直接导出权重值为字典,此时并未 ...

  5. TensorRT&Sample&Python[introductory_parser_samples]

    本文是基于TensorRT 5.0.2基础上,关于其内部的introductory_parser_samples例子的分析和介绍. 1 引言 假设当前路径为: TensorRT-5.0.2.6/sam ...

  6. tensorRT使用python进行网络定义

  7. Python API vs C++ API of TensorRT

    Python API vs C++ API of TensorRT 本质上,C++ API和Python API应该在支持您的需求方面接近相同.pythonapi的主要优点是数据预处理和后处理都很容易 ...

  8. (原)pytorch中使用TensorRT

    转载请注明出处: https://www.cnblogs.com/darkknightzh/p/11332155.html 代码网址: https://github.com/darkknightzh/ ...

  9. 基于TensorRT 3的自动驾驶快速INT8推理

    基于TensorRT 3的自动驾驶快速INT8推理 Fast INT8 Inference for Autonomous Vehicles with TensorRT 3 自主驾驶需要安全性,需要一种 ...

随机推荐

  1. mysql数据库备份并且实现远程复制

    一.实现ssh 远程登陆 机器环境: 192.167.33.108 clent 用户:crawler 192.167.33.77 server 用户:crawler 1.客户端 生成密钥 /home/ ...

  2. Springboot 系列(三)Spring Boot 自动配置原理

    注意:本 Spring Boot 系列文章基于 Spring Boot 版本 v2.1.1.RELEASE 进行学习分析,版本不同可能会有细微差别. 前言 关于配置文件可以配置的内容,在 Spring ...

  3. Runloop详解

    RunLoop是iOS和OSX开发中非常基础的一个概念,这篇文章将从源码以及应用入手,介绍RunLoop的概念以及底层实现原理.本人看了一下RunLoop的英语源码,以及借鉴部分优秀博客,感谢!读完这 ...

  4. Java开发笔记(十七)各得其所的多路分支

    前面提到条件语句的标准格式为“if (条件) { /* 条件成立时的操作代码 */ } else { /* 条件不成立时的操作代码 */ }”,乍看之下仿佛只有两个分支,一个是条件成立时的分支,另一个 ...

  5. Acer宏碁笔记本触摸板失效解决方法

    打开windows设置,找到鼠标设置 之后,选择触摸板设置,将其开启 PS: 由于我是安装完驱动之后,才发现有这个触摸板设置的 如果找不到这个触摸板设置的哈,应该就是驱动安装完之后就有了 驱动的话去官 ...

  6. Java基础:HashMap假死锁问题的测试、分析和总结

    前言 前两天在公司的内部博客看到一个同事分享的线上服务挂掉CPU100%的文章,让我联想到HashMap在不恰当使用情况下的死循环问题,这里做个整理和总结,也顺便复习下HashMap. 直接上测试代码 ...

  7. localhost和127.0.01 区别

    笔者最近调试程序时遇到的一个问题,localhost不能访问但127.0.0.1可以访问. 一.原理 我估计大多数人都不会去想localhost到底与127.0.0.1有什么不同,就比如我,有时候用h ...

  8. (二)阿里云ECS Linux服务器外网无法连接MySQL解决方法(报错2003- Can't connect MySQL Server on 'x.x.x.x'(10038))(自己亲身遇到的问题是防火墙的问题已经解决)

    我的服务器买的是阿里云ECS linux系统.为了更好的操作数据库,我希望可以用navicat for mysql管理我的数据库. 当我按照正常的模式去链接mysql的时候, 报错提示: - Can' ...

  9. vue项目利用apicloud打包成apk过程

    最近公司要求我们用apicloud做一个app,正好利用这个机会学习下app的制作过程~ 页面的开发过程跟我们平时开发一样,利用vue把页面全部完成,最后进行npm run build将项目打包. 接 ...

  10. vue+axios 前端实现的常用拦截

    一.路由拦截使用 首先在定义路由的时候就需要多添加一个自定义字段requireAuth,用于判断该路由的访问是否需要登录.如果用户已经登录,则顺利进入路由,否则就进入登录页面,路由配置如下: cons ...