通过类来实现多session 运行
#xilerihua
import tensorflow as tf
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import sys #objectlocation
import six.moves.urllib as urllib
import tarfile
import matplotlib
matplotlib.use('Agg')
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from utils import label_map_util
from utils import visualization_utils as vis_util
import time class multi():
"""初始化所有模型"""
def __init__(self):
# 加载faster_rcnn 计算图
self.faster_graph = tf.Graph()
with self.faster_graph.as_default():
self.od_graph_def2 = tf.GraphDef()
with tf.gfile.GFile(r'E:/Project/TaoBaoLocation_new/research/object_detection/faster_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb', 'rb') as fid:
self.serialized_graph = fid.read()
self.od_graph_def2.ParseFromString(self.serialized_graph)
tf.import_graph_def(self.od_graph_def2, name='')
self.faster_sess = tf.Session(graph=self.faster_graph) # 加载inception_v3计算图
self.inception_graph = tf.Graph()
with self.inception_graph.as_default():
self.od_graph_def2 = tf.GraphDef()
with tf.gfile.GFile(r'E:/Project/XiLeRiHuaReg/inception_v3_model/output_graph.pb', 'rb') as fid:
self.serialized_graph = fid.read()
self.od_graph_def2.ParseFromString(self.serialized_graph)
tf.import_graph_def(self.od_graph_def2, name='')
self.inception_sess = tf.Session(graph=self.inception_graph) def get_result(self, type, image_path):
if type == '2':
#xilerihua
lines = tf.gfile.GFile('E:/Project/XiLeRiHuaReg/inception_v3_model/output_labels.txt').readlines()
uid_to_human = {}
for uid, line in enumerate(lines):
line = line.strip('\n')
uid_to_human[uid] = line def id_to_string(node_id):
if node_id not in uid_to_human:
return ''
return uid_to_human[node_id] softmax_tensor = self.inception_sess.graph.get_tensor_by_name('final_result:0') image_data = tf.gfile.GFile(image_path, 'rb').read()
predictions = self.inception_sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions) # image_path = os.path.join(sys.argv[2]) top_k = predictions.argsort()[::-1][:1] # 取前k个,此处取最相似的那个 for node_id in top_k: # 只取第一个
human_string = id_to_string(node_id)
score = predictions[node_id] human_kanji = {
'baby wipes': '婴儿湿巾',
'bath towel': '洗澡巾',
'convenient toothpick box': '便捷牙具盒',
'dish rack': '沥水架',
'hooks4': '挂钩粘钩4个装',
'kitchen towel': '厨房方巾',
'towel': '毛巾',
'macaron basin': '马卡龙家用多用盆',
'multi functional dental box': '多功能牙具盒',
'paring knife': '削皮刀',
'pineapple towel set': '菠萝纹毛巾浴巾套装',
'rubbish bag': '垃圾袋',
'sponge': '清洁海绵',
'stainless hook': '不锈钢多用挂钩',
'storage boxes': '三格储物盒',
'towel set': '毛巾浴巾套装',
'usb cable': '数据线',
'liquor': '劲酒'
}
thres = 0.6
if score < thres:
print('不在17个范围之内')
elif human_kanji[human_string] == '劲酒':
print('不在17个范围之内')
else:
print(human_kanji[human_string]) if type == '1': # List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') NUM_CLASSES = 90 ##################### Loading label map
# print('Loading label map...')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
use_display_name=True)
category_index = label_map_util.create_category_index(categories) ##################### Helper code
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8) ##################### Detection
# 测试图片的路径,可以根据自己的实际情况修改
# TEST_IMAGE_PATH = 'test_images/image1.jpg'
TEST_IMAGE_PATH = image_path
# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8) # with tf.Session(graph=self.faster_graph) as self.faster_sess:
# print(TEST_IMAGE_PATH)
image = Image.open(TEST_IMAGE_PATH)
image_np = load_image_into_numpy_array(image)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = self.faster_graph.get_tensor_by_name('image_tensor:0')
boxes = self.faster_graph.get_tensor_by_name('detection_boxes:0')
scores = self.faster_graph.get_tensor_by_name('detection_scores:0')
classes = self.faster_graph.get_tensor_by_name('detection_classes:0')
num_detections = self.faster_graph.get_tensor_by_name('num_detections:0') # Actual detection.
(boxes, scores, classes, num_detections) = self.faster_sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded}) scores = np.squeeze(scores)
scores = scores.reshape((100, 1))
boxes = np.squeeze(boxes)
res = np.hstack((boxes, scores)) # 筛选>thres的box
thres = 0.55
reserve_boxes_0 = []
for b in res:
if b[-1]>thres:
reserve_boxes_0.append(b.tolist()) # print('reserve_boxes_0:',reserve_boxes_0) #转换坐标
reserve_boxes=[]
w = image_np.shape[1] # 1,3乘 1024
h = image_np.shape[0] # 0,2乘 636
# print('h:',h,'w:',w) for box in reserve_boxes_0:
# print([int(float(box[0]*h)),int(float(box[2]*h)),int(float(box[1]*w)),int(float(box[3]*w))],'tran')
# reserve_boxes.append([int(float(box[0]*h)),int(float(box[2]*h)),int(float(box[1]*w)),int(float(box[3]*w))])
reserve_boxes.append([int(float(box[1]*w)),int(float(box[0]*h)),int(float(box[3]*w)),int(float(box[2]*h))]) # print('reserve_boxes:',reserve_boxes) #没有找到一个框的情况
if len(reserve_boxes)==0:#为0的情况,裁剪返回图片坐标
w_subtract = int(image_np.shape[1] / 10)
h_subtract = int(image_np.shape[0] / 10)
print(w_subtract, h_subtract, image_np.shape[1] - w_subtract, image_np.shape[0] - h_subtract)
else:
# 保留最靠近中间的那个框的情况
# print('w:',image_np.shape[1],'h:',image_np.shape[0])
# 1.计算图片的中心点
# y:im.shape[0],x:im.shape[1]
x_center, y_center = image_np.shape[1] / 2, image_np.shape[0] / 2
# print(x_center,y_center) # 2 计算找出来的框到中心点的距离
dis_l = []
for b in reserve_boxes:
b_xcenter, b_ycenter = int((b[0] + b[2]) / 2), int((b[1] + b[3]) / 2)
distance = np.sqrt((x_center - b_xcenter) ** 2 + (y_center - b_ycenter) ** 2)
dis_l.append(distance)
# print('b_xcenter,b_ycenter:',b_xcenter,b_ycenter,distance) # 拿到最靠中心的box的index
center_index = dis_l.index(min(dis_l))
det = reserve_boxes[center_index]
print(det[0],det[1],det[2],det[3]) #可视化1
# cv2.rectangle(image_np, (det[0], det[1]), (det[2], det[3]), thickness=2, color=(0, 0, 255))
# cv2.imshow('res',image_np)
# cv2.waitKey(0)
# cv2.destroyAllWindows() #初始化
multi = multi() for i in range(5):
start_t=time.time()
multi.get_result("1","1.jpg")
end_t=time.time()
print('t1:',end_t-start_t)
multi.get_result("2","1.jpg")
start_t3=time.time()
print('t2:',start_t3-end_t)
通过类来实现多session 运行的更多相关文章
- C++ //多态 //静态多态:函数重载 和 运算符重载 属于静态多态 ,复用函数名 //动态多态:派生类和虚函数实现运行时多态
1 //多态 2 //静态多态:函数重载 和 运算符重载 属于静态多态 ,复用函数名 3 //动态多态:派生类和虚函数实现运行时多态 4 5 //静态多态和动态多态的区别 6 //静态多态的函数地址早 ...
- 教你在Java的普通类中轻松获取Session以及request中保存的值
曾经有多少人因为不知如何在业务类中获取自己在Action或页面上保存在Session中值,当然也包括我,但是本人已经学到一种办法可以解决这个问题,来分享下,希望对你有多多少少的帮助! 如何在Java的 ...
- [转载]tensorflow中使用tf.ConfigProto()配置Session运行参数&&GPU设备指定
tf.ConfigProto()函数用在创建session的时候,用来对session进行参数配置: config = tf.ConfigProto(allow_soft_placement=True ...
- 深度解剖session运行原理
已经大半年没有更新博客了,一方面有比博客更重要的事情要做,另外一方面也没有时间来整理知识,所以希望在接下来的日子里面能够多多的写博客来与大家交流 什么是session session的官方定义是:Se ...
- [转] spring的普通类中如何取session和request对像
在使用spring时,经常需要在普通类中获取session,request等对像.比如一些AOP拦截器类,在有使用struts2时,因为struts2有一个接口使用org.apache.struts2 ...
- jeecg中的一个上下文工具类获取request,session
通过调用其中的方法可以获取到request和session,调用方式如下: HttpServletRequest request = ContextHolderUtils.getRequest();H ...
- spring的普通类中如何取session和request对像
在使用spring时,经常需要在普通类中获取session,request等对像. 比如一些AOP拦截器类,在有使用struts2时,因为struts2有一个接口使用org.apache.struts ...
- tensorflow中使用tf.ConfigProto()配置Session运行参数&&GPU设备指定
tf.ConfigProto()函数用在创建session的时候,用来对session进行参数配置: config = tf.ConfigProto(allow_soft_placement=True ...
- 使用tf.ConfigProto()配置Session运行参数和GPU设备指定
参考链接:https://blog.csdn.net/dcrmg/article/details/79091941 tf.ConfigProto()函数用在创建session的时候,用来对sessio ...
随机推荐
- Excel催化剂开源第15波-VSTO开发之DataTable数据导出至单元格区域
上篇提到如何从Excel界面上拿到用户的数据,另外反方向的怎样输出给用户数据,也是关键之处. VSTO最大的优势是,这双向的过程中,全程有用户的交互操作. 而一般IT型的程序,都是脱离用户的操作,只能 ...
- 不同版本2.5的Servlet web.xml 头信息
<?xml version="1.0" encoding="UTF-8"?> <web-app version="2.5" ...
- 【git】Git的使用
一.安装git 1.windows下安装一个Git 2.lInux下yum(apt-get) install git 二.使用git连接github 使用git连接github时,需要将linux下产 ...
- 《VR入门系列教程》之1---预热篇
序 初识虚拟现实技术,非常倾心,奋力习之,阅<Learning Virtual Reality>一书之后觉得甚好,但不愿独乐乐,于是翻译之,与大家共同学习.本人学艺不精,难免有翻 ...
- 第二篇:"空空如也"的博客应用
文中涉及的示例代码,已同步更新到 HelloGitHub-Team 仓库 建立博客应用 我们已经建立了 django 博客的项目工程,并且成功地运行了它.不过到目前为止这一切都还只是 django 为 ...
- C++ 八数码问题宽搜
C++ 八数码问题宽搜 题目描述 样例输入 (none) 样例输出 H--F--A AC代码 #include <iostream> #include <stdio.h> #i ...
- springBoot综合开发
作者:纯洁的微笑出处:www.ityouknow.com 版权所有,欢迎保留原文链接进行转载:) 上篇文章介绍了Spring boot初级教程:spring boot(一):入门篇,方便大家快速入门. ...
- PIVOT内置函数实现行转列
PIVOT用于将列值旋转为列名(即行转列),PIVOT的一般语法是:PIVOT(聚合函数(列) FOR 列 in (…) )AS P 完整语法: table_source PIVOT( 聚合函数(va ...
- Hack The Box Web Pentest 2019
[20 Points] Emdee five for life [by L4mpje] 问题描述: Can you encrypt fast enough? 初始页面,不管怎么样点击Submit都会显 ...
- RocketMQ中Producer消息的发送
上篇博客介绍过Producer的启动,这里涉及到相关内容就不再累赘了 [RocketMQ中Producer的启动源码分析] Producer发送消息,首先需要生成Message实例: public c ...