欢迎关注磐创博客资源汇总站:

http://docs.panchuang.net/

欢迎关注PyTorch官方中文教程站:

http://pytorch.panchuang.net/

计算机视觉–图像和视频数据分析是深度学习目前最火的应用领域之一。因此,在学习深度学习的同时尝试运用某些计算机视觉技术做些有趣的事情会很有意思,也会让你发现些令人吃惊的事实。长话短说,我的搭档(Maximiliane Uhlich)和我决定将深度学习应用于浪漫情侣的形象分类上,因为Maximiliane是一位关系研究员和情感治疗师。具体来说,我们想知道我们是否可以准确地判断图像或视频中描绘的情侣是否对他们的关系感到满意? 事实证明,我们可以!我们的最终模型(我们称之为DeepConnection)分类准确率接近97%,能够准确地区分幸福与不幸福的情侣。大家可以在我们的论文预览链接[^1]里阅读完整介绍,上图是我们为这个任务设计的框架草图。

在数据集收集方面,我们使用这个Python脚本[^2]进行网页数据抽取(webscraping)来获取幸福和不幸福的情侣数据。最后,我们整理出了大约包含1000张图像的训练集。这并不是特别多,所以我们使用数据增强与迁移学习来增强我们模型在数据集上的表现。数据增强–图像方向的微小变化,色调和色彩强度以及许多其他因素都会增强模型的泛化能力,从而避免学习一些不相关信息。 例如,如果数据中幸福夫妻的图像平均比不幸福夫妻的图像更亮,我们并不希望我们的模型映射这种关联。我们使用了强大的ImgAug库[^3]进行了相当多策略的数据扩充,以确保我们模型的鲁棒性。基本上对于每个批次的每个图像,我们至都至少应用多种数据增强技术。下图是一张图片应用了48种数据增强策略的示例。

我们决定使用ResNet模型作为DeepConnection的基础网络,在大型数据集ImageNet上预先训练。通过预训练,模型已经具有了一定的识别能力。我们所有的模型都借用PyTorch实现,我们使用Google Colab上的免费GPU资源进行训练和测试。这个基础模型本身已经具备了良好的分类能力,但我们决定更进一步,用空间金字塔池化层(SPP)[^4] 替换ResNet-34基础模型的最后一个自适应池模块。这里,处理后的图像数据被分成不同数量的正方形,并且仅传递最大值以进行进一步分析(最大池化)。这使得模型可以专注于重要的特征,使其对不同大小的图像具有鲁棒性,并且不受图像扰动的影响。之后,我们放置了一个均值变换(PMT)层[^5],用数学函数转换数据以引入非线性,使得DeepConnection可以从数据中捕获更复杂的关系。这两个模块均提高了我们的分类准确度,我们在单独的验证集上得到了大约97%准确率。SPP / PMT和后续分类层的代码如下所示:

  1. class SPP(nn.Module):
  2. def __init__(self):
  3. super(SPP, self).__init__()
  4. ## features incoming from ResNet-34 (after SPP/PMT)
  5. self.lin1 = nn.Linear(2*43520, 100)
  6. self.relu = nn.ReLU()
  7. self.bn1 = nn.BatchNorm1d(100)
  8. self.dp1 = nn.Dropout(0.5)
  9. self.lin2 = nn.Linear(100, 2)
  10. def forward(self, x):
  11. # SPP
  12. x = spatial_pyramid_pool(x, x.shape[0], [x.shape[2], x.shape[3]], [8, 4, 2, 1])
  13. # PMT
  14. x_1 = torch.sign(x)*torch.log(1 abs(x))
  15. x_2 = torch.sign(x)*(torch.log(1 abs(x)))**2
  16. x = torch.cat((x_1, x_2), dim = 1)
  17. # fully connected classification part
  18. x = self.lin1(x)
  19. x = self.bn1(self.relu(x))
  20. #1
  21. x1 = self.lin2(self.dp1(x))
  22. #2
  23. x2 = self.lin2(self.dp1(x))
  24. #3
  25. x3 = self.lin2(self.dp1(x))
  26. #4
  27. x4 = self.lin2(self.dp1(x))
  28. #5
  29. x5 = self.lin2(self.dp1(x))
  30. #6
  31. x6 = self.lin2(self.dp1(x))
  32. #7
  33. x7 = self.lin2(self.dp1(x))
  34. #8
  35. x8 = self.lin2(self.dp1(x))
  36. x = torch.mean(torch.stack([x1, x2, x3, x4, x5, x6, x7, x8]), dim = 0)
  37. return x

仔细观察代码可以看出,最终分类层上有八个变种。看似浪费了算力实际上恰恰相反。这个概念是最近提出的,叫做multi-sample dropout(多样本随机丢弃),它在训练期间显着加速了收敛[^6]。它基本上是防止模型学习虚假关系(过度拟合)和试图不丢弃丢失掩码中的信息之间的折衷。

我们在项目中对这个方法进行了其他一些调整优化,具体参看我们在GitHub放出的项目代码[7]以获取更多信息。简单地提一下:我们使用混合精度(使用Apex库[8]实现)训练模型,以大大降低内存使用率,使用早停(earlystopping)来防止过度拟合,并根据余弦函数进行学习率退火。

在达到令人满意的分类准确度(具有相应高的召回率和精确度)后,我们想知道我们是否可以从DeepConnection执行的分类中学到一些东西。因此,我们尝试模型解释性探索并使用梯度加权类激活映射技术(Grad-CAM)进行分析[^9]。基本地,Grad-CAM获取最终卷积层的输入梯度以确定显著区域,其可以被视为原始图像之上的上采样热图。具体实现与可视化结果如下:

  1. ## from https://github.com/eclique/pytorch-gradcam/blob/master/gradcam.ipynb
  2. def GradCAM(img, c, features_fn, classifier_fn):
  3. feats = modulelist_conv(img.cuda().half())
  4. feats = feats.cuda()
  5. _, N, H, W = feats.size()
  6. out = modulelist_fc(feats)
  7. c_score = out[0, c]
  8. grads = torch.autograd.grad(c_score, feats)
  9. w = grads[0][0].mean(-1).mean(-1)
  10. sal = torch.matmul(w, feats.view(N, H*W))
  11. sal = sal.view(H, W).cpu().detach().numpy()
  12. sal = np.maximum(sal, 0)
  13. return sal

我们在论文中对此进行了进一步讨论,并将其嵌入到了现有的心理学研究中,但DeepConnection似乎主要关注面部区域。从研究的角度来看,这很有意义,因为面部表情会传达沟通和情感。除了Grad-CAM获得的视觉感知之外,我们还想看看我们是否可以通过模型解释得出实际特征。为此,我们创建了激活状态图,以显示最终分类层的哪些神经元被哪些给定图像区域激活。

与其他模型相比,DeepConnection还学习到了代表不幸福的特征,并不仅仅将缺乏代表幸福的特征的分类为不幸福。但是,我们需要进一步的研究才能将这些特征实际映射到人类行为可解释性方面。我们还尝试过在未知的情侣视频帧上使用DeepConnection,效果非常好。

总体而言,该模型的稳健性是其强大优势之一。准确的分类同样适用于同性恋伴侣不同肤色人种除情侣外包含其他人的视频帧中不能完整显示情侣人脸的视频帧中等等。对于图像中存在其他人的情况,DeepConnection甚至可以识别其他人是否感到满意,但仍然将其预测集中在这对情侣身上。

除了进一步的模型解释之外,下一步的工作将是使用更大的训练数据集,从而训练更复杂的模型。使用DeepConnection作为情侣治疗师的助手将会很有意思,可以在会话期间或之后对情侣的当前关系状态进行实时反馈。此外,我建议您与女票/男票一起输入你们的合照,看看DeepConnection对你们的关系有何看法!希望这会是一个好的开始!

1: https://psyarxiv.com/df25j/

2: https://github.com/Bribak/DeepConnection

3: https://github.com/aleju/imgaug

4: https://arxiv.org/abs/1406.4729

5: https://www.sciencedirect.com/science/article/pii/S0031320318304503

6: https://arxiv.org/abs/1905.09788

7: https://github.com/Bribak/DeepConnection

8: https://github.com/NVIDIA/apex

9: https://arxiv.org/abs/1610.02391

使用PyTorch进行情侣幸福度测试指南的更多相关文章

  1. 腾讯GT的流畅度测试方案研究

    GT源码:https://github.com/TencentOpen/GT 一.流畅度模块的代码结构 流畅度插件总共就几个类,其实处理方式也比较简单粗暴,就是通过Choreographer输出的lo ...

  2. Android流畅度测试

    Android流畅度测试 测试方法一:系统自带-开发者模式 测试方法二:FPS Meter测试安卓帧数 H5页面加载速度:window.performance.timing 测试方法一:系统自带-开发 ...

  3. 《大话移动APP测试:Android与iOS应用测试指南》

    <大话移动app测试:android与ios应用测试指南> 基本信息 作者: 陈晔 出版社:清华大学出版社 ISBN:9787302368793 上架时间:2014-7-7 出版日期:20 ...

  4. 推荐——Monkey《大话 app 测试——Android、iOS 应用测试指南》

    <大话移动——Android与iOS应用测试指南> 京东可以预购啦!http://item.jd.com/11495028.html 当当网:http://product.dangdang ...

  5. app流畅度测试--使用手机自带功能

    1.进入开发者选项,在“监控”选项卡找到“GPU呈现模式分析”的选项 2.开启后,即可以条形图和线形图的方式显示系统的界面相应速度 3.那么要如何根据曲线判断系统是否流畅呢?实际上这个曲线表达的是GP ...

  6. OWASP固件安全性测试指南

    OWASP固件安全性测试指南 固件安全评估,英文名称 firmware security testing methodology 简称 FSTM.该指导方法主要是为了安全研究人员.软件开发人员.顾问. ...

  7. 基于XGBoost模型的幸福度预测——阿里天池学习赛

    加载数据 加载的是完整版的数据 happiness_train_complete.csv . import numpy as np import pandas as pd import matplot ...

  8. Web安全测试指南--认证

    认证: 5.1.1.敏感数据传输: 编号 Web_Authen_01_01 用例名称 敏感数据传输保密性测试 用例描述 测试敏感数据是否通过加密通道进行传输以防止信息泄漏. 严重级别 高 前置条件 1 ...

  9. pytorch: 准备、训练和测试自己的图片数据

    大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试.如果我们是自己的图片数据,又该怎么做呢? 一.我的数据 我在学习的时候,使用的是fashion-mnist.这个 ...

随机推荐

  1. Leetcode 20题 有效的括号(Valid Parentheses) Java语言求解

    题目描述: 给定一个只包括 '(',')','{','}','[',']' 的字符串,判断字符串是否有效. 有效字符串需满足: 左括号必须用相同类型的右括号闭合. 左括号必须以正确的顺序闭合. 注意空 ...

  2. 5——PHP逻辑运算符&&唯一的三元运算符

    */ * Copyright (c) 2016,烟台大学计算机与控制工程学院 * All rights reserved. * 文件名:text.cpp * 作者:常轩 * 微信公众号:Worldhe ...

  3. 达拉草201771010105《面向对象程序设计(java)》第十六周学习总结

    达拉草201771010105<面向对象程序设计(java)>第十六周学习总结 第一部分:理论知识 1.程序与进程的概念: (1)程序是一段静态的代码,它是应用程序执行的蓝 本. (2)进 ...

  4. Mariadb 修改root密码及跳过授权方式启动数据库

    默认情况下,yum方式新安装的 mariadb 的密码为空,在shell终端直接输入 mysql 就能登陆数据库. 如果是刚安装第一次使用,请使用 mysql_secure_installation ...

  5. 目标用户偏好指数Target Group Index分析

    目标用户偏好指数Target Group Index分析 TGI指数,全称Target Group Index,可以反映目标群体在特定研究范围内强势或者弱势. TGI指数计算公式 = 目标群体中具有某 ...

  6. 前端、HTML+CSS+JS编写规范(终极版)

    HTMLCSS文档规范 HTML和CSS文档必须采用UTF-8编码格式: HTML文档必须使用HTML5的标准文档格式: HTMLCSS编写规范 HTML和CSS的标签.属性.类名.ID都必须使用小写 ...

  7. 前端每日实战:125# 视频演示如何用纯 CSS 创作一个失落的人独自行走的动画

    效果预览 按下右侧的"点击预览"按钮可以在当前页面预览,点击链接可以全屏预览. https://codepen.io/comehope/pen/MqpOdR/ 可交互视频 此视频是 ...

  8. CSS3:TEXT-SHADOW|BOX-SHADOW(炫彩字体)

    2016年2月26日个人博客文章--迁移到segmentfault (1)text-shadow(文本阴影) 在介绍css3:text-shadow文本阴影之前,我们先来看看用它都能实现什么效果: 没 ...

  9. 前端javascript知识(二)

    documen.write和 innerHTML的区别 document.write只能重绘整个页面 innerHTML可以重绘页面的一部分 浏览器检测通过什么? (1) navigator.user ...

  10. 学习使用Guava Cache

    官方文档:https://github.com/google/guava/wiki/CachesExplained 目录 一.guava cache介绍 二.快速入门 2.1.引入依赖 2.2.第一个 ...