import sys
import os
import argparse
import time
import random
import math import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable import cuda_functional as MF def read_corpus(path, eos="</s>"):
data = [ ]
with open(path) as fin:
for line in fin:
data += line.split() + [ eos ]
return data def create_batches(data_text, map_to_ids, batch_size, cuda=True):
data_ids = map_to_ids(data_text)
N = len(data_ids)
L = ((N-1) // batch_size) * batch_size
x = np.copy(data_ids[:L].reshape(batch_size,-1).T)
y = np.copy(data_ids[1:L+1].reshape(batch_size,-1).T)#x和y的结果基本相同
x, y = torch.from_numpy(x), torch.from_numpy(y)
x, y = x.contiguous(), y.contiguous()
if cuda:
x, y = x.cuda(), y.cuda()
return x, y class EmbeddingLayer(nn.Module):#为语料中每一个单词对应的其相应的词向量
def __init__(self, n_d, words, fix_emb=False):
super(EmbeddingLayer, self).__init__()
word2id = {}
for w in words:
if w not in word2id:
word2id[w] = len(word2id)#把文本映射到数字上。 self.word2id = word2id
self.n_V, self.n_d = len(word2id), n_d#n_V应该是指词库大小,n_d指hidden state size
self.embedding = nn.Embedding(self.n_V, n_d)#赋予每个单词相应的词向量 def forward(self, x):
return self.embedding(x) def map_to_ids(self, text):#映射
return np.asarray([self.word2id[x] for x in text],
dtype='int64'
) class Model(nn.Module):
def __init__(self, words, args):
super(Model, self).__init__()
self.args = args
self.n_d = args.d
self.depth = args.depth
self.drop = nn.Dropout(args.dropout)#防止过拟合的层,变分dropout
self.embedding_layer = EmbeddingLayer(self.n_d, words)
self.n_V = self.embedding_layer.n_V
if args.lstm:
self.rnn = nn.LSTM(self.n_d, self.n_d,#self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns
input_size=INPUT_SIZE,
hidden_size=64, # rnn hidden unit
num_layers=1, # number of rnn layer
batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
) self.depth,
dropout = args.rnn_dropout
)
else:
self.rnn = MF.SRU(self.n_d, self.n_d, self.depth,
dropout = args.rnn_dropout,
rnn_dropout = args.rnn_dropout,
use_tanh = 0
)
self.output_layer = nn.Linear(self.n_d, self.n_V)
# tie weights
self.output_layer.weight = self.embedding_layer.embedding.weight
self.init_weights()
if not args.lstm:
self.rnn.set_bias(args.bias) def init_weights(self):#initial c
val_range = (3.0/self.n_d)**0.5
for p in self.parameters():
if p.dim() > 1: # matrix
p.data.uniform_(-val_range, val_range)
else:
p.data.zero_() def forward(self, x, hidden):
emb = self.drop(self.embedding_layer(x))
output, hidden = self.rnn(emb, hidden)#rnn的输入和输出都有两个,即输入和上一层的隐层的值
output = self.drop(output)
output = output.view(-1, output.size(2))#改变tensor的size,size(2)表示计算第三维的大小,如size 4x6x7,则.size(3)就等于7
output = self.output_layer(output)
return output, hidden def init_hidden(self, batch_size):
weight = next(self.parameters()).data
zeros = Variable(weight.new(self.depth, batch_size, self.n_d).zero_())
if self.args.lstm:
return (zeros, zeros)
else:
return zeros def print_pnorm(self):#输出范数,范数常常被用来度量某个向量空间(或矩阵)中的每个向量的长度或大小。正则化中就是用范数
norms = [ "{:.0f}".format(x.norm().data[0]) for x in self.parameters() ]
sys.stdout.write("\tp_norm: {}\n".format(
norms
)) def train_model(epoch, model, train):
model.train()
args = model.args unroll_size = args.unroll_size
batch_size = args.batch_size
N = (len(train[0])-1)//unroll_size + 1
lr = args.lr total_loss = 0.0
criterion = nn.CrossEntropyLoss(size_average=False)#每个小批次的损失将被相加。
hidden = model.init_hidden(batch_size)
for i in range(N):
x = train[0][i*unroll_size:(i+1)*unroll_size]
y = train[1][i*unroll_size:(i+1)*unroll_size].view(-1)#view(-1)是指按列展开
x, y = Variable(x), Variable(y)
hidden = (Variable(hidden[0].data), Variable(hidden[1].data)) if args.lstm \
else Variable(hidden.data) model.zero_grad()
output, hidden = model(x, hidden)
assert x.size(1) == batch_size
loss = criterion(output, y) / x.size(1)#.size(1)计算列数.size(0)计算行数,must be (1. nn output, 2. target), the target label is NOT one-hotted
loss.backward() torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad)#nn.utils.clip_grad_norm()对网络进行梯度裁剪,因为RNN中容易出现梯度爆炸的问题。
for p in model.parameters():
if p.requires_grad:
if args.weight_decay > 0:
p.data.mul_(1.0-args.weight_decay)
p.data.add_(-lr, p.grad.data)
if math.isnan(loss.data[0]) or math.isinf(loss.data[0]):#如果发生梯度消失或梯度爆炸则退出程序
sys.exit(0) #math.isinf(x):如果x = ±inf(inf:infinity ,译为无穷)也就是±∞返回True
return #math.isnan(x):如果x = Non (not a number) 返回True; total_loss += loss.data[0] / x.size(0)
if i%10 == 0:
sys.stdout.write("\r{}".format(i))
sys.stdout.flush() return np.exp(total_loss/N) def eval_model(model, valid):
model.eval()
args = model.args
total_loss = 0.0
unroll_size = model.args.unroll_size
criterion = nn.CrossEntropyLoss(size_average=False)
hidden = model.init_hidden(1)
N = (len(valid[0])-1)//unroll_size + 1
for i in range(N):
x = valid[0][i*unroll_size:(i+1)*unroll_size]
y = valid[1][i*unroll_size:(i+1)*unroll_size].view(-1)
x, y = Variable(x, volatile=True), Variable(y)
hidden = (Variable(hidden[0].data), Variable(hidden[1].data)) if args.lstm \
else Variable(hidden.data)
output, hidden = model(x, hidden)
loss = criterion(output, y)
total_loss += loss.data[0]
avg_loss = total_loss / valid[1].numel()#numel()返回张量所含元素的个数
ppl = np.exp(avg_loss)
return ppl def main(args):
train = read_corpus(args.train)
dev = read_corpus(args.dev)
test = read_corpus(args.test) model = Model(train, args)
model.cuda()
sys.stdout.write("vocab size: {}\n".format(
model.embedding_layer.n_V
))
sys.stdout.write("num of parameters: {}\n".format(
sum(x.numel() for x in model.parameters() if x.requires_grad)
))
model.print_pnorm()
sys.stdout.write("\n") map_to_ids = model.embedding_layer.map_to_ids
train = create_batches(train, map_to_ids, args.batch_size)
dev = create_batches(dev, map_to_ids, 1)
test = create_batches(test, map_to_ids, 1) unchanged = 0
best_dev = 1e+8
for epoch in range(args.max_epoch):
start_time = time.time()#返回当前时间的时间戳(1970纪元后经过的浮点秒数)。
if args.lr_decay_epoch>0 and epoch>=args.lr_decay_epoch:
args.lr *= args.lr_decay
train_ppl = train_model(epoch, model, train)
dev_ppl = eval_model(model, dev)
sys.stdout.write("\rEpoch={} lr={:.4f} train_ppl={:.2f} dev_ppl={:.2f}"
"\t[{:.2f}m]\n".format(
epoch,
args.lr,
train_ppl,
dev_ppl,
(time.time()-start_time)/60.0
))
model.print_pnorm()
sys.stdout.flush() if dev_ppl < best_dev:
unchanged = 0
best_dev = dev_ppl
start_time = time.time()
test_ppl = eval_model(model, test)
sys.stdout.write("\t[eval] test_ppl={:.2f}\t[{:.2f}m]\n".format(
test_ppl,
(time.time()-start_time)/60.0
))
sys.stdout.flush()
else:
unchanged += 1
if unchanged >= 30: break
sys.stdout.write("\n") if __name__ == "__main__":
argparser = argparse.ArgumentParser(sys.argv[0], conflict_handler='resolve')
argparser.add_argument("--lstm", action="store_true")
argparser.add_argument("--train", type=str, required=True, help="training file")
argparser.add_argument("--dev", type=str, required=True, help="dev file")
argparser.add_argument("--test", type=str, required=True, help="test file")
argparser.add_argument("--batch_size", "--batch", type=int, default=32)
argparser.add_argument("--unroll_size", type=int, default=35)
argparser.add_argument(" ", type=int, default=300)
argparser.add_argument("--d", type=int, default=910)
argparser.add_argument("--dropout", type=float, default=0.7,
help="dropout of word embeddings and softmax output"
)
argparser.add_argument("--rnn_dropout", type=float, default=0.2,
help="dropout of RNN layers"
)
argparser.add_argument("--bias", type=float, default=-3,
help="intial bias of highway gates",
)
argparser.add_argument("--depth", type=int, default=6)
argparser.add_argument("--lr", type=float, default=1.0)
argparser.add_argument("--lr_decay", type=float, default=0.98)
argparser.add_argument("--lr_decay_epoch", type=int, default=175)
argparser.add_argument("--weight_decay", type=float, default=1e-5)
argparser.add_argument("--clip_grad", type=float, default=5) args = argparser.parse_args()
print (args)

sru源码--language model的更多相关文章

  1. 关于sru源码class Model的parameters

    class Model(nn.Module): def __init__(self, words, args): super(Model, self).__init__() self.args = a ...

  2. yii2 源码分析 model类分析 (五)

    模型类是数据模型的基类.此类继承了组件类,实现了3个接口 先介绍一下模型类前面的大量注释说了什么: * 模型类是数据模型的基类.此类继承了组件类,实现了3个接口 * 实现了IteratorAggreg ...

  3. django源码分析---- Model类型&Field类型

    djiango在数据库这方式自己实现了orm(object relationship mapping 对象关系模型映射).这个主要是用到python 元类这一 项python中的高级技术来实现的. c ...

  4. [原创]在Windows和Linux中搭建PostgreSQL源码调试环境

    张文升http://ode.cnblogs.comEmail:wensheng.zhang#foxmail.com 配图太多,完整pdf下载请点这里 本文使用Xming.Putty和VMWare几款工 ...

  5. Mybatis Generator的model生成中文注释,支持oracle和mysql(通过修改源码的方式来实现)

    在看本篇之前,最好先看一下上一篇通过实现CommentGenerator接口的方法来实现中文注释的例子,因为很多操作和上一篇基本是一致的,所以本篇可能不那么详细. 首先说一下上篇通过实现Comment ...

  6. Backbone源码解析(二):Model(模型)模块

    Model(模型)模块在bk框架中的作用主要是存储处理数据,它对外和对内都有很多操作数据的接口和方法.它与视图(Views)模块精密联系着,通过set函数改变数据结构从而改变视图界面的变化.下面我们来 ...

  7. PureMVC(JS版)源码解析(十一):Model类

          这篇博文讲PureMVC三个核心类——Model类.Model类的构造函数及工厂函数[即getInstance()方法]和View类.Controller类是一样的,这里就不重复讲解了,只 ...

  8. ZRender源码分析2:Storage(Model层)

    回顾 上一篇请移步:zrender源码分析1:总体结构 本篇进行ZRender的MVC结构中的M进行分析 总体理解 上篇说到,Storage负责MVC层中的Model,也就是模型,对于zrender来 ...

  9. BIZ中model.getSql源码分析

    功能:根据model.xml文件中配置的sql,获取对应的动态sql结果. 实例代码:String sql1 = model.getSql(dao.dbMeta());String sql2 = mo ...

随机推荐

  1. java.io.FileNotFoundException: generatorConfig.xml (系统找不到指定的文件。)

    在使用MyBatis的逆向工程生成代码时,一直报错java.io.FileNotFoundException: generatorConfig.xml (系统找不到指定的文件.),如图 文件结构如下: ...

  2. PPT高手博客

    让PPT设计NEW一NEW——Lonely Fish http://lonelyfish1920.blog.163.com/ http://blog.sina.com.cn/s/blog_698717 ...

  3. Java并发编程中线程池源码分析及使用

    当Java处理高并发的时候,线程数量特别的多的时候,而且每个线程都是执行很短的时间就结束了,频繁创建线程和销毁线程需要占用很多系统的资源和时间,会降低系统的工作效率. 参考http://www.cnb ...

  4. 洛谷 P1357 花园 解题报告

    P1357 花园 题目描述 小\(L\)有一座环形花园,沿花园的顺时针方向,他把各个花圃编号为\(1~N(2<=N<=10^{15})\).他的环形花园每天都会换一个新花样,但他的花园都不 ...

  5. 21天实战caffe笔记_第一天

    1 深度学习术语 深度学习常用名词:有监督学习.无监督学习.训练数据集.测试数据集.过拟合.泛化.惩罚值(损失loss); 机器自动学习所需三份数据:训练集(机器学习的样例),验证集(机器学习阶段,用 ...

  6. C中有关引用和指针的异同

    参考于https://blog.csdn.net/wtzdedaima/article/details/78377201 C语言也学了蛮久的,其实一直都没有用到过或者碰到过引用的例子.前端时间再全面复 ...

  7. 解决Maven提示:Could not read settings.xml

    在Eclipse中配置maven时,提示错误:Could not read settings.xml.用户配置无法生效. 1.首先检查自己的settings.xml配置文件,发现在<!----& ...

  8. Ansible1: 简介与基本安装

    目录 Ansible特性 Ansible的基本组件 Ansible工作机制 Ansible的安装 Ansible是一个综合的强大的管理工具,他可以对多台主机安装操作系统,并为这些主机安装不同的应用程序 ...

  9. P3275 [SCOI2011]糖果 && 差分约束(二)

    学习完了差分约束是否有解, 现在我们学习求解最大解和最小解 首先我们回想一下是否有解的求解过程, 不难发现最后跑出来任意两点的最短路关系即为这两元素的最短路关系. 即: 最后的最短路蕴含了所有元素之间 ...

  10. 自动化工具制作PASCAL VOC 数据集

    自动化工具制作PASCAL VOC 数据集   1. VOC的格式 VOC主要有三个重要的文件夹:Annotations.ImageSets和JPEGImages JPEGImages 文件夹 该文件 ...