yolotv5和resnet152模型预测
我已经训练完成了yolov5检测和resnet152分类的模型,下面开始对一张图片进行检测分类。
首先用yolo算法对猫和狗进行检测,然后将检测到的目标进行裁剪,然后用resnet152对裁剪的图片进行分类。
首先我有以下这些训练好的模型
猫狗检测的,猫的分类,狗的分类
我的预测文件my_detect.py
import os
import sys
from pathlib import Path from tools_detect import draw_box_and_save_img, dataLoad, predict_classify, detect_img_2_classify_img, get_time_uuid FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative from models.common import DetectMultiBackend
from utils.general import (non_max_suppression)
from utils.plots import save_one_box import config as cfg conf_thres = cfg.conf_thres
iou_thres = cfg.iou_thres detect_size = cfg.detect_img_size
classify_size = cfg.classify_img_size def detect_img(img, device, detect_weights='', detect_class=[], save_dir=''):
# 选择计算设备
# device = select_device(device)
# 加载数据
imgsz = (detect_size, detect_size)
im0s, im = dataLoad(img, imgsz, device)
# print(im0)
# print(im)
# 加载模型
model = DetectMultiBackend(detect_weights, device=device)
stride, names, pt = model.stride, model.names, model.pt
# print((1, 3, *imgsz))
model.warmup(imgsz=(1, 3, *imgsz)) # warmup pred = model(im, augment=False, visualize=False)
# print(pred)
pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=1000)
# print(pred)
im0 = im0s.copy()
# 画框,保存图片
# ret_bytes= None
ret_bytes = draw_box_and_save_img(pred, names, detect_class, save_dir, im0, im)
ret_li = list()
# print(pred)
im0_arc = int(im0.shape[0]) * int(im0.shape[1])
count = 1
for det in reversed(pred[0]):
# print(det)
# print(det)
# 目标太小跳过
xyxy_arc = (int(det[2]) - int(det[0])) * (int(det[3]) - int(det[1]))
# print(xyxy_arc)
if xyxy_arc / im0_arc < 0.01:
continue
# 裁剪图片
xyxy = det[:4]
im_crop = save_one_box(xyxy, im0, file=Path('im.jpg'), gain=1.1, pad=10, square=False, BGR=False, save=False)
# 将裁剪的图片转为分类的大小及tensor类型
im_crop = detect_img_2_classify_img(im_crop, classify_size, device) d = dict()
# print(det)
c = int(det[-1])
label = detect_class[c]
# 开始做具体分类
if label == detect_class[0]:
classify_predict = predict_classify(cfg.cat_weight, im_crop, device)
classify_label = cfg.cat_class[int(classify_predict)]
else:
classify_predict = predict_classify(cfg.dog_weight, im_crop, device)
classify_label = cfg.dog_class[int(classify_predict)]
# print(classify_label)
d['details'] = classify_label
conf = round(float(det[-2]), 2)
d['label'] = label+str(count)
d['conf'] = conf
ret_li.append(d)
count += 1 return ret_li, ret_bytes def start_predict(img, save_dir=''):
weights = cfg.detect_weight
detect_class = cfg.detect_class
device = cfg.device
ret_li, ret_bytes = detect_img(img, device, weights, detect_class, save_dir)
# print(ret_li)
return ret_li, ret_bytes if __name__ == '__main__':
name = get_time_uuid()
save_dir = f'./save/{name}.jpg'
# path = r'./test_img/hashiqi20230312_00010.jpg'
path = r'./test_img/hashiqi20230312_00116.jpg'
# path = r'./test_img/kejiquan20230312_00046.jpg'
f = open(path, 'rb')
img = f.read()
f.close()
# print(img)
# print(type(img))
img_ret_li, img_bytes = start_predict(img, save_dir=save_dir)
print(img_ret_li)
我的tools_detect.py文件
import datetime
import os
import random
import sys
import time
from pathlib import Path import torch
from PIL import Image
from torch import nn from utils.augmentations import letterbox FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative from utils.general import (cv2,
scale_boxes, xyxy2xywh)
from utils.plots import Annotator, colors
import numpy as np def bytes_to_ndarray(byte_img):
"""
图片二进制转numpy格式
"""
image = np.asarray(bytearray(byte_img), dtype="uint8")
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
return image def ndarray_to_bytes(ndarray_img):
"""
图片numpy格式转二进制
"""
ret, buf = cv2.imencode(".jpg", ndarray_img)
img_bin = Image.fromarray(np.uint8(buf)).tobytes()
# print(type(img_bin))
return img_bin def get_time_uuid():
"""
:return: 20220525140635467912
:PS :并发较高时尾部随机数增加
"""
uid = str(datetime.datetime.fromtimestamp(time.time())).replace("-", "").replace(" ", "").replace(":","").replace(".", "") + str(random.randint(100, 999))
return uid def dataLoad(img, img_size, device, half=False):
image = bytes_to_ndarray(img)
# print(image.shape)
im = letterbox(image, img_size)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous im = torch.from_numpy(im).to(device)
im = im.half() if half else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim return image, im def draw_box_and_save_img(pred, names, class_names, save_dir, im0, im): save_path = save_dir
fontpath = "./simsun.ttc"
for i, det in enumerate(pred):
annotator = Annotator(im0, line_width=3, example=str(names), font=fontpath, pil=True)
if len(det):
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
count = 1
im0_arc = int(im0.shape[0]) * int(im0.shape[1])
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]
base_path = os.path.split(save_path)[0]
file_name = os.path.split(save_path)[1].split('.')[0]
txt_path = os.path.join(base_path, 'labels')
if not os.path.exists(txt_path):
os.mkdir(txt_path)
txt_path = os.path.join(txt_path, file_name)
for *xyxy, conf, cls in reversed(det):
# 目标太小跳过
xyxy_arc = (int(xyxy[2]) - int(xyxy[0])) * (int(xyxy[3]) - int(xyxy[1]))
# print(im0.shape, xyxy, xyxy_arc, im0_arc, xyxy_arc / im0_arc)
if xyxy_arc / im0_arc < 0.01:
continue
# print(im0.shape, xyxy)
c = int(cls) # integer class
label = f"{class_names[c]}{count} {round(float(conf), 2)}" # .encode('utf-8')
# print(xyxy)
annotator.box_label(xyxy, label, color=colors(c, True)) im0 = annotator.result()
count += 1
# print(im0) # print(type(im0))
# im0 为 numpy.ndarray类型 # Write to file
# print('+++++++++++')
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
# print(xywh)
line = (cls, *xywh) # label format
with open(f'{txt_path}.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
cv2.imwrite(save_path, im0) ret_bytes = ndarray_to_bytes(im0)
return ret_bytes def predict_classify(model_path, img, device):
# im = torch.nn.functional.interpolate(img, (160, 160), mode='bilinear', align_corners=True)
# print(device)
if torch.cuda.is_available():
model = torch.load(model_path)
else:
model = torch.load(model_path, map_location='cpu')
# print(help(model))
model.to(device)
model.eval()
predicts = model(img)
_, preds = torch.max(predicts, 1)
pred = torch.squeeze(preds)
# print(pred)
return pred def detect_img_2_classify_img(img, classify_size, device):
im_crop1 = img.copy()
im_crop1 = np.float32(im_crop1)
image = cv2.resize(im_crop1, (classify_size, classify_size))
image = image.transpose((2, 0, 1))
im = torch.from_numpy(image).unsqueeze(0)
im_crop = im.to(device)
return im_crop
我的config.py文件
import torch
import os base_path = r'.\weights' detect_weight = os.path.join(base_path, r'cat_dog_detect/best.pt')
detect_class = ['猫', '狗'] cat_weight = os.path.join(base_path, r'cat_predict/best.pt')
cat_class = ['东方短毛猫', '亚洲豹猫', '加菲猫', '安哥拉猫', '布偶猫', '德文卷毛猫', '折耳猫', '无毛猫', '暹罗猫', '森林猫', '橘猫', '奶牛猫', '狞猫', '狮子猫', '狸花猫', '玳瑁猫', '白猫', '蓝猫', '蓝白猫', '薮猫', '金渐层猫', '阿比西尼亚猫', '黑猫'] dog_weight = os.path.join(base_path, r'dog_predict/best.pt')
dog_class = ['中华田园犬', '博美犬', '吉娃娃', '哈士奇', '喜乐蒂', '巴哥犬', '德牧', '拉布拉多犬', '杜宾犬', '松狮犬', '柯基犬', '柴犬', '比格犬', '比熊', '法国斗牛犬', '秋田犬', '约克夏', '罗威纳犬', '腊肠犬', '萨摩耶', '西高地白梗犬', '贵宾犬', '边境牧羊犬', '金毛犬', '阿拉斯加犬', '雪纳瑞', '马尔济斯犬'] # device = 0
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
conf_thres = 0.5
iou_thres = 0.45 detect_img_size = 416
classify_img_size = 160
整体文件结构
其中models和utils文件夹都是yolov5源码的文件
运行my_detect.py的结果
yolotv5和resnet152模型预测的更多相关文章
- TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化
线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...
- 基于GPS数据建立隐式马尔可夫模型预测目的地
<Trip destination prediction based on multi-day GPS data>是一篇在2019年,由吉林交通大学团队发表在elsevier期刊上的一篇论 ...
- 修正剑桥模型预测-用python3.4
下面是预测结果: #!/usr/bin/env python # -*- coding:utf-8 -*- # __author__ = "blzhu" ""& ...
- 时间序列深度学习:seq2seq 模型预测太阳黑子
目录 时间序列深度学习:seq2seq 模型预测太阳黑子 学习路线 商业中的时间序列深度学习 商业中应用时间序列深度学习 深度学习时间序列预测:使用 keras 预测太阳黑子 递归神经网络 设置.预处 ...
- 时间序列深度学习:状态 LSTM 模型预测太阳黑子
目录 时间序列深度学习:状态 LSTM 模型预测太阳黑子 教程概览 商业应用 长短期记忆(LSTM)模型 太阳黑子数据集 构建 LSTM 模型预测太阳黑子 1 若干相关包 2 数据 3 探索性数据分析 ...
- 【R实践】时间序列分析之ARIMA模型预测___R篇
时间序列分析之ARIMA模型预测__R篇 之前一直用SAS做ARIMA模型预测,今天尝试用了一下R,发现灵活度更高,结果输出也更直观.现在记录一下如何用R分析ARIMA模型. 1. 处理数据 1.1. ...
- tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测
由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...
- NLP(十八)利用ALBERT提升模型预测速度的一次尝试
前沿 在文章NLP(十七)利用tensorflow-serving部署kashgari模型中,笔者介绍了如何利用tensorflow-serving部署来部署深度模型模型,在那篇文章中,笔者利用k ...
- [Tensorflow]模型持久化的原理,将CKPT转为pb文件,使用pb模型预测
文章目录 [Tensorflow]模型持久化的原理,将CKPT转为pb文件,使用pb模型预测 一.模型持久化 1.持久化代码实现 convert_variables_to_constants固化模型结 ...
- 用R做时间序列分析之ARIMA模型预测
昨天刚刚把导入数据弄好,今天迫不及待试试怎么做预测,网上找的帖子跟着弄的. 第一步.对原始数据进行分析 一.ARIMA预测时间序列 指数平滑法对于预测来说是非常有帮助的,而且它对时间序列上面连续的值之 ...
随机推荐
- 咕咕list
做完以后会留在榜上一天,这样显得咕咕list长一些 CF666E Forensic Examination(done on 2023.2.6) dp选做
- TypeScript 学习笔记 — 类型兼容 (十)
目录 一.基本数据类型的兼容性 二.接口兼容性 三.函数的兼容性 四.类的兼容性 类的私有成员和受保护成员 五.泛型的兼容性 六.枚举的兼容性 标称类型简短介绍 TS 是结构类型系统(structur ...
- Java语言标识符的命名规范(超详细讲解)
前言 在上一篇文章中,壹哥带领大家开始编写了第一个 Java 案例,在我们的 cmd 命令窗口中输出了"Hello World"这句话.并且我还给大家留了一个小作业,你做出来了吗? ...
- Neo4j常用操作——Cypher查询语言
1. 删除数据库中以往的图,确保一个空白的环境进行操作: MATCH (n) DETACH DELETE n # 要想删除数据库的话直接删除文件即可 2. 创建一个人物节点: CREATE (n:Pe ...
- 「高频必考」Docker&K8S面试题和答案
先送福利:Go如何自动解压缩包?| 文末送书 Docker 如何在Docker容器内部访问主机上的服务? 可以通过设置主机网络模式,使用--net=host参数来访问主机上的服务.这样,容器和主机将共 ...
- Unity3D中的Attribute详解(五)
今天主要来讲一下Unity中带Menu的Attribute. 首先是AddComponentMenu.这是UnityEngine命名空间下的一个Attribute. 按照官方文档的说法,会在Compo ...
- 这可能是最全面的Spring面试题总结了
Spring是什么? Spring是一个轻量级的控制反转(IoC)和面向切面(AOP)的容器框架. Spring的优点 通过控制反转和依赖注入实现松耦合. 支持面向切面的编程,并且把应用业务逻辑和系统 ...
- 迁移学习《Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks》
论文信息 论文标题:Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Ne ...
- python入门教程之十三错误和异常
作为 Python 初学者,在刚学习 Python 编程时,经常会看到一些报错信息,在前面我们没有提及,这章节我们会专门介绍. Python 有两种错误很容易辨认:语法错误和异常. Python as ...
- [Linux/Java SE]查看JAR包内的类 | JAR 命令 | 反编译
1 查看JAR包内的类 另一个思路: 解压JAR包jar -xf <jarPath> 1-1 单JAR包 -t list table of contents for archive(列出存 ...