深度学习(PYTORCH)-3.sphereface-pytorch.lfw_eval.py详解
pytorch版本sphereface的原作者地址:https://github.com/clcarwin/sphereface_pytorch
由于接触深度学习不久,所以花了较长时间来阅读源码,以下对项目中的lfw_eval.py文件做了详细解释
(不知是版本问题还是作者code有误,原代码存在很多的bug,需要自行一一纠正,另:由于在windows下运行,故而去掉了gpu加速以及多线程)
#-*- coding:utf-8 -*-
from __future__ import print_function import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
torch.backends.cudnn.bencmark = True import os,sys,cv2,random,datetime
import argparse
import numpy as np
import zipfile from dataset import ImageDataset
from matlab_cp2tform import get_similarity_transform_for_cv2
import net_sphere
from matplotlib import pyplot as plt #图像对齐和裁剪
def alignment(src_img,src_pts):
#使用标准人脸坐标对图像进行仿射
ref_pts = [ [30.2946, 51.6963],[65.5318, 51.5014],
[48.0252, 71.7366],[33.5493, 92.3655],[62.7299, 92.2041] ]
crop_size = (96, 112)
src_pts = np.array(src_pts).reshape(5,2) s = np.array(src_pts).astype(np.float32)
r = np.array(ref_pts).astype(np.float32) tfm = get_similarity_transform_for_cv2(s, r)
face_img = cv2.warpAffine(src_img, tfm, crop_size)
return face_img #k-fold cross validation(k-折叠交叉验证)
#将n份数据分为n_folds份,以次将第i份作为测试集,其余部分作为训练集
def KFold(n=200, n_folds=10, shuffle=False):
folds = []
base = list(range(n))
for i in range(n_folds):
test = base[(i*n//n_folds):((i+1)*n//n_folds)]
train = list(set(base)-set(test))
folds.append([train,test])
return folds #求解当前阈值时的准确率
def eval_acc(threshold, diff):
y_true = []
y_predict = []
for d in diff:
same = 1 if float(d[2]) > threshold else 0
y_predict.append(same)
y_true.append(int(d[3]))
y_true = np.array(y_true)
y_predict = np.array(y_predict)
accuracy = 1.0*np.count_nonzero(y_true==y_predict)/len(y_true)
return accuracy #eval_acc和find_best_threshold共同工作,来求试图找到最佳阈值,
#
def find_best_threshold(thresholds, predicts):
#threshould 阈值
best_threshold = best_acc = 0
for threshold in thresholds:
accuracy = eval_acc(threshold, predicts)
if accuracy >= best_acc:
best_acc = accuracy
best_threshold = threshold
return best_threshold #命令行参数
parser = argparse.ArgumentParser(description='PyTorch sphereface lfw')
parser.add_argument('--net','-n', default='sphere20a', type=str)
parser.add_argument('--lfw', default='../DataSet/lfw.zip', type=str)
parser.add_argument('--model','-m', default='./sphere20a_20171020.pth', type=str)
args = parser.parse_args() predicts=[] #加载网络
net = getattr(net_sphere,args.net)()
#加载模型
net.load_state_dict(torch.load(args.model))
#
net.eval()
#
net.feature = True #加载图片数据
zfile = zipfile.ZipFile(args.lfw) #加载landmark,每张照片包括五个特征点,共五组坐标
landmark = {}
with open('data/lfw_landmark.txt') as f:
landmark_lines = f.readlines()
#对每一行进行处理
for line in landmark_lines:
l = line.replace('\n','').split('\t')
#将每一组数据转化为字典形式
landmark[l[0]] = [int(k) for k in l[1:]] #加载pairs
with open('data/pairs.txt') as f:
pairs_lines = f.readlines()[1:] #range表示测试的图片对数
for i in range(600):
print(str(i)+" start")
p = pairs_lines[i].replace('\n','').split('\t')
# pairs.txt一共有6000行,存在两种形式,
# 分别表示进行对比的两张照片,形式1是同一个人,形式2是不同人:
# name 数字1 数字2
# name 数字1 name数字2
if 3==len(p):
sameflag = 1
#形式例如:Woody_Allen/Woody_Allen_0002.jpg
name1 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[1]))
name2 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[2]))
if 4==len(p):
sameflag = 0
name1 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[1]))
name2 = p[2]+'/'+p[2]+'_'+'{:04}.jpg'.format(int(p[3])) #分别加载两张照片,并对其进行图像对齐
org_img1=cv2.imdecode(np.frombuffer(zfile.read("lfw/lfw/"+name1),np.uint8),1)
org_img2=cv2.imdecode(np.frombuffer(zfile.read("lfw/lfw/"+name2),np.uint8),1)
img1 = alignment(org_img1,landmark[name1])
img2 = alignment(org_img2,landmark[name2])
#1.对输出图像使用cv2进行展示
# cv2.imshow("org_img1", org_img1)
# cv2.imshow("org_img2", org_img2)
# cv2.imshow("img1",img1)
# cv2.imshow("img2", img2)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
#2.对输出图像使用matplotlib进行展示
fig_new=plt.figure()
img_list=[[org_img1,221],[org_img2,222],[img1,223],[img2,224]]
for p,q in img_list:
ax=fig_new.add_subplot(q)
p = p[:, :, (2, 1, 0)]
ax.imshow(p)
plt.show() #cv.flip图像翻转,第二个参数:1:水平翻转,0:垂直翻转,-1:水平垂直翻转
imglist = [img1,cv2.flip(img1,1),img2,cv2.flip(img2,1)]
#分别对图片进行
for m in range(len(imglist)):
imglist[m] = imglist[m].transpose(2, 0, 1).reshape((1,3,112,96))
imglist[m] = (imglist[m]-127.5)/128.0 # p.vstack: 垂直(按照行顺序)的把数组给堆叠起来
#******举例******
# import numpy as np
# a = [1, 2, 3]
# b = [4, 5, 6]
# print(np.vstack((a, b)))
#
# 输出:
# [[1 2 3]
# [4 5 6]]
img = np.vstack(imglist)
#将numpy形式转化为variable形式
img = Variable(torch.from_numpy(img).float(),volatile=True)
output = net(img)
#得到计算结果,f1和f2均为512维向量形式
f = output.data
f1,f2 = f[0],f[2]
#计算二者的余弦相似度,后面加上常量是为了防止分母为0
#关于余弦相似度请自行百度或google
#这里给出一个简单说明的链接:http://blog.csdn.net/huangfei711/article/details/78469614
#a*b/|a||b|
cosdistance = f1.dot(f2)/(f1.norm()*f2.norm()+1e-5)
predicts.append('{}\t{}\t{}\t{}\n'.format(name1,name2,cosdistance,sameflag))
print(str(i) + " end") #准确率
accuracy = []
#(最佳)阈值
thd = []
#k-fold cross validation(k-折叠交叉验证)
#folds的形式为[[train,test],[train,test].....]
folds = KFold(n=600, n_folds=10, shuffle=False)
#取数组为-1到1,步长为0.005
thresholds = np.arange(-1.0, 1.0, 0.005)
# 此处为原作者code,疑似有误,已做修改
# predicts = np.array(map(lambda line:frd.append(line.strip('\n').split()), predicts))
predicts = np.array([k.strip('\n').split() for k in predicts])
for idx, (train, test) in enumerate(folds):
# predicts[train/test]形式为:
# [['Doris_Roberts/Doris_Roberts_0001.jpg'
# 'Doris_Roberts/Doris_Roberts_0003.jpg' '0.6532696413605743' '1'],.....]
#寻找最佳阈值
best_thresh = find_best_threshold(thresholds, predicts[train])
#通过上面的得到的最佳阈值来对test数据集进行测试得到准确率
accuracy.append(eval_acc(best_thresh, predicts[test]))
#thd阈值
thd.append(best_thresh)
#np.mean:计算均值,np.std:计算标准差
#输出结果分别为:准确率均值,准确率标准差,阈值均值
print('LFWACC={:.4f} std={:.4f} thd={:.4f}'.format(np.mean(accuracy), np.std(accuracy), np.mean(thd)))
#例如结果为 LFWACC=0.9800 std=0.0600 thd=0.3490
#则说明准确率为98%,准确率标准差为0.06,阈值的均值为0.3490
#因此我们可以认为余弦相似度大于0.3490的两张图片里是同一个人
深度学习(PYTORCH)-3.sphereface-pytorch.lfw_eval.py详解的更多相关文章
- 基于OpenCL的深度学习工具:AMD MLP及其使用详解
基于OpenCL的深度学习工具:AMD MLP及其使用详解 http://www.csdn.net/article/2015-08-05/2825390 发表于2015-08-05 16:33| 59 ...
- 【深度学习系列】卷积神经网络CNN原理详解(一)——基本原理
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 深度学习笔记——PCA原理与数学推倒详解
PCA目的:这里举个例子,如果假设我有m个点,{x(1),...,x(m)},那么我要将它们存在我的内存中,或者要对着m个点进行一次机器学习,但是这m个点的维度太大了,如果要进行机器学习的话参数太多, ...
- 深度学习识别CIFAR10:pytorch训练LeNet、AlexNet、VGG19实现及比较(三)
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com VGGNet在2014年ImageNet图像分类任务竞赛中有出色的表现.网络结构如下图所示: 同样的, ...
- 深度学习识别CIFAR10:pytorch训练LeNet、AlexNet、VGG19实现及比较(二)
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com AlexNet在2012年ImageNet图像分类任务竞赛中获得冠军.网络结构如下图所示: 对CIFA ...
- 深度学习框架Keras与Pytorch对比
对于许多科学家.工程师和开发人员来说,TensorFlow是他们的第一个深度学习框架.TensorFlow 1.0于2017年2月发布,可以说,它对用户不太友好. 在过去的几年里,两个主要的深度学习库 ...
- 深度学习调用TensorFlow、PyTorch等框架
深度学习调用TensorFlow.PyTorch等框架 一.开发目标目标 提供统一接口的库,它可以从C++和Python中的多个框架中运行深度学习模型.欧米诺使研究人员能够在自己选择的框架内轻松建立模 ...
- Linux防火墙iptables学习笔记(三)iptables命令详解和举例[转载]
Linux防火墙iptables学习笔记(三)iptables命令详解和举例 2008-10-16 23:45:46 转载 网上看到这个配置讲解得还比较易懂,就转过来了,大家一起看下,希望对您工作能 ...
- ASP.NET MVC 5 学习教程:Details 和 Delete 方法详解
原文 ASP.NET MVC 5 学习教程:Details 和 Delete 方法详解 在教程的这一部分,我们将研究一下自动生成的 Details 和Delete 方法. Details 方法 打开M ...
随机推荐
- 自动化定位——通过XPath定位元素
XPath是一种XML文档中定位元素的语言.该定位方式也是比较常用的定位方式 1通过属性定位元素 find_element_by_xpath("//标签名[@属性='属性值']") ...
- Flask Vue.js全栈开发
Flask Vue.js全栈开发的 最新完整代码 及使用方式 本系列的最新代码及使用方式将持续更新到: http://www.madmalls.com/blog/post/latest-code/ 1 ...
- python基本概念
python环境以及python的搭建的基本知识 python解释器 python语言的本质 通过解释器将脚本翻译成机器能识别的二进制码,交予机器执行 pycharm ide:集成开发环境 集成编译器 ...
- springboot 默认异常处理
SpringBoot默认有自定义异常处理的体系,在做SpringBoot项目的时候,如果是抛出了运行时异常,springBoot并会对异常进行处理,返回如下异常信息: { "timestam ...
- 【分享】用Canvas实现画板功能
前言 PC端测试:QQ浏览器全屏绘画完成.缩小时内容会被清空,切换背景颜色内容会被重置,其他暂无发现: 手机端测试:微信内置浏览器不通过:Safari 浏览器使用画笔时没固定页面会有抖动效果,使用橡皮 ...
- thymeleaf下拉框从后台动态获取集合数据并回显选中
今天遇到从后台集合中取出对象在前台页面下拉列表展示: <select name="signature" lay-search="" class=" ...
- nginx配置文件详解----第一篇【访问与错误日志】
error_log错误日志 access_log访问日志 log_format指令 语法: log_format name string …;默认值: log_format combined “ ...
- require 4种引入方式的区别
以下四种引入方式的区别: 自己创建的包里面封装了一些方法,只是把aa文件夹放在了node_modules文件夹里,所以在引用时,不需要写上相对路径,也不能在网上下载 这是网上别人封装好了的包,下载好了 ...
- typescript初入门
1.通过npm安装 typescript 进入终端窗口安装typescript: npm install -g typescript 查看typescript版本号: tsc -v 2.编译代码:t ...
- pip更新
python -m ** install -U **