一开始写这篇随笔的时候还没有了解到 Dateloader有一个 collate_fn 的参数,通过定义一个collate_fn 函数,其实很多batch补齐到当前batch最长的操作可以放在collate_fn 里面去,这样代码在训练和模型中就可以更加简洁。有时间再整理一下这个吧。

_________________________________________

使用的主要部分包括:Dateset、 Dateloader、MSELoss、PackedSequence、pack_padded_sequence、pad_packed_sequence

模型包含LSTM模块。

参考了下面两篇博文,总结了一下。对PackedSequence相关的理解可以先看这两篇。本文主要是把这些应用从数据准备到loss计算都串起来大致提供了一下代码思路,权当给自己的提醒备份吧。或者看完下面两篇,但是不知道具体怎么操作的朋友们一个参考。

http://www.cnblogs.com/lindaxin/p/8052043.html#commentform

https://blog.csdn.net/lssc4205/article/details/79474735

使用Dateset构建数据集的时候,在__getitem__函数中把所有数据先补齐到 全局最长序列的长度。

  1. def __getitem__(self, index):
  2. '''
  3. get original data
  4. 此处省略获取原始数据的代码
  5. input_data,output_data
  6. 数据shape是 seq_length * feature_dim
  7. '''
  8. # 当前seq_length小于所有数据中的最长数据长度,则补0到同一长度。
  9. ori_length = input_data.shape[0]
  10. if ori_length < self.max_len:
  11. npi = np.zeros(self.input_feature_dim, dtype=np.float32)
  12. npi = np.tile(npi, (self.max_len - ori_length,1))
  13. input_data = np.row_stack((input_data, npi))
  14. npo = np.zeros(self.output_feature_dim, dtype=np.float32)
  15. npo = np.tile(npo, (self.max_len - ori_length,1))
  16. output_data = np.row_stack((output_data, npo))
  17. return input_data, output_data, ori_length, input_data_path

在模型中,forward的实现中,需要在LSTM之前使用pack_padded_sequence、在LSTM之后使用pad_packed_sequence,中间还涉及到顺序的还原之类的操作。

  1. def forward(self, input_x, length_list, hidden=None):
  2. if hidden is None:
  3. # 这里没用 配置中的batch_size,而是直接在input_x中取batch_size是为了防止last_batch的batch_size不是配置中的那个,引发bug
  4. h_0 = input_x.data.new(self.directional*self.layer_num, input_x.shape[0], self.hidden_dim).fill_(0).float()
  5. c_0 = input_x.data.new(self.directional*self.layer_num, input_x.shape[0], self.hidden_dim).fill_(0).float()
  6. else:
  7. h_0, c_0 = hidden
  8. '''
  9. 省略模型其他部分,直接进去LSTM前后的操作
  10. '''
  11. _, idx_sort = torch.sort(length_list, dim=0, descending=True)
  12. _, idx_unsort = otrch.sort(idx_sort, dim=0)
  13.  
  14. input_x = input_x.index_select(0, Variable(idx_sort))
  15. length_list = list(length_list[idx_sort])
  16. pack = nn_utils.rnn.pack_padded_sequence(input_x, length_list, batch_first=self.batch_first)
  17. output, hidden = self.BiLSTM(pack, (h0, c0))
  18. un_padded = nn_utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first)
  19. un_padded = un_padded[0].index_select(0, Variable(idx_unsort))
  20. # 此时的un_padded已经完成了还原,并且补0完成,而且这时的补0到的序列长度是当前batch的最长长度,而不是Dateset中的全局最长长度!
    # 所以在main train函数中也要对label的seq做处理
  21. return un_padded

main train中,要对label做相应的截断处理,因为模型返回的长度已经是补齐到当前batch的最长序列长度了,而dateset返回的label是补齐到全局最长序列长度。算loss的时候,MSELoss的reduce参数要设置成false,让loss函数返回一个loss矩阵,再构造一个01掩膜矩阵mask,矩阵相乘求和得到真的loss(达到填充0的位置不参与loss的目的)

  1. def train(**kwargs):
      train_data = my_dataset()
      train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers)
      model = getattr(models, opt.model)(batchsize=opt.batch_size)
      criterion = torch.nn.MSELoss(reduce=False)
      lr = opt.lf
      optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=opt.weight_decay)
      for epoch in range(opt.start_epoch, opt.max_epoch):
        for ii, (data, label, length_list,_) in tqdm(enumerate(train_dataloader)):
          cur_batch_max_len = length_list.max()
          data = Variable(data)
          target = Variable(label)
  2.  
  3.       optimizer.zero_grad()
          score = model(data, length_list)
          loss_mat = criterion(score, target)
          list_int = list(length_list)
          mask_mat = Variable(t.ones(len(list_int),cur_batch_max_len,opt.output_feature_dim))
          num_element = 0
          for idx_sample in range(len(list_int)):
            num_element += list_int[idx_sample] * opt.output_feature_dim
            if list_int[idx_sample] != cur_batch_max_len:
              mask_mat[idx_sample, list[idx_sample]:] = 0.0
  4.  
  5.       loss = (loss_mat * mask_mat).sum() / num_element
          loss.backward()
          optimizer.step()
  6.  

pytorch 对变长序列的处理的更多相关文章

  1. pytorch中如何处理RNN输入变长序列padding

    一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...

  2. keras: 在构建LSTM模型时,使用变长序列的方法

    众所周知,LSTM的一大优势就是其能够处理变长序列.而在使用keras搭建模型时,如果直接使用LSTM层作为网络输入的第一层,需要指定输入的大小.如果需要使用变长序列,那么,只需要在LSTM层前加一个 ...

  3. 0-3为变长序列建模modeling variable length sequences

    在本节中,我们会讨论序列的长度是变化的,也是一个变量 we would like the length of sequence,n,to alse be a random variable 一个简单的 ...

  4. Python技法1:变长和定长序列拆分

    Python中的任何序列(可迭代的对象)都可以通过赋值操作进行拆分,包括但不限于元组.列表.字符串.文件.迭代器.生成器等. 元组拆分 元组拆分是最为常见的一种拆分,示例如下: p = (4, 5) ...

  5. C++中的变长参数

    新参与的项目中,为了使用共享内存和自定义内存池,我们自己定义了MemNew函数,且在函数内部对于非pod类型自动执行构造函数.在需要的地方调用自定义的MemNew函数.这样就带来一个问题,使用stl的 ...

  6. Scala 变长参数

    如果Scala定义变长参数 def sum(i Int*), 那么调用sum时,可以直接输入sum(1,2,3,4,5) 但是不可以sum(1 to 5) 必须要将1 to 5 强制为seq sum( ...

  7. 报文格式:xml 、定长报文、变长报文

    目前接触到的报文格式有三种:xml .定长报文.变长报文 . 此处只做简单介绍,日后应该会深入学习到三者之间如何解析,再继续更新.——2016.9.23 XML XML 被设计用来传输和存储数据. H ...

  8. GCC 中零长数组与变长数组

    前两天看程序,发现在某个函数中有下面这段程序: int n; //define a variable n int array[n]; //define an array with length n 在 ...

  9. 删除变长列字段后使用DBCC CLEANTABLE回收空间

    标签:SQL Server Reclaim space 收缩表 原创作品,允许转载,转载时请务必以超链接形式标明文章 原始出处 .作者信息和本声明.否则将追究法律责任.http://lzf328.bl ...

随机推荐

  1. 一个汇编的HelloWorld!

    花了一下午时间,感觉最坑的是,书写代码的个数和编译器的坑比较多,还各种版本的编译器! 会让人“眼花缭乱”! 主要代码 将文件保存为*.asm include io32.inc .data ;数据 sr ...

  2. 一、JSP九大内置对象 二、JAVAEE三层架构和MVC设计模式 三、Ajax

    一.JSP九大内置对象###<1>概念 不需要预先申明和定义,可以直接在jsp代码中直接使用 在JSP转换成Servlet之后,九大对象在Servlet中的service方法中对其进行定义 ...

  3. SDN第三次上机

    1.创建以下拓扑(可采用任意方式) 2.利用OVS命令下发流表,实现VLAN功能 3.利用OVS命令查看流表 4.验证性测试 5.Wireshark抓包验证

  4. 解决Windows Server2008 R2中IE开网页时弹出阻止框

    使用Windows Server2008,用IE打开网站时会弹出“Internet Explorer增强安全配置正在阻止来自下列网站的此应用程序中的内容”的对话框.如下图所示: 2011-10-14_ ...

  5. python第三十二课——栈

    栈:满足特点 --> 先进后出,类似于我们生活中的子弹夹 [注意] 对于栈结构而言:python中没有为其封装特定的函数,我们可以使用list(列表)来模拟栈的特点 使用list对象来模拟栈结构 ...

  6. linux,添加新硬盘的方法

    一.物理机添加一块新的硬盘方法(目的是把后加的磁盘直接加在现有的上面,不用再分区挂载)1.首先要确定现有系统在那块盘上  [root@localhost ~]# df -lhFilesystem    ...

  7. JAVA框架Struts2(二)

    一:Struts2执行流程: 1)编写页面,点击超链接,请求提交到服务器端. 2)请求先经过Struts2核心过滤器(StrutsprepareAndexectuterfilter). 3)过滤器的功 ...

  8. HDU 2159 FATE(有选择物品总个数限制的完全背包,经典!!)

    FATE Time Limit:1000MS     Memory Limit:32768KB     64bit IO Format:%I64d & %I64u Submit Status ...

  9. 'utf-8' codec can't decode byte 0xbc in position 1182: invalid start byte

    2.如果是字符集出现错误,建议多选择几种字符集测试一下: 选择的经验是: 如果是爬取到的网页文件,可以查看网页文件的meta标签下的charset属性值.例如: <meta charset=&q ...

  10. day61

    Vue 八.重要指令 v-bind <!-- 值a --> <div v-bind:class='"a"'></div> <!-- 变量a ...