转载自:https://blog.csdn.net/qq1483661204/article/details/79039702

Learning a Similarity Metric Discriminatively, with Application to Face
Verification 这个siamese文章链接。
本文主要讲解siamese网络,并用tensorflwo实现,在mnist数据集中,siamese网络和其他网络的不同之处在于,首先他是两个输入,它输入的不是标签,而是是否是同一类别,如果是同一类别就是0,否则就是1,文章中是用这个网络来做人脸识别,网络结构图如下:

从图中可以看到,他又两个输入,分别是下x1和x2,左右两个的网咯结构是一样的,并且他们共享权重,最后得到两个输出,分别是Gw(x1)和Gw(x2),这个网络的很好理解,当输入是同一张图片的时候,我们希望他们呢之间的欧式距离很小,当不是一张图片时,我们的欧式距离很大。有了网路结构,接下来就是定义损失函数,这个很重要,而经过我们的分析,我们可以知道,损失函数的特点应该是这样的,
(1) 当我们输入同一张图片时,他们之间的欧式距离越小,损失是越小的,距离越大,损失越大
(2) 当我们的输入是不同的图片的时候,他们之间的距离越大,损失越大
怎么理解呢,很简单,我们就是最小化把相同类的数据之间距离,最大化不同类之间的距离。
然后文章中定义的损失函数如下:
首先是定义距离,使用l2范数,公式如下:

距离其实就是欧式距离,有了距离,我们的损失函数和距离的关系我上面说了,如何包证满足上面的要求呢,文章提出这样的损失函数:

其中我们的Ew就是距离,Lg和L1相当于是一个系数,这个损失函数和交叉熵其实挺像,为了让损失函数满足上面的关系,让Lg满足单调递减,LI满足单调递增就可以。另外一个条件是:同类图片之间的距离必须比不同类之间的距离小,
其他条件如下:

然后作者也给出了证明,最终损失函数为:

Q是一个常数,这个损失函数就满足上面的关系,然后我用tensoflow写了一个损失函数如下:

需要强调的是,这个地方同一类图片是0,不同类图片是1,然后我自己用tensorflow实现的这个损失函数如下:

def siamese_loss(out1,out2,y,Q=5):

    Q = tf.constant(Q, name="Q",dtype=tf.float32)
E_w = tf.sqrt(tf.reduce_sum(tf.square(out1-out2),1))
pos = tf.multiply(tf.multiply(y,2/Q),tf.square(E_w))
neg = tf.multiply(tf.multiply(1-y,2*Q),tf.exp(-2.77/Q*E_w))
loss = pos + neg
loss = tf.reduce_mean(loss)
return loss

这就是损失函数,其他的代码如下:

 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
tf.reset_default_graph()
mnist = input_data.read_data_sets('./data/mnist',one_hot=True)
print(mnist.validation.num_examples)
print(mnist.train.num_examples)
print(mnist.test.num_examples)
def siamese_loss(out1,out2,y,Q=5): Q = tf.constant(Q, name="Q",dtype=tf.float32)
E_w = tf.sqrt(tf.reduce_sum(tf.square(out1-out2),1))
pos = tf.multiply(tf.multiply(y,2/Q),tf.square(E_w))
neg = tf.multiply(tf.multiply(1-y,2*Q),tf.exp(-2.77/Q*E_w))
loss = pos + neg
loss = tf.reduce_mean(loss)
return loss def siamese(inputs,keep_prob):
with tf.name_scope('conv1') as scope:
w1 = tf.Variable(tf.truncated_normal(shape=[3,3,1,32],stddev=0.05),name='w1')
b1 = tf.Variable(tf.zeros(32),name='b1')
conv1 = tf.nn.conv2d(inputs,w1,strides=[1,1,1,1],padding='SAME',name='conv1')
with tf.name_scope('relu1') as scope:
relu1 = tf.nn.relu(tf.add(conv1,b1),name='relu1')
with tf.name_scope('conv2') as scope:
w2 = tf.Variable(tf.truncated_normal(shape=[3,3,32,64],stddev=0.05),name='w2')
b2 = tf.Variable(tf.zeros(64),name='b2')
conv2 = tf.nn.conv2d(relu1,w2,strides=[1,2,2,1],padding='SAME',name='conv2')
with tf.name_scope('relu2') as scope:
relu2 = tf.nn.relu(conv2+b2,name='relu2') with tf.name_scope('conv3') as scope: w3 = tf.Variable(tf.truncated_normal(shape=[3,3,64,128],mean=0,stddev=0.05),name='w3')
b3 = tf.Variable(tf.zeros(128),name='b3')
conv3 = tf.nn.conv2d(relu2,w3,strides=[1,2,2,1],padding='SAME')
with tf.name_scope('relu3') as scope:
relu3 = tf.nn.relu(conv3+b3,name='relu3') with tf.name_scope('fc1') as scope:
x_flat = tf.reshape(relu3,shape=[-1,7*7*128])
w_fc1=tf.Variable(tf.truncated_normal(shape=[7*7*128,1024],stddev=0.05,mean=0),name='w_fc1')
b_fc1 = tf.Variable(tf.zeros(1024),name='b_fc1')
fc1 = tf.add(tf.matmul(x_flat,w_fc1),b_fc1)
with tf.name_scope('relu_fc1') as scope:
relu_fc1 = tf.nn.relu(fc1,name='relu_fc1') with tf.name_scope('drop_1') as scope: drop_1 = tf.nn.dropout(relu_fc1,keep_prob=keep_prob,name='drop_1')
with tf.name_scope('bn_fc1') as scope:
bn_fc1 = tf.layers.batch_normalization(drop_1,name='bn_fc1')
with tf.name_scope('fc2') as scope:
w_fc2 = tf.Variable(tf.truncated_normal(shape=[1024,512],stddev=0.05,mean=0),name='w_fc2')
b_fc2 = tf.Variable(tf.zeros(512),name='b_fc2')
fc2 = tf.add(tf.matmul(bn_fc1,w_fc2),b_fc2)
with tf.name_scope('relu_fc2') as scope:
relu_fc2 = tf.nn.relu(fc2,name='relu_fc2')
with tf.name_scope('drop_2') as scope:
drop_2 = tf.nn.dropout(relu_fc2,keep_prob=keep_prob,name='drop_2')
with tf.name_scope('bn_fc2') as scope:
bn_fc2 = tf.layers.batch_normalization(drop_2,name='bn_fc2')
with tf.name_scope('fc3') as scope:
w_fc3 = tf.Variable(tf.truncated_normal(shape=[512,2],stddev=0.05,mean=0),name='w_fc3')
b_fc3 = tf.Variable(tf.zeros(2),name='b_fc3')
fc3 = tf.add(tf.matmul(bn_fc2,w_fc3),b_fc3)
return fc3 lr = 0.01
iterations = 20000
batch_size = 64 with tf.variable_scope('input_x1') as scope:
x1 = tf.placeholder(tf.float32, shape=[None, 784])
x_input_1 = tf.reshape(x1, [-1, 28, 28, 1])
with tf.variable_scope('input_x2') as scope:
x2 = tf.placeholder(tf.float32, shape=[None, 784])
x_input_2 = tf.reshape(x2, [-1, 28, 28, 1])
with tf.variable_scope('y') as scope:
y = tf.placeholder(tf.float32, shape=[batch_size]) with tf.name_scope('keep_prob') as scope:
keep_prob = tf.placeholder(tf.float32) with tf.variable_scope('siamese') as scope:
out1 = siamese(x_input_1,keep_prob)
scope.reuse_variables()
out2 = siamese(x_input_2,keep_prob)
with tf.variable_scope('metrics') as scope:
loss = siamese_loss(out1, out2, y)
optimizer = tf.train.AdamOptimizer(lr).minimize(loss) loss_summary = tf.summary.scalar('loss',loss)
merged_summary = tf.summary.merge_all() with tf.Session() as sess: writer = tf.summary.FileWriter('./graph/siamese',sess.graph)
sess.run(tf.global_variables_initializer()) for itera in range(iterations):
xs_1, ys_1 = mnist.train.next_batch(batch_size)
ys_1 = np.argmax(ys_1,axis=1)
xs_2, ys_2 = mnist.train.next_batch(batch_size)
ys_2 = np.argmax(ys_2,axis=1)
y_s = np.array(ys_1==ys_2,dtype=np.float32)
_,train_loss,summ = sess.run([optimizer,loss,merged_summary],feed_dict={x1:xs_1,x2:xs_2,y:y_s,keep_prob:0.6}) writer.add_summary(summ,itera)
if itera % 1000 == 1 :
print('iter {},train loss {}'.format(itera,train_loss))
embed = sess.run(out1,feed_dict={x1:mnist.test.images,keep_prob:0.6})
test_img = mnist.test.images.reshape([-1,28,28,1])
writer.close()

tensorflow实现siamese网络 (附代码)的更多相关文章

  1. SVM原理以及Tensorflow 实现SVM分类(附代码)

    1.1. SVM介绍 1.2. 工作原理 1.2.1. 几何间隔和函数间隔 1.2.2. 最大化间隔 - 1.2.2.0.0.1. \(L( {x}^*)\)对$ {x}^*$求导为0 - 1.2.2 ...

  2. siamese网络&&tripletnet

    siamese网络 - 之前记录过: https://www.cnblogs.com/ranjiewen/articles/7736089.html - 原始的siamese network: 输入一 ...

  3. 十图详解tensorflow数据读取机制(附代码)转知乎

    十图详解tensorflow数据读取机制(附代码) - 何之源的文章 - 知乎 https://zhuanlan.zhihu.com/p/27238630

  4. Siamese网络

    1.       对比损失函数(Contrastive Loss function) 孪生架构的目的不是对输入图像进行分类,而是区分它们.因此,分类损失函数(如交叉熵)不是最合适的选择,这种架构更适合 ...

  5. tensorflow笔记:多层LSTM代码分析

    tensorflow笔记:多层LSTM代码分析 标签(空格分隔): tensorflow笔记 tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) ten ...

  6. Pytorch 入门之Siamese网络

    首次体验Pytorch,本文参考于:github and PyTorch 中文网人脸相似度对比 本文主要熟悉Pytorch大致流程,修改了读取数据部分.没有采用原作者的ImageFolder方法:   ...

  7. tensorflow笔记:多层CNN代码分析

    tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 ...

  8. SpringCloud-使用熔断器防止服务雪崩-Ribbon和Feign方式(附代码下载)

    场景 SpringCloud-服务注册与实现-Eureka创建服务注册中心(附源码下载): https://blog.csdn.net/BADAO_LIUMANG_QIZHI/article/deta ...

  9. 小姐姐带你一起学:如何用Python实现7种机器学习算法(附代码)

    小姐姐带你一起学:如何用Python实现7种机器学习算法(附代码) Python 被称为是最接近 AI 的语言.最近一位名叫Anna-Lena Popkes的小姐姐在GitHub上分享了自己如何使用P ...

随机推荐

  1. 第六章 ZYNQ-MIZ701 GPIO使用之MIO

      6.0 本章难度系数★★☆☆☆☆☆ 6.1 GPIO简介 Zynq7000系列芯片有54个MIO(multiuse I/O),它们分配在 GPIO 的Bank0 和Bank1隶属于PS部分,这些I ...

  2. 创建web服务器

    用node创建本地web服务 1,创建本地文件server.js var http = require('http'); var url=require('url'); var fs=require( ...

  3. vue阻止右键默认行为

    vue阻止右键默认行为 <!--不阻止右键菜单(浏览器行为),右键执行函数show--> <input type="button" value="按钮& ...

  4. CSP-S2019题解

    格雷码 €€£:我不抄自己辣!JOJO! 这题比那个SCOI的炒鸡格雷码好多了,甚至告诉你构造方法,所以... void wk(uLL kk) { int j=0; for(uLL i=n-1;~i; ...

  5. ES5中一些重要的拓展

    1.对象的拓展 ①Object.create(obj, {age:{value:18, writable:true, configurable:true, enumerable:true}); 以指定 ...

  6. Win Server 2012 配置运行 .net core 环境

    今天拿到一台 全新的win 2012 服务器配置服务器环境 记录一下 首先装好IIS 打开服务器管理器  - 添加功能和角色     好 安装完IIS 看一下服务器有没有安装 core的运行环境(全新 ...

  7. bash shell脚本之使用expr运算

    bash shell中的数学运算 cat test7: #!/bin/bash # An example of using the expr command var1= var2= var3=`exp ...

  8. yii框架下使用redis

    1 首先获取到 yii2-redis-master.zip 压缩包 下载地址https://github.com/yiisoft/yii2-redis/archive/master.zip 2 把下载 ...

  9. MySQL无法启动:ERROR 2002 (HY000): Can't connect to local MySQL server through socket '/var/lib/mysql/mysql.sock' (2)

    1 详细异常 ct 11 17:31:51 bd02.getngo.com mysqld[20513]: 2019-10-11T09:31:51.187848Z 0 [Note] /usr/sbin/ ...

  10. 03.Zabbix应用服务监控

    一.Zabbix监控Nginx 1.1 实验环境 服务器系统 角色 IP CentOS 7.4 x86_64 Zabbix-Server 192.168.90.10 CentOS 7.4 x86_64 ...