参考代码地址:https://github.com/pytorch/examples/tree/master/word_language_model

/word_language_model/data.py

这个data加载文件写的很简洁,值得学习

参考代码地址:https://github.com/pytorch/examples/tree/master/word_language_model

/word_language_model/data.py

这个data加载文件写的很简洁,值得学习

import os
import torch

class Dictionary(object):
#维护一个字典,存储着语料word2idx
def __init__(self):
self.word2idx = {}
self.idx2word = []

def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]

def __len__(self):
return len(self.idx2word)

class Corpus(object):
#维护着语料,这里比较tricky,在建立字典的过程中,同时也产出了文本到索引的文件。
#tokenize这个函数先把现在文件的所有term加到字典,然后根据字典产出一个索引文件。

def __init__(self, path):
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path, 'train.txt'))
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
self.test = self.tokenize(os.path.join(path, 'test.txt'))

def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path)
# Add words to the dictionary
with open(path, 'r') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
self.dictionary.add_word(word)

# Tokenize file content
with open(path, 'r') as f:
ids = torch.LongTensor(tokens)
token = 0
for line in f:
words = line.split() + ['<eos>']
for word in words:
ids[token] = self.dictionary.word2idx[word]
token += 1

return ids
  

/word_language_model/main.py

###############################################################################
# Load data
###############################################################################

corpus = data.Corpus(args.data)
#这个corpus.train是一个一维的向量
def batchify(data, bsz):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
#bsz就是batch size
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)

# Evenly divide the data across the bsz batches.
# 先把一维的向量变成batch size的整数倍,然后先把1*N的矩阵转化成 bsz * M的矩阵,然后通过t()将第一个维度和第二个维度替换
# 得到的最后的data向量是M*bsz的向量,就是每一行是一个batch,每一列就是一段文本,类似我们古代的书,但是是从左到右的排列

——————————————————————————

数据处理流程是这样的。原来一段长文本,如 a pen is this day, that is a ......

然后进行切割,一共分成20行,就变成了每一行是个长文本,一共20行,这时候每列是一个batch。进行transpose。就是M*bsz的格式,这个时候每一行是一个batch。格式类似:

a   when ...

pen    we..

is       look

this    the

这个格式的,进入网络的话是一个sequence序列,假如是35的话,就是从上倒下35行,取出来送入网络。取第个序列的话,直接从36行开始。。一个batch是20大小,20表示序列的数目,每个序列是35个字,所以是一个35*20的向量

——————————————————————————
data = data.view(bsz, -1).t().contiguous()
if args.cuda:
data = data.cuda()
return data

eval_batch_size = 10
train_data = batchify(corpus.train, args.batch_size)
val_data = batchify(corpus.valid, eval_batch_size)
test_data = batchify(corpus.test, eval_batch_size)

###############################################################################
# Build the model
###############################################################################

ntokens = len(corpus.dictionary)
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied)
if args.cuda:
model.cuda()

criterion = nn.CrossEntropyLoss()

###############################################################################
# Training code
###############################################################################

def repackage_hidden(h):
"""Wraps hidden states in new Variables, to detach them from their history."""
if type(h) == Variable:
return Variable(h.data)
else:
return tuple(repackage_hidden(v) for v in h)

def get_batch(source, i, evaluation=False):
seq_len = min(args.bptt, len(source) - 1 - i)

#这里的bptt是sequence的长度,i是序列的开始,i是根据sequence的长度进行递增
data = Variable(source[i:i+seq_len], volatile=evaluation)
target = Variable(source[i+1:i+1+seq_len].view(-1))

#返回的data是bptt*bsz,即sequence长度*batch size

#返回的target的是一维的向量,大小和data是一样的,不过是data平移一位后的结果
return data, target

def evaluate(data_source):
# Turn on evaluation mode which disables dropout.
model.eval()
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(eval_batch_size)
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, evaluation=True)
output, hidden = model(data, hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).data
hidden = repackage_hidden(hidden)
return total_loss[0] / len(data_source)

def train():
# Turn on training mode which enables dropout.
model.train()
total_loss = 0
start_time = time.time()
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(args.batch_size)
for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
data, targets = get_batch(train_data, i)
# Starting each batch, we detach the hidden state from how it was previously produced.
# If we didn't, the model would try backpropagating all the way to start of the dataset.
hidden = repackage_hidden(hidden)
model.zero_grad()
output, hidden = model(data, hidden)
loss = criterion(output.view(-1, ntokens), targets)
loss.backward()

# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
for p in model.parameters():
p.data.add_(-lr, p.grad.data)

total_loss += loss.data

if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss[0] / args.log_interval
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
epoch, batch, len(train_data) // args.bptt, lr,
elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
total_loss = 0
start_time = time.time()

# Loop over epochs.
lr = args.lr
best_val_loss = None

# At any point you can hit Ctrl + C to break out of training early.
try:
for epoch in range(1, args.epochs+1):
epoch_start_time = time.time()
train()
val_loss = evaluate(val_data)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
val_loss, math.exp(val_loss)))
print('-' * 89)
# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
with open(args.save, 'wb') as f:
torch.save(model, f)
best_val_loss = val_loss
else:
# Anneal the learning rate if no improvement has been seen in the validation dataset.
lr /= 4.0
except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early')

# Load the best saved model.
with open(args.save, 'rb') as f:
model = torch.load(f)

# Run on test data.
test_loss = evaluate(test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)

model

import torch.nn as nn
from torch.autograd import Variable

class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""

def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
if rnn_type in ['LSTM', 'GRU']:
self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
else:
try:
nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
except KeyError:
raise ValueError( """An invalid option for `--model` was supplied,
options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)

# Optionally tie weights as in:
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
# https://arxiv.org/abs/1608.05859
# and
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
# https://arxiv.org/abs/1611.01462
if tie_weights:
if nhid != ninp:
raise ValueError('When using the tied flag, nhid must be equal to emsize')
self.decoder.weight = self.encoder.weight

self.init_weights()

self.rnn_type = rnn_type
self.nhid = nhid
self.nlayers = nlayers

def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-initrange, initrange)

def forward(self, input, hidden):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

def init_hidden(self, bsz):

#就是这个函数一直不太明白,为什么取出第一个参数,然后重新构造,难道第一个参数不学习了?
weight = next(self.parameters()).data
if self.rnn_type == 'LSTM':
return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()),
Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()))
else:
return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())

												

pytoch word_language_model 代码阅读的更多相关文章

  1. 代码阅读分析工具Understand 2.0试用

    Understand 2.0是一款源代码阅读分析软件,功能强大.试用过一段时间后,感觉相当不错,确实可以大大提高代码阅读效率.由于Understand功能十分强大,本文不可能详尽地介绍它的所有功能,所 ...

  2. Android 上的代码阅读器 CoderBrowserHD 修改支持 go 语言代码

    我在Android上的代码阅读器用的是 https://github.com/zerob13/CoderBrowserHD 改造的版本,改造后的版本我放在 https://github.com/ghj ...

  3. Linux协议栈代码阅读笔记(二)网络接口的配置

    Linux协议栈代码阅读笔记(二)网络接口的配置 (基于linux-2.6.11) (一)用户态通过C库函数ioctl进行网络接口的配置 例如,知名的ifconfig程序,就是通过C库函数sys_io ...

  4. [置顶] Linux协议栈代码阅读笔记(一)

    Linux协议栈代码阅读笔记(一) (基于linux-2.6.21.7) (一)用户态通过诸如下面的C库函数访问协议栈服务 int socket(int domain, int type, int p ...

  5. 图形化代码阅读工具——Scitools Understand

    Scitools出品的Understand 2.0.用了很多年了,比Source Insight强大很多.以前的名字叫Understand for C/C++,Understand for Java, ...

  6. Python - 关于代码阅读的一些建议

    初始能力 让阅读思路保持清晰连贯,主力关注在流程架构和逻辑实现上,不被语法.技巧和业务流程等频繁地阻碍和打断. 建议基本满足以下条件,再开始进行代码阅读: 具备一定的语言基础:熟悉基础语法,常用的函数 ...

  7. MediaInfo代码阅读

      MediaInfo是一个用来分析媒体文件的开源工具. 支持的文件非常全面,基本上支持所有的媒体文件. 最近是在做HEVC开发,所以比较关注MediaInfo中关于HEVC的分析与处理. 从Meid ...

  8. Tools - 一些代码阅读的方法

    1 初始能力 让阅读思路清晰连贯,保持在程序的流程架构和逻辑实现上,不被语法.编程技巧和业务流程等频繁地阻碍和打断. 语言基础:熟悉基础语法,常用的函数.库.编程技巧等: 了解设计模式.构建工具.代码 ...

  9. Bleve代码阅读(二)——Index Mapping

    引言 Bleve是Golang实现的一个全文检索库,类似Lucene之于Java.在这里通过阅读其代码,来学习如何使用及定制检索功能.也是为了通过阅读代码,学习在具体环境下Golang的一些使用方式. ...

随机推荐

  1. CS1704问题汇总

    “/”应用程序中的服务器错误. 编译错误 说明: 在编译向该请求提供服务所需资源的过程中出现错误.请检查下列特定错误详细信息并适当地修改源代码. 编译器错误消息: CS1704: 已经导入了具有相同的 ...

  2. ueditor富文本上传图片的时候报错"未找上传数据"

    最近因为需求所以在ssh项目中使用了Ueditor富文本插件,但是在上传图片的时候总是提示“未找到上传数据”,之后百度了好久终于弄明白了.因为Ueditor在上传图片的时候会访问controller. ...

  3. IDEA中文出现乱码解决(转)

    转自:http://lcl088005.iteye.com/blog/2284696 我是个idea的忠实用户,新公司的项目都是用eclipse做的,通过svn拉下代码后发现,注释的内容里,中文内容都 ...

  4. 【css】适配iphoneX

    /*适配iphoneX*/ @media only screen and (device-width: 375px) and (device-height: 812px) and (-webkit-d ...

  5. oracle导出导入指定表

    从源数据库导出: exp user1/pwd@server1/orcl file=c:\temp\exp.dmp tables=(table1, table2) 导入到目标数据库: imp user2 ...

  6. Python在金融量开源项目列表

    Python也已经在金融量化投资领域占据了重要位置,开源项目列表:

  7. week_one-python基础 列表 增删改查

    # Author:larlly #列表增删改查#定义列表name = ["wo","ni","ta","wo"] #定义 ...

  8. 生成树协议stp

    生成树协议应用的原因是从逻辑上阻塞交换机在物理上形成的环路.大家都知道交换机工作在二层,也就是数据链路层,根据mac地址识别主机,对三层网络无法识别,因此交换机不能隔离广播.但是在日常的工作中,为了达 ...

  9. MariaDB glare cluster简介

    MariaDB MariaDB 是由原来 MySQL 的作者Michael Widenius创办的公司所开发的免费开源的数据库服务器,MariaDB是同一MySQL版本的二进制替代品, 当前最新版本1 ...

  10. pyCharm的第一个项目

    首先打开编译器pyCharm 创建一个项目 在location :新建文件夹 在interpreter:指定python解释器的路径 python解释器下载官网: https://www.python ...