Learning Memory-guided Normality代码学习笔记

记忆模块核心

Memory部分的核心在于以下定义Memory类的部分。

  1. class Memory(nn.Module):
  2. def __init__(self, memory_size, feature_dim, key_dim, temp_update, temp_gather):
  3. super(Memory, self).__init__()
  4. # Constants
  5. self.memory_size = memory_size
  6. self.feature_dim = feature_dim
  7. self.key_dim = key_dim
  8. self.temp_update = temp_update
  9. self.temp_gather = temp_gather
  10. def hard_neg_mem(self, mem, i):
  11. similarity = torch.matmul(mem,torch.t(self.keys_var))
  12. similarity[:,i] = -1
  13. _, max_idx = torch.topk(similarity, 1, dim=1)
  14. return self.keys_var[max_idx]
  15. def random_pick_memory(self, mem, max_indices):
  16. m, d = mem.size()
  17. output = []
  18. for i in range(m):
  19. flattened_indices = (max_indices==i).nonzero()
  20. a, _ = flattened_indices.size()
  21. if a != 0:
  22. number = np.random.choice(a, 1)
  23. output.append(flattened_indices[number, 0])
  24. else:
  25. output.append(-1)
  26. return torch.tensor(output)
  27. def get_update_query(self, mem, max_indices, update_indices, score, query, train):
  28. m, d = mem.size()
  29. if train:
  30. query_update = torch.zeros((m,d)).cuda()
  31. # random_update = torch.zeros((m,d)).cuda()
  32. for i in range(m):
  33. idx = torch.nonzero(max_indices.squeeze(1)==i)
  34. a, _ = idx.size()
  35. if a != 0:
  36. query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
  37. else:
  38. query_update[i] = 0
  39. return query_update
  40. else:
  41. query_update = torch.zeros((m,d)).cuda()
  42. for i in range(m):
  43. idx = torch.nonzero(max_indices.squeeze(1)==i)
  44. a, _ = idx.size()
  45. if a != 0:
  46. query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
  47. else:
  48. query_update[i] = 0
  49. return query_update
  50. def get_score(self, mem, query):
  51. bs, h,w,d = query.size()
  52. m, d = mem.size()
  53. score = torch.matmul(query, torch.t(mem))# b X h X w X m
  54. score = score.view(bs*h*w, m)# (b X h X w) X m
  55. score_query = F.softmax(score, dim=0)
  56. score_memory = F.softmax(score,dim=1)
  57. return score_query, score_memory
  58. def forward(self, query, keys, train=True):
  59. batch_size, dims,h,w = query.size() # b X d X h X w
  60. query = F.normalize(query, dim=1)
  61. query = query.permute(0,2,3,1) # b X h X w X d
  62. #train
  63. if train:
  64. #losses
  65. separateness_loss, compactness_loss = self.gather_loss(query,keys, train)
  66. # read
  67. updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys)
  68. #update
  69. updated_memory = self.update(query, keys, train)
  70. return updated_query, updated_memory, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss
  71. #test
  72. else:
  73. # loss
  74. compactness_loss, query_re, top1_keys, keys_ind = self.gather_loss(query,keys, train)
  75. # read
  76. updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys)
  77. #update
  78. updated_memory = keys
  79. return updated_query, updated_memory, softmax_score_query, softmax_score_memory, query_re, top1_keys,keys_ind, compactness_loss
  80. def update(self, query, keys,train):
  81. batch_size, h,w,dims = query.size() # b X h X w X d
  82. softmax_score_query, softmax_score_memory = self.get_score(keys, query)
  83. query_reshape = query.contiguous().view(batch_size*h*w, dims)
  84. _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1)
  85. _, updating_indices = torch.topk(softmax_score_query, 1, dim=0)
  86. if train:
  87. query_update = self.get_update_query(keys, gathering_indices, updating_indices, softmax_score_query, query_reshape,train)
  88. updated_memory = F.normalize(query_update + keys, dim=1)
  89. else:
  90. query_update = self.get_update_query(keys, gathering_indices, updating_indices, softmax_score_query, query_reshape, train)
  91. updated_memory = F.normalize(query_update + keys, dim=1)
  92. return updated_memory.detach()
  93. def pointwise_gather_loss(self, query_reshape, keys, gathering_indices, train):
  94. n,dims = query_reshape.size() # (b X h X w) X d
  95. loss_mse = torch.nn.MSELoss(reduction='none')
  96. pointwise_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach())
  97. return pointwise_loss
  98. def gather_loss(self,query, keys, train):
  99. batch_size, h,w,dims = query.size() # b X h X w X d
  100. if train:
  101. loss = torch.nn.TripletMarginLoss(margin=1.0)
  102. loss_mse = torch.nn.MSELoss()
  103. softmax_score_query, softmax_score_memory = self.get_score(keys, query)
  104. query_reshape = query.contiguous().view(batch_size*h*w, dims)
  105. _, gathering_indices = torch.topk(softmax_score_memory, 2, dim=1)
  106. #1st, 2nd closest memories
  107. pos = keys[gathering_indices[:,0]]
  108. neg = keys[gathering_indices[:,1]]
  109. top1_loss = loss_mse(query_reshape, pos.detach())
  110. gathering_loss = loss(query_reshape,pos.detach(), neg.detach())
  111. return gathering_loss, top1_loss
  112. else:
  113. loss_mse = torch.nn.MSELoss()
  114. softmax_score_query, softmax_score_memory = self.get_score(keys, query)
  115. query_reshape = query.contiguous().view(batch_size*h*w, dims)
  116. _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1)
  117. gathering_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach())
  118. return gathering_loss, query_reshape, keys[gathering_indices].squeeze(1).detach(), gathering_indices[:,0]
  119. def read(self, query, updated_memory):
  120. batch_size, h,w,dims = query.size() # b X h X w X d
  121. softmax_score_query, softmax_score_memory = self.get_score(updated_memory, query)
  122. query_reshape = query.contiguous().view(batch_size*h*w, dims)
  123. concat_memory = torch.matmul(softmax_score_memory.detach(), updated_memory) # (b X h X w) X d
  124. updated_query = torch.cat((query_reshape, concat_memory), dim = 1) # (b X h X w) X 2d
  125. updated_query = updated_query.view(batch_size, h, w, 2*dims)
  126. updated_query = updated_query.permute(0,3,1,2)
  127. return updated_query, softmax_score_query, softmax_score_memory

Update过程

调用get_update_query(self, mem, max_indices, update_indices, score, query, train)函数计算\(query\_ dpdate= \sum_{k \in U_{t}^M} v_t^{'k,m}q_t^k\)

然后计算\(f(P^m+query_dpdate)\)

文中对f的描述为L2正则。

看一下get_update_query函数的定义:

  1. def get_update_query(self, mem, max_indices, update_indices, score, query, train):
  2. m, d = mem.size()
  3. if train:
  4. query_update = torch.zeros((m,d)).cuda()
  5. # random_update = torch.zeros((m,d)).cuda()
  6. for i in range(m):
  7. idx = torch.nonzero(max_indices.squeeze(1)==i)
  8. a, _ = idx.size()
  9. if a != 0:
  10. query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
  11. else:
  12. query_update[i] = 0
  13. return query_update
  14. else:
  15. query_update = torch.zeros((m,d)).cuda()
  16. for i in range(m):
  17. idx = torch.nonzero(max_indices.squeeze(1)==i)
  18. a, _ = idx.size()
  19. if a != 0:
  20. query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
  21. else:
  22. query_update[i] = 0
  23. return query_update

在定义中,我们需要看到\(v_t^{'k,m}\)的计算。代码是通过(score[idx,i] / torch.max(score[:,i])实现的,进一步,我们需要查看\(v_t^{k,m}\)的计算过程。这个参数与\(w\)一样是权重,文中通过get_score函数计算权重,如下为此函数的定义:

  1. def get_score(self, mem, query):
  2. #计算权重$w_t^{k,m}$
  3. bs, h,w,d = query.size()
  4. m, d = mem.size()
  5. score = torch.matmul(query, torch.t(mem))# b X h X w X m
  6. score = score.view(bs*h*w, m)# (b X h X w) X m
  7. score_query = F.softmax(score, dim=0)
  8. score_memory = F.softmax(score,dim=1)
  9. return score_query, score_memory

实现了文献中的权重计算

Read过程

  1. def read(self, query, updated_memory):
  2. #Read部分
  3. batch_size, h,w,dims = query.size() # b X h X w X d
  4. softmax_score_query, softmax_score_memory = self.get_score(updated_memory, query)
  5. query_reshape = query.contiguous().view(batch_size*h*w, dims)
  6. concat_memory = torch.matmul(softmax_score_memory.detach(), updated_memory) # (b X h X w) X d
  7. # 权重和memory获得加权均值
  8. updated_query = torch.cat((query_reshape, concat_memory), dim = 1) # (b X h X w) X 2d
  9. # 进行拼接
  10. updated_query = updated_query.view(batch_size, h, w, 2*dims)
  11. updated_query = updated_query.permute(0,3,1,2)
  12. return updated_query, softmax_score_query, softmax_score_memory

核心部分在代码中给出了注释。

forward过程

  1. separateness_loss, compactness_loss = self.gather_loss(query,keys, train)
  2. # read
  3. updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys)
  4. #update
  5. updated_memory = self.update(query, keys, train)
  6. return updated_query, updated_memory, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss

分别调用update函数和read函数

需要说明损失函数的定义,\(L = L_{rec} + \lambda _cL_{compact}+ \lambda _sL_{separate}\)

代码中通过gather_loss函数实现。

  1. def gather_loss(self,query, keys, train):
  2. batch_size, h,w,dims = query.size() # b X h X w X d
  3. if train:
  4. loss = torch.nn.TripletMarginLoss(margin=1.0)
  5. # 计算Feature separateness loss的主要函数
  6. loss_mse = torch.nn.MSELoss()
  7. # 计算均方差损失
  8. softmax_score_query, softmax_score_memory = self.get_score(keys, query)
  9. query_reshape = query.contiguous().view(batch_size*h*w, dims)
  10. _, gathering_indices = torch.topk(softmax_score_memory, 2, dim=1)
  11. #1st, 2nd closest memories
  12. pos = keys[gathering_indices[:,0]]
  13. neg = keys[gathering_indices[:,1]]
  14. top1_loss = loss_mse(query_reshape, pos.detach())
  15. gathering_loss = loss(query_reshape,pos.detach(), neg.detach())
  16. return gathering_loss, top1_loss
  17. else:
  18. loss_mse = torch.nn.MSELoss()
  19. softmax_score_query, softmax_score_memory = self.get_score(keys, query)
  20. query_reshape = query.contiguous().view(batch_size*h*w, dims)
  21. _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1)
  22. gathering_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach())
  23. return gathering_loss, query_reshape, keys[gathering_indices].squeeze(1).detach(), gathering_indices[:,0]

Learning Memory-guided Normality代码学习笔记的更多相关文章

  1. DeepLearnToolbox-master代码学习笔记

    卷积神经网络(CNN)博大精深,网上资料浩如烟海,让初学者无从下手.笔者以为,学习编程还是从代码实例入们最好.目前,学习CNN最好的代码实例就是,DeepLearnToolbox-master,不用装 ...

  2. 《Learning Play! Framework 2》学习笔记——案例研究1(Templating System)

    注解: 这是对<Learning Play! Framework 2>第三章的学习 本章是一个显示聊天记录的项目,只有一个页面,可以自动对聊天记录进行排序.分组和显示,并整合使用了less ...

  3. Machine Learning In Action 第二章学习笔记: kNN算法

    本文主要记录<Machine Learning In Action>中第二章的内容.书中以两个具体实例来介绍kNN(k nearest neighbors),分别是: 约会对象预测 手写数 ...

  4. C# 好代码学习笔记(1):文件操作、读取文件、Debug/Trace 类、Conditional条件编译、CLS

    目录 1,文件操作 2,读取文件 3,Debug .Trace类 4,条件编译 5,MethodImpl 特性 5,CLSCompliantAttribute 6,必要时自定义类型别名 目录: 1,文 ...

  5. 1.JAVA中使用JNI调用C++代码学习笔记

    Java 之JNI编程1.什么是JNI? JNI:(Java Natibe Inetrface)缩写. 2.为什么要学习JNI?  Java 是跨平台的语言,但是在有些时候仍然是有需要调用本地代码 ( ...

  6. APM代码学习笔记1

    libraries目录 传感器 AP_InertialSensor 惯性导航传感器 就是陀螺仪加速计 AP_Baro 气压计 居然支持BMP085 在我印象中APM一直用高端的MS5611 AP_Co ...

  7. boost timer代码学习笔记

    socket连接中需要判断超时 所以这几天看了看boost中计时器的文档和示例 一共有五个例子 从简单的同步等待到异步调用超时处理 先看第一个例子 // timer1.cpp: 定义控制台应用程序的入 ...

  8. Hands on Machine Learning with Sklearn and TensorFlow学习笔记——机器学习概览

    一.什么是机器学习? 计算机程序利用经验E(训练数据)学习任务T(要做什么,即目标),性能是P(性能指标),如果针对任务T的性能P随着经验E不断增长,成为机器学习.[这是汤姆米切尔在1997年定义] ...

  9. cc代码学习笔记1

    #define #define INT32 int #define INT8 char #define CHAR char #define SSHORT signed short #define IN ...

随机推荐

  1. Hi3559AV100 NNIE开发(4)mobilefacenet.cfg参数配置挖坑解决与SVP_NNIE_Cnn实现分析

    前面随笔给出了NNIE开发的基本知识,下面几篇随笔将着重于Mobilefacenet NNIE开发,实现mobilefacenet.wk的chip版本,并在Hi3559AV100上实现mobilefa ...

  2. 题解 CF746D 【Green and Black Tea】

    # 题目分析这道题表面上看上去挺简单,其实仔细研究一下还是值得钻研的.我本人做这道题使用的任然是$ DFS01 $背包.不过呢,与往常背包不同的是,这次递归中需要加许多参数.就数据强度来看,栈问题不大 ...

  3. go调用python命令行参数过量报错python.exe: The filename or extension is too long.的解决方法

    当我们在调用python时,如果传入的参数数据量过大时会报错 python.exe: The filename or extension is too long. 这时候我们的解决办法是放弃传参,将想 ...

  4. hexo+github 博客绑定域名

    关于博客的搭建分为以下几步: 申请域名可以在万维网上申请一个自己的独特域名,本博客的域名即为zhengwei.xyz. 域名解析域名申请成功后继续在万维网上进行操作,进入管理自己的域名界面,在要解析的 ...

  5. python 序列与字典

    序列概念: 序列的成员有序排列,可以通过下标访问到一个或几个元素,就类似与c语言的数组. 序列的通用的操作: 1:索引 11 = [1,2,3,4] 11[0] = 1 2:切片 11[1,2,3,4 ...

  6. P1739_表达式括号匹配(JAVA语言)

    思路:刚开始想用stack,遇到'('就push,遇到')'就pop,后来发现其实我们只需要用到栈里'('的个数,所以我们用一个变量统计'('的个数就好啦~ 题目描述 假设一个表达式有英文字母(小写) ...

  7. 攻防世界 reverse BabyXor

    BabyXor     2019_UNCTF 查壳 脱壳 dump 脱壳后 IDA静态分析 int main_0() { void *v0; // eax int v1; // ST5C_4 char ...

  8. 简易计算器实现:while循环+switch语句

    个人练习: 写一个计算器,要求实现加减乘除功能,并且能循环接收新的数据,通过用户交互实现(即Scanner对象) 用到了 while循环 switch语句,实现了数据的循环输入并计算!!!!妙啊!!! ...

  9. 什么是一致性hash?

    一致性hash 前言 说出来大家可能不相信,我昨天做梦梦到自己在面试,然后面试官问了我这个问题哈哈~然后我就打算按照自己的理解写一写.如果有写的不对的欢迎大家指正! 直接开始 普通hash算法 普通h ...

  10. 【SqlServer】管理全文索引(FULL TEXT INDEX)

    Sql Server中的全文索引(下面统一使用FULLTEXT INDEX来表示全文索引),是一种特定语言搜索索引功能.它和LIKE的不一样,LIKE主要是根据搜索模板搜索数据,它的效率比FULLTE ...