CVPR 2018 的一篇少样本学习论文

Learning to Compare: Relation Network for Few-Shot Learning

源码地址:https://github.com/floodsung/LearningToCompare_FSL

在自己的破笔记本上跑了下这个源码,windows 系统,pycharm + Anaconda3 + pytorch-cpu 1.0.1

报了一堆bug, 总结如下:

procs_images.py里 ‘cp’报错

用procs_images.py处理 miniImangenet 数据集的时候:

报错信息:
/LearningToCompare_FSL-master/datas/miniImagenet/proc_images.py
'cp' �����ڲ����ⲿ���Ҳ���ǿ����еij������������ļ���

具体位置是

/datas/miniImagenet/procs_images.py  Line 48:
os.system('cp images/' + image_name + ' ' + cur_dir)

这个‘cp’是linux环境运行的。

用windows系统的话要改成:

os.rename('images/' + image_name, cur_dir + image_name)

除此之外,所有的 os.system('mkdir ' + filename)

也要改成 os.mkdir(filename),虽然不一定会报错。

cpu RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False.

我的torch版本是是cpu, 所以把所有 .cuda(GPU)删了,另外

使用torch.load时添加 ,map_location ='cpu'

以miniImagenet_train_few_shots.py 为例
Line 150:
feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
改成
feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"),map_location = 'cpu'
))
Line:153:
relation_network.load_state_dict(torch.load(str("./models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
改成
relation_network.load_state_dict(torch.load(str("./models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"),map_location = 'cpu'))

KeyError: '..\\datas\\omniglot_resized'

报错信息:
File "LearningToCompare_FSL-master/omniglot/omniglot_train_few_shot.py", line 163, in main
task = tg.OmniglotTask(metatrain_character_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)
File "LearningToCompare_FSL-master\omniglot\task_generator.py", line 72, in <listcomp>
self.train_labels = [labels[self.get_class(x)] for x in self.train_roots]
KeyError: '..\\datas\\omniglot_resized'

关键的地方其实是在:

 task_generator.py, line 74:
  def get_class(self, sample):
return os.path.join(*sample.split('/')[:-1])

print (os.path.join(*sample.split('/')[:-1])) 结果是

..\datas\omniglot_resized

而labels是

{'../datas/omniglot_resized/Malay_(Jawi_-_Arabic)\\character25': 0, '../datas/omniglot_resized/Japanese_(hiragana)\\character15': 1, '…}

而 print(os.path.join(*sample.split('\\')[:-1]))  结果正是

../datas/omniglot_resized/Malay_(Jawi_-_Arabic)\character25

解决方法:把'/'改成'\\'即可 
def get_class(self, sample):
return os.path.join(*sample.split('\\')[:-1])

RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #3 'index'

报错信息:
File "/LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 193, in main
torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1, 1), 1))
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #3 'index'

解决方法:在前面加一句

 batch_labels = batch_labels.long()

RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'other'

报错信息:  
File "LearningToCompare_FSL-master/miniimagenet/miniimagenet_test_few_shot.py", line 247, in <listcomp>
rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(batch_size)]
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'other'

解决方法:在前面加上

predict_labels = predict_labels.long()
test_labels = test_labels.long()

这两个好像是使用torch的数据格式问题

IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

报错信息:
File "LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 212, in main
print("episode:",episode+1,"loss",loss.data[0])
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number 按要求改成
print("episode:", episode + 1, "loss", loss.item())
就可以了

RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

报错信息:
File "LearningToCompare_FSL-master\omniglot\task_generator.py", line 107, in __getitem__
image = self.transform(image)
File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\transforms.py", line 60, in __call__
img = t(img)
File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\transforms.py", line 163, in __call__
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\functional.py", line 208, in normalize
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

这个是使用Omniglot数据集时的报错,主要原因在于

"\omniglot\task_generator.py", line 139:

def get_data_loader(task, num_per_class=1, split='train',shuffle=True,rotation=0):
normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
dataset = Omniglot(task,split=split,transform=transforms.Compose([Rotate(rotation),transforms.ToTensor(),normalize]))

使用 torch.transforms 中 normalize 用了 3 通道,而实际使用的数据集Omniglot 图片大小是 [1, 28, 28]

解决方法:


normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
改成
normalize = transforms.Normalize(mean=[0.92206], std=[0.08426])

UserWarning: nn.functional.sigmoid is deprecated.

类似的warning 还有

UserWarning : torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_.

按要求改就行

torch.nn.utils.clip_grad_norm(feature_encoder.parameters(), 0.5)
改成
torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5) def forward里的
out = F.sigmoid(self.fc2(out))
改成
out = F.torch.sigmoid(self.fc2(out))

Learning to Compare: Relation Network 源码调试的更多相关文章

  1. Learning to Compare: Relation Network for Few-Shot Learning 论文笔记

    主要原理: 和Siamese Neural Networks一样,将分类问题转换成两个输入的相似性问题. 和Siamese Neural Networks不同的是: Relation Network中 ...

  2. 开启Tomcat 源码调试

    开启Tomcat 源码调试 因为工作的原因,需要了解Tomcat整个架构是如何设计的,正如要使用Spring MVC进行Web开发,需要了解Spring是如何设计的一样,有哪些主要的类,分别是用于干什 ...

  3. 在Eclipse中进行HotSpot的源码调试--转

    原文地址:http://www.linuxidc.com/Linux/2015-05/117250.htm 在阅读OpenJDK源码的过程中,经常需要运行.调试程序来帮助理解.我们现在已经可以编译出一 ...

  4. [原创]在Windows和Linux中搭建PostgreSQL源码调试环境

    张文升http://ode.cnblogs.comEmail:wensheng.zhang#foxmail.com 配图太多,完整pdf下载请点这里 本文使用Xming.Putty和VMWare几款工 ...

  5. SpringMVC DispatcherServlet 启动和加载过程(源码调试)

    在阅读本文前,最好先阅读以下内容(当然,如果对 Servlet 已经有所了解,则可跳过): http://www.cnblogs.com/cyhbyw/p/8682078.html http://ww ...

  6. 《k8s-1.13版本源码分析》-源码调试

    源码分析系列文章已经开源到github,地址如下: github:https://github.com/farmer-hutao/k8s-source-code-analysis gitbook:ht ...

  7. SpringBoot自动配置源码调试

    之前对SpringBoot的自动配置原理进行了较为详细的介绍(https://www.cnblogs.com/stm32stm32/p/10560933.html),接下来就对自动配置进行源码调试,探 ...

  8. HashMap源码调试——认识"put"操作

    前言:通常大家都知道HashMap的底层数据结构为数组加链表的形式,但其put操作具体是怎样执行的呢,本文通过调试HashMap的源码来阐述这一问题. 注:jdk版本:jdk1.7.0_51 1.pu ...

  9. .net源码调试 http://referencesource.microsoft.com/

    其实关于.net源码调试 网上的资料已经很多了,我以前转载的文章有 VS2010下如何调试Framework源代码(即FCL) 和 如何使你的应用程序调试进.NET Framework 4.5源代码内 ...

随机推荐

  1. DNS message解析

    案例吐个槽,命苦啊,要自己动手解包. 另外,这里的内容是半路找来的,如果有冲突,自行翻阅rfc1035.我还没校正过. The Structure 如下图: 所有的DNS message都包含了下面这 ...

  2. PgSql备份pg_dump与还原手记pg_restore(转)

    可以直接跳转至最后面的示例进行查看 真没有想到,以前一直是PostgreSQL使用者,突然需要库移植又成了头一招了!原来它与mysql命令行操作区别还挺大. 不用怕,但绝对要细心,因为数据库操作是网站 ...

  3. 在进程中执行新代码 execl、execle、execlp、execv、execve和execvp函数

    摘要:本文主要讲述怎样在进程中执行新代码,以及exec系列函数的基本用法. 在进程中执行新代码 用函数fork创建子进程后,假设希望在当前子进程中运行新的程序,能够调用exec函数运行还有一个程序.当 ...

  4. 如何更改Docker默认的images存储位置

    Docker的镜像以及一些数据都是在/var/lib/docker目录下,它占用的是Linux的系统分区,也就是下面的/dev/vda1,当有多个镜像时,/dev/vda1的空间可能不足,我们可以把d ...

  5. DRUPAL性能优化【转】

    1.启用memcache代替Mysql的缓存表处理缓存数据. 2.添加一个opcode缓存可以让 PHP能够重用前面编译过的代码,这样就会跳过解析和编译.常见的opcode缓存有Alternative ...

  6. [转]SIGPIPE信号

    我写了一个服务器程序,在Linux下测试,然后用C++写了客户端用千万级别数量的短链接进行压力测试.  但是服务器总是莫名退出,没有core文件. 最后问题确定为, 对一个对端已经关闭的socket调 ...

  7. Mysql各种类型字段长度

    1.数值类型 列类型 需要的存储量 TINYINT 1 字节 SMALLINT 2 个字节 MEDIUMINT 3 个字节 INT 4 个字节 INTEGER 4 个字节 BIGINT 8 个字节 F ...

  8. Linux学习一

    1.Linux的优缺点: 长处: 稳定的系统 免费或少许费用 安全性,漏洞的高速修补 多任务,多用户 用户与用户的规划 相对不耗资源的系统 适合须要小内核的嵌入式系统 整合度佳且多样的图形用户界面 缺 ...

  9. 通过 SysVinit、Systemd 和 Upstart 管理系统自启动进程和服务

    管理 Linux 自启动进程 Linux 系统的启动程序包括多个阶段,每个阶段由一个不同的图示块表示.下面的图示简要总结了启动过程以及所有包括的主要组件. Linux 启动过程 当你按下你机器上的电源 ...

  10. [JNA系列]Java调用Delphi编写的Dll之Delphi与JAVA基本数据类型对比

    Delphi与JAVA基本数据类型对比 类型 Delphi关键字 JAVA关键字 字节 备注 范围 整型 Shortint byte 1 有符号8位 -128..127 Byte 1 无符号8位 0 ...