PyTorch随手记

Note:

1. 模型操作

假设我们有一个用self.arcnn = nn.Sequential(...)定义并训练好的ARCNN模型。我们想迁移过来,冻结前几层再训练。分两步:

  1. print(model.state_dict())查看名称,如'arcnn.12.bias', 'arcnn.12.weight'等。

  2. model.arcnn[0].weight.requires_grad = Falsemodel.arcnn[0].bias.requires_grad = False,让第一层冻结。

2. 网络设计

卷积图示

GitHub

填充(padding)

PyTorch和TensorFlow的填充规则是不同的。因此必须查阅官方文档

如果y = F.pad(x, (1,2,3,4)),意思是:在\(x\)的最后一个维度上(一般是W),左边填一圈零,右边填两圈0(默认为0);在\(x\)的倒数第二个维度上(一般是H),上面填3圈零,下面填4圈零。

升采样

其中有一个参数align_corners。例子参见官方教程里的Example

这里有一个图例:

全连接层

假设我们经过多层卷积,得到了\((128, 32, 4, 4)\)的通道,即batch size为128,32张特征图,通道尺寸为\(4 \times 4\)。我们希望基于此得到2分类。那么可以如下操作:

self.l1 = nn.Linear(32 * 4 * 4, 128)
self.l2 = nn.Linear(128, 32)
self.l3 = nn.Linear(32, 2) x = x.view(-1, 32 * 4 * 4)
x = self.l1(x)
x = self.l2(x)
x = self.l3(x)

关于交叉熵和softmax,参见损失函数。

3. 损失函数

交叉熵

loss_func = F.cross_entropy

batch_pred_t = model(batch_cmp_t)
batch_pred = batch_pred_t.detach().cpu()
acc = cal_acc(batch_pred, batch_label) def cal_acc(batch_pred, batch_label): batch_pred = [torch.argmax(batch_pred[ite_patch]) for ite_patch in range(batch_size)] acc = 0
for ite_patch in range(batch_size):
if pred[ite_patch] == batch_label[ite_patch]:
acc += 1
acc /= batch_size return acc

注意:

  • cross_entropy函数结合了nn.LogSoftmax()nn.NLLLoss()

  • 第二个参数是target。假设batch size是32,那么就是一个32维向量(张量),值为从0开始的正确标签。

  • 第一个参数是input,可以没有被softmax归一化。假设batch size是32,一共有5个分类,那么就是一个\(32 \times 5\)的张量。

4. 系统或环境交互

模型加载

自动搜索空余显存最多的GPU,然后将模型加载到该GPU上:

os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
memory_gpu=[int(x.split()[2]) for x in open('tmp','r').readlines()]
dev = torch.device("cuda:" + str(np.argmax(memory_gpu)))
print(dev) model.load_state_dict(torch.load(os.path.join(dir_model, "model_" + str(index_model) + ".pt"), map_location=dev))
model.to(dev)

5. 犯过的错误

损失异常

  • CNN最后一层使用了非线性激活函数ReLU,导致输出在0附近浮动。

测试显存过大

在测试程序中指定了torch.no_grad(),然而显存还是过大。后来改成with torch.no_grad():包裹测试程序,成功了。

Note | PyTorch的更多相关文章

  1. Note | PyTorch官方教程学习笔记

    目录 1. 快速入门PYTORCH 1.1. 什么是PyTorch 1.1.1. 基础概念 1.1.2. 与NumPy之间的桥梁 1.2. Autograd: Automatic Differenti ...

  2. 理解PyTorch的自动微分机制

    参考Getting Started with PyTorch Part 1: Understanding how Automatic Differentiation works 非常好的文章,讲解的非 ...

  3. 基于pytorch的CNN、LSTM神经网络模型调参小结

    (Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...

  4. PyTorch官方中文文档:torch.nn

    torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...

  5. pytorch对可变长度序列的处理

    主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils. ...

  6. pytorch .detach() .detach_() 和 .data用于切断反向传播

    参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-autograd/#detachsource 当我们再训 ...

  7. 一文看懂Transformer内部原理(含PyTorch实现)

    Transformer注解及PyTorch实现 原文:http://nlp.seas.harvard.edu/2018/04/03/attention.html 作者:Alexander Rush 转 ...

  8. [转] 理解CheckPoint及其在Tensorflow & Keras & Pytorch中的使用

    作者用游戏的暂停与继续聊明白了checkpoint的作用,在三种主流框架中演示实际使用场景,手动点赞. 转自:https://blog.floydhub.com/checkpointing-tutor ...

  9. pytorch做seq2seq注意力模型的翻译

    以下是对pytorch 1.0版本 的seq2seq+注意力模型做法语--英语翻译的理解(这个代码在pytorch0.4上也可以正常跑): # -*- coding: utf-8 -*- " ...

随机推荐

  1. Thymeleaf入门与基础语法

    1.简介 Thymeleaf是用来开发Web和独立环境项目的现代服务器端Java模板引擎. Thymeleaf的主要目标是为您的开发工作流程带来优雅的自然模板 - HTML.可以在直接浏览器中正确显示 ...

  2. 前端笔记之Vue(四)UI组件库&Vuex&虚拟服务器初识

    一.日历组件 new Date()的月份是从0开始的. 下面表达式是:2018年6月1日 new Date(2018, 5, 1); 下面表达式是:2018年5月1日 new Date(2018, 4 ...

  3. 正则表达式中的.*?和python中re.S参数的详解

    本章的内容主要是为讲解在正则表达式中常用的.*?和re.S! 在正则表达式中有贪婪匹配和最小匹配:如下为贪婪匹配(.*) import re match = re.search(r'PY.*', 'P ...

  4. wpf 单例模式和异常处理 (原发布 csdn 2017-04-12 20:34:12)

    第一次写博客,如有错误,请大家及时告知,本人立即改之. 如果您有好的想法或者建议,我随时与我联系. 如果发现代码有误导时,请与我联系,我立即改之. 好了不多说,直接贴代码. 一般的错误,使用下面三个就 ...

  5. NET EF 连接Oracle 的配置方法记录

    主要记录下如何在EF 中连接Oracle s数据库,很傻瓜式,非常简单,但是不知道的童鞋,也会搞得很难受,我自己就是 1.创一个控制台程序,并且添加  Oracle.ManagedDataAccess ...

  6. MySQL学习——查询表里的数据

    MySQL学习——查询表里的数据 摘要:本文主要学习了使用DQL语句查询表里数据的方法. 数据查询 语法 select [distinct] 列1 [as '别名1'], ..., 列n [as '别 ...

  7. 2019年上半年收集到的人工智能LSTM干货文章

    2019年上半年收集到的人工智能LSTM干货文章 门控神经网络:LSTM 和 GRU 简要说明 LSTM-CNN-Attention算法系列之一:LSTM提取时间特征 对时间序列分类的LSTM全卷积网 ...

  8. DOJO之gridx

    GridX简介 Gridx是IBM公司的职员对Dojo中的Grid进行进一步扩展的组件,但是它是重新开发了Grid而不是继承Grid. 虽然同样都是基于Dojo store, 但与DataGrid/E ...

  9. [b0032] python 归纳 (十七)_线程同步_信号量Semaphore

    代码: # -*- coding: utf-8 -*- """ 多线程并发同步 ,使用信号量threading.Semaphore 逻辑: 多个线程,对同一个共享变量 , ...

  10. REST API接口测试

    背景介绍 为什么要做借口测试? 很多系统关联都是基于接口来实现的,接口测试可以将复杂的系统关联进行简化. 接口功能比较单一,能够比较好的进行测试覆盖,也相对容易实现自动化持续集成. 接口相当于界面功能 ...