主题列表:juejin, github, smartblue, cyanosis, channing-cyan, fancy, hydrogen, condensed-night-purple, greenwillow, v-green, vue-pro, healer-readable

贡献主题:https://github.com/xitu/juejin-markdown-themes

theme: smartblue

highlight:

在上一篇文章中已经讲解了Siamese Net的原理,和这种网络架构的关键——损失函数contrastive loss。现在我们来用pytorch来做一个简单的案例。经过这个案例,我个人的收获有到了以下的几点:

  • Siamese Net适合小数据集;
  • 目前Siamese Net用在分类任务(如果有朋友知道如何用在分割或者其他任务可以私信我,WX:cyx645016617)
  • Siamese Net的可解释性较好。

1 准备数据

  1. import pandas as pd
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch.utils.data import Dataset,DataLoader
  8. from sklearn.model_selection import train_test_split
  9. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  1. data_train = pd.read_csv('../input/fashion-mnist_train.csv')
  2. data_train.head()

这个数据文件是csv格式,第一列是类别,之后的784列其实好似28x28的像素值。

划分训练集和验证集,然后把数据转换成28x28的图片

  1. X_full = data_train.iloc[:,1:]
  2. y_full = data_train.iloc[:,:1]
  3. x_train, x_test, y_train, y_test = train_test_split(X_full, y_full, test_size = 0.05)
  4. x_train = x_train.values.reshape(-1, 28, 28, 1).astype('float32') / 255.
  5. x_test = x_test.values.reshape(-1, 28, 28, 1).astype('float32') / 255.
  6. y_train.label.unique()
  7. >>> array([8, 9, 7, 6, 4, 2, 3, 1, 5, 0])

可以看到这个Fashion MNIST数据集中也是跟MNIST类似,划分了10个不同的类别。

  • 0 T-shirt/top
  • 1 Trouser
  • 2 Pullover
  • 3 Dress
  • 4 Coat
  • 5 Sandal
  • 6 Shirt
  • 7 Sneaker
  • 8 Bag
  • 9 Ankle boot
  1. np.bincount(y_train.label.values),np.bincount(y_test.label.values)
  2. >>> (array([4230, 4195, 4135, 4218, 4174, 4172, 4193, 4250, 4238, 4195]),
  3. array([1770, 1805, 1865, 1782, 1826, 1828, 1807, 1750, 1762, 1805]))

可以看到,每个类别的数据还是非常均衡的。

2 构建Dataset和可视化

  1. class mydataset(Dataset):
  2. def __init__(self,x_data,y_data):
  3. self.x_data = x_data
  4. self.y_data = y_data.label.values
  5. def __len__(self):
  6. return len(self.x_data)
  7. def __getitem__(self,idx):
  8. img1 = self.x_data[idx]
  9. y1 = self.y_data[idx]
  10. if np.random.rand() < 0.5:
  11. idx2 = np.random.choice(np.arange(len(self.y_data))[self.y_data==y1],1)
  12. else:
  13. idx2 = np.random.choice(np.arange(len(self.y_data))[self.y_data!=y1],1)
  14. img2 = self.x_data[idx2[0]]
  15. y2 = self.y_data[idx2[0]]
  16. label = 0 if y1==y2 else 1
  17. return img1,img2,label

关于torch.utils.data.Dataset的构建结构,我就不再赘述了,在之前的《小白学PyTorch》系列中已经讲解的很清楚啦。上面的逻辑就是,给定一个idx,然后我们先判断,这个数据是找两个同类别的图片还是两个不同类别的图片。50%的概率选择两个同类别的图片,然后最后输出的时候,输出这两个图片,然后再输出一个label,这个label为0的时候表示两个图片的类别是相同的,1表示两个图片的类别是不同的。这样就可以进行模型训练和损失函数的计算了。

  1. train_dataset = mydataset(x_train,y_train)
  2. train_dataloader = DataLoader(dataset = train_dataset,batch_size=8)
  3. val_dataset = mydataset(x_test,y_test)
  4. val_dataloader = DataLoader(dataset = val_dataset,batch_size=8)
  1. for idx,(img1,img2,target) in enumerate(train_dataloader):
  2. fig, axs = plt.subplots(2, img1.shape[0], figsize = (12, 6))
  3. for idx,(ax1,ax2) in enumerate(axs.T):
  4. ax1.imshow(img1[idx,:,:,0].numpy(),cmap='gray')
  5. ax1.set_title('image A')
  6. ax2.imshow(img2[idx,:,:,0].numpy(),cmap='gray')
  7. ax2.set_title('{}'.format('same' if target[idx]==0 else 'different'))
  8. break

这一段的代码就是对一个batch的数据进行一个可视化:

到目前位置应该没有什么问题把,有问题可以联系我讨论交流,WX:cyx645016617.我个人认为从交流中可以快速解决问题和进步。

3 构建模型

  1. class siamese(nn.Module):
  2. def __init__(self,z_dimensions=2):
  3. super(siamese,self).__init__()
  4. self.feature_net = nn.Sequential(
  5. nn.Conv2d(1,4,kernel_size=3,padding=1,stride=1),
  6. nn.ReLU(inplace=True),
  7. nn.BatchNorm2d(4),
  8. nn.Conv2d(4,4,kernel_size=3,padding=1,stride=1),
  9. nn.ReLU(inplace=True),
  10. nn.BatchNorm2d(4),
  11. nn.MaxPool2d(2),
  12. nn.Conv2d(4,8,kernel_size=3,padding=1,stride=1),
  13. nn.ReLU(inplace=True),
  14. nn.BatchNorm2d(8),
  15. nn.Conv2d(8,8,kernel_size=3,padding=1,stride=1),
  16. nn.ReLU(inplace=True),
  17. nn.BatchNorm2d(8),
  18. nn.MaxPool2d(2),
  19. nn.Conv2d(8,1,kernel_size=3,padding=1,stride=1),
  20. nn.ReLU(inplace=True)
  21. )
  22. self.linear = nn.Linear(49,z_dimensions)
  23. def forward(self,x):
  24. x = self.feature_net(x)
  25. x = x.view(x.shape[0],-1)
  26. x = self.linear(x)
  27. return x

一个非常简单的卷积网络,输出的向量的维度就是z-dimensions的大小。

  1. def contrastive_loss(pred1,pred2,target):
  2. MARGIN = 2
  3. euclidean_dis = F.pairwise_distance(pred1,pred2)
  4. target = target.view(-1)
  5. loss = (1-target)*torch.pow(euclidean_dis,2) + target * torch.pow(torch.clamp(MARGIN-euclidean_dis,min=0),2)
  6. return loss

然后构建了一个contrastive loss的损失函数计算。

4 训练

  1. model = siamese(z_dimensions=8).to(device)
  2. # model.load_state_dict(torch.load('../working/saimese.pth'))
  3. optimizor = torch.optim.Adam(model.parameters(),lr=0.001)
  1. for e in range(10):
  2. history = []
  3. for idx,(img1,img2,target) in enumerate(train_dataloader):
  4. img1 = img1.to(device)
  5. img2 = img2.to(device)
  6. target = target.to(device)
  7. pred1 = model(img1)
  8. pred2 = model(img2)
  9. loss = contrastive_loss(pred1,pred2,target)
  10. optimizor.zero_grad()
  11. loss.backward()
  12. optimizor.step()
  13. loss = loss.detach().cpu().numpy()
  14. history.append(loss)
  15. train_loss = np.mean(history)
  16. history = []
  17. with torch.no_grad():
  18. for idx,(img1,img2,target) in enumerate(val_dataloader):
  19. img1 = img1.to(device)
  20. img2 = img2.to(device)
  21. target = target.to(device)
  22. pred1 = model(img1)
  23. pred2 = model(img2)
  24. loss = contrastive_loss(pred1,pred2,target)
  25. loss = loss.detach().cpu().numpy()
  26. history.append(loss)
  27. val_loss = np.mean(history)
  28. print(f'train_loss:{train_loss},val_loss:{val_loss}')

这里为了加快训练,我把batch-size增加到了128个,其他的并没有改变:

这是运行的10个epoch的结果,不要忘记把模型保存一下:

  1. torch.save(model.state_dict(),'saimese.pth')

差不多是这个样子,然后看一看验证集的可视化效果,这里使用的是t-sne高位特征可视化的方法,其内核是PCA降维:

  1. from sklearn import manifold
  2. '''X是特征,不包含target;X_tsne是已经降维之后的特征'''
  3. tsne = manifold.TSNE(n_components=2, init='pca', random_state=501)
  4. X_tsne = tsne.fit_transform(X)
  5. print("Org data dimension is {}. \
  6. Embedded data dimension is {}".format(X.shape[-1], X_tsne.shape[-1]))
  7. x_min, x_max = X_tsne.min(0), X_tsne.max(0)
  8. X_norm = (X_tsne - x_min) / (x_max - x_min) # 归一化
  9. plt.figure(figsize=(8, 8))
  10. for i in range(10):
  11. plt.scatter(X_norm[y==i][:,0],X_norm[y==i][:,1],alpha=0.3,label=f'{i}')
  12. plt.legend()

输入图像为:

可以看得出来,不同类别之间划分的是比较好的,可以看到不同类别之间的距离还是比较大的,比较明显的,甚至可以放下公众号的名字。这里使用的隐变量是8。

这里有一个问题,我内心已有答案不知大家的想法如何,假如我把z潜变量的维度直接改成2,这样就不需要使用tsne和pca的方法来降低维度就可以直接可视化,但是这样的话可视化的效果并不比从8降维到2来可视化的效果好,这是为什么呢?

提示:一方面在于维度过小导致信息的缺失,但是这个解释站不住脚,因为PCA其实等价于一个退化的线形层,所以PCA同样会造成这种缺失;我认为关键应该是损失函数中的欧式距离的计算,如果维度高,那么欧式距离就会偏大,这样需要相应的调整MARGIN的数值。

孪生网络入门(下) Siamese Net分类服装MNIST数据集(pytorch)的更多相关文章

  1. 孪生网络入门(上) Siamese Net及其损失函数

    最近在多个关键词(小数据集,无监督半监督,图像分割,SOTA模型)的范畴内,都看到了这样的一个概念,孪生网络,所以今天有空大概翻看了一下相关的经典论文和博文,之后做了一个简单的案例来强化理解.如果需要 ...

  2. 机器学习-MNIST数据集使用二分类

    一.二分类训练MNIST数据集练习 %matplotlib inlineimport matplotlibimport numpy as npimport matplotlib.pyplot as p ...

  3. Pytorch 入门之Siamese网络

    首次体验Pytorch,本文参考于:github and PyTorch 中文网人脸相似度对比 本文主要熟悉Pytorch大致流程,修改了读取数据部分.没有采用原作者的ImageFolder方法:   ...

  4. 孪生网络(Siamese Network)在句子语义相似度计算中的应用

    1,概述 在NLP中孪生网络基本是用来计算句子间的语义相似度的.其结构如下 在计算句子语义相似度的时候,都是以句子对的形式输入到网络中,孪生网络就是定义两个网络结构分别来表征句子对中的句子,然后通过曼 ...

  5. Linux网络栈下两层实现

    http://www.cnblogs.com/zmkeil/archive/2013/04/18/3029339.html 1.1简介 VLAN是网络栈的一个附加功能,且位于下两层.首先来学习Linu ...

  6. 源码分析——迁移学习Inception V3网络重训练实现图片分类

    1. 前言 近些年来,随着以卷积神经网络(CNN)为代表的深度学习在图像识别领域的突破,越来越多的图像识别算法不断涌现.在去年,我们初步成功尝试了图像识别在测试领域的应用:将网站样式错乱问题.无线领域 ...

  7. EcShop调用显示指定分类下的子分类方法

    ECSHOP首页默认的只有全部分类,还有循环大类以及下面小类的代码,貌似没有可以调用显示指定大类下的子分类代码.于是就有这个文章的产生了,下面由夏日博客来总结下网站建设过程中ECSHOP此类问题的网络 ...

  8. Pytorch入门下 —— 其他

    本节内容参照小土堆的pytorch入门视频教程. 现有模型使用和修改 pytorch框架提供了很多现有模型,其中torchvision.models包中有很多关于视觉(图像)领域的模型,如下图: 下面 ...

  9. 主机WIFI网络环境下,Linux虚拟机网络设置

    在主机使用WIFI网络环境下,怎么样进行虚拟机静态ip设置和连接互联网呢,原理什么太麻烦,另类的网络共享而已: 1.其实简单将网络连接模式设置成NAT模式即可. 2.虚拟网络编辑器依旧是桥接模式,选择 ...

随机推荐

  1. 面试题:对NotNull字段插入Null值 有啥现象?

    Hi,大家好!我是白日梦. 今天我要跟你分享的话题是:"对NotNull字段插入Null值有啥现象?" 一. 推荐阅读 首发地址:https://mp.weixin.qq.com/ ...

  2. Spider_实践_beautifulsoup静态网页爬取所有网页链接

    # 获取百度网站首页上的所有a标签里的 href属性值: # import requests # from bs4 import BeautifulSoup # # html = requests.g ...

  3. mdp文件-Chapter2-NVT.mdp

    这是mdp文件系列的第二篇,介绍nvt平衡中要使用的mdp文件. 先上代码,nvt.mdp 1 title = OPLS Lysozyme NVT equilibration 2 define = - ...

  4. 对udp dns的一次思考

    目前昨天查一个线上问题:""dns服务器在我们的设备, 有大量的终端到设备上请求解析域名,但是一直是单线程,dns报文处理不过来", 然而设备是多核,dns服务器一直不能 ...

  5. android下vulkan与opengles纹理互通

    先放demo源码地址:https://github.com/xxxzhou/aoce 06_mediaplayer 效果图: 主要几个点: 用ffmpeg打开rtmp流. 使用vulkan Compu ...

  6. web安全原理-文件包含漏洞

    前言 起来吃完早饭就开始刷攻防世界的题,一个简单的文件包含题我竟然都做不出来我服了  拿出买的书开始从头学习总结文件包含漏洞! 一.文件包含漏洞 文件包含漏洞 文件包含函数的参数没有经过过滤或者严格的 ...

  7. 给力啊!这篇Spring Bean的依赖注入方式笔记总结真的到位,没见过写的这么细的

    1. Bean的依赖注入概念 依赖注入(Dependency Injection):它是 Spring 框架核心 IOC 的具体实现.在编写程序时,通过控制反转,把对象的创建交给了 Spring,但是 ...

  8. 深度分析:java8的新特性lambda和stream流,看完你学会了吗?

    1. lambda表达式 1.1 什么是lambda 以java为例,可以对一个java变量赋一个值,比如int a = 1,而对于一个方法,一块代码也是赋予给一个变量的,对于这块代码,或者说被赋给变 ...

  9. .Net orm 开源项目 FreeSql 2.0.0(满意的答卷)

    写在开头 2018年11月头脑发热到今天,一晃已经两年,当初从舒服区走向一个巨大的坑,回头一看后背一凉. 两年时间从无到有,经历数不清的日夜奋斗(有人问花了多长时间投入,答案:全职x2 + 两年无休息 ...

  10. 企业BI智能大屏,除了页面炫酷,还能带来什么?

    当我们一谈到可视化大屏,超大画面.超强科技感.酷炫的呈现效果就会出现在我们的脑海中. 所谓数据可视化,就是通过图表.图形.地图等视觉元素,将数据中所蕴含的信息的趋势.异常和模式展现出来.与传统报表相比 ...