参考代码地址:https://github.com/pytorch/examples/tree/master/word_language_model

/word_language_model/data.py

这个data加载文件写的很简洁,值得学习

参考代码地址:https://github.com/pytorch/examples/tree/master/word_language_model

/word_language_model/data.py

这个data加载文件写的很简洁,值得学习

import os
import torch

class Dictionary(object):
#维护一个字典,存储着语料word2idx
def __init__(self):
self.word2idx = {}
self.idx2word = []

def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]

def __len__(self):
return len(self.idx2word)

class Corpus(object):
#维护着语料,这里比较tricky,在建立字典的过程中,同时也产出了文本到索引的文件。
#tokenize这个函数先把现在文件的所有term加到字典,然后根据字典产出一个索引文件。

def __init__(self, path):
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path, 'train.txt'))
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
self.test = self.tokenize(os.path.join(path, 'test.txt'))

def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path)
# Add words to the dictionary
with open(path, 'r') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
self.dictionary.add_word(word)

# Tokenize file content
with open(path, 'r') as f:
ids = torch.LongTensor(tokens)
token = 0
for line in f:
words = line.split() + ['<eos>']
for word in words:
ids[token] = self.dictionary.word2idx[word]
token += 1

return ids
  

/word_language_model/main.py

###############################################################################
# Load data
###############################################################################

corpus = data.Corpus(args.data)
#这个corpus.train是一个一维的向量
def batchify(data, bsz):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
#bsz就是batch size
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)

# Evenly divide the data across the bsz batches.
# 先把一维的向量变成batch size的整数倍,然后先把1*N的矩阵转化成 bsz * M的矩阵,然后通过t()将第一个维度和第二个维度替换
# 得到的最后的data向量是M*bsz的向量,就是每一行是一个batch,每一列就是一段文本,类似我们古代的书,但是是从左到右的排列

——————————————————————————

数据处理流程是这样的。原来一段长文本,如 a pen is this day, that is a ......

然后进行切割,一共分成20行,就变成了每一行是个长文本,一共20行,这时候每列是一个batch。进行transpose。就是M*bsz的格式,这个时候每一行是一个batch。格式类似:

a   when ...

pen    we..

is       look

this    the

这个格式的,进入网络的话是一个sequence序列,假如是35的话,就是从上倒下35行,取出来送入网络。取第个序列的话,直接从36行开始。。一个batch是20大小,20表示序列的数目,每个序列是35个字,所以是一个35*20的向量

——————————————————————————
data = data.view(bsz, -1).t().contiguous()
if args.cuda:
data = data.cuda()
return data

eval_batch_size = 10
train_data = batchify(corpus.train, args.batch_size)
val_data = batchify(corpus.valid, eval_batch_size)
test_data = batchify(corpus.test, eval_batch_size)

###############################################################################
# Build the model
###############################################################################

ntokens = len(corpus.dictionary)
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied)
if args.cuda:
model.cuda()

criterion = nn.CrossEntropyLoss()

###############################################################################
# Training code
###############################################################################

def repackage_hidden(h):
"""Wraps hidden states in new Variables, to detach them from their history."""
if type(h) == Variable:
return Variable(h.data)
else:
return tuple(repackage_hidden(v) for v in h)

def get_batch(source, i, evaluation=False):
seq_len = min(args.bptt, len(source) - 1 - i)

#这里的bptt是sequence的长度,i是序列的开始,i是根据sequence的长度进行递增
data = Variable(source[i:i+seq_len], volatile=evaluation)
target = Variable(source[i+1:i+1+seq_len].view(-1))

#返回的data是bptt*bsz,即sequence长度*batch size

#返回的target的是一维的向量,大小和data是一样的,不过是data平移一位后的结果
return data, target

def evaluate(data_source):
# Turn on evaluation mode which disables dropout.
model.eval()
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(eval_batch_size)
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, evaluation=True)
output, hidden = model(data, hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).data
hidden = repackage_hidden(hidden)
return total_loss[0] / len(data_source)

def train():
# Turn on training mode which enables dropout.
model.train()
total_loss = 0
start_time = time.time()
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(args.batch_size)
for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
data, targets = get_batch(train_data, i)
# Starting each batch, we detach the hidden state from how it was previously produced.
# If we didn't, the model would try backpropagating all the way to start of the dataset.
hidden = repackage_hidden(hidden)
model.zero_grad()
output, hidden = model(data, hidden)
loss = criterion(output.view(-1, ntokens), targets)
loss.backward()

# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
for p in model.parameters():
p.data.add_(-lr, p.grad.data)

total_loss += loss.data

if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss[0] / args.log_interval
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
epoch, batch, len(train_data) // args.bptt, lr,
elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
total_loss = 0
start_time = time.time()

# Loop over epochs.
lr = args.lr
best_val_loss = None

# At any point you can hit Ctrl + C to break out of training early.
try:
for epoch in range(1, args.epochs+1):
epoch_start_time = time.time()
train()
val_loss = evaluate(val_data)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
val_loss, math.exp(val_loss)))
print('-' * 89)
# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
with open(args.save, 'wb') as f:
torch.save(model, f)
best_val_loss = val_loss
else:
# Anneal the learning rate if no improvement has been seen in the validation dataset.
lr /= 4.0
except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early')

# Load the best saved model.
with open(args.save, 'rb') as f:
model = torch.load(f)

# Run on test data.
test_loss = evaluate(test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)

model

import torch.nn as nn
from torch.autograd import Variable

class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""

def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
if rnn_type in ['LSTM', 'GRU']:
self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
else:
try:
nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
except KeyError:
raise ValueError( """An invalid option for `--model` was supplied,
options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)

# Optionally tie weights as in:
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
# https://arxiv.org/abs/1608.05859
# and
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
# https://arxiv.org/abs/1611.01462
if tie_weights:
if nhid != ninp:
raise ValueError('When using the tied flag, nhid must be equal to emsize')
self.decoder.weight = self.encoder.weight

self.init_weights()

self.rnn_type = rnn_type
self.nhid = nhid
self.nlayers = nlayers

def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-initrange, initrange)

def forward(self, input, hidden):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

def init_hidden(self, bsz):

#就是这个函数一直不太明白,为什么取出第一个参数,然后重新构造,难道第一个参数不学习了?
weight = next(self.parameters()).data
if self.rnn_type == 'LSTM':
return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()),
Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()))
else:
return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())

												

pytoch word_language_model 代码阅读的更多相关文章

  1. 代码阅读分析工具Understand 2.0试用

    Understand 2.0是一款源代码阅读分析软件,功能强大.试用过一段时间后,感觉相当不错,确实可以大大提高代码阅读效率.由于Understand功能十分强大,本文不可能详尽地介绍它的所有功能,所 ...

  2. Android 上的代码阅读器 CoderBrowserHD 修改支持 go 语言代码

    我在Android上的代码阅读器用的是 https://github.com/zerob13/CoderBrowserHD 改造的版本,改造后的版本我放在 https://github.com/ghj ...

  3. Linux协议栈代码阅读笔记(二)网络接口的配置

    Linux协议栈代码阅读笔记(二)网络接口的配置 (基于linux-2.6.11) (一)用户态通过C库函数ioctl进行网络接口的配置 例如,知名的ifconfig程序,就是通过C库函数sys_io ...

  4. [置顶] Linux协议栈代码阅读笔记(一)

    Linux协议栈代码阅读笔记(一) (基于linux-2.6.21.7) (一)用户态通过诸如下面的C库函数访问协议栈服务 int socket(int domain, int type, int p ...

  5. 图形化代码阅读工具——Scitools Understand

    Scitools出品的Understand 2.0.用了很多年了,比Source Insight强大很多.以前的名字叫Understand for C/C++,Understand for Java, ...

  6. Python - 关于代码阅读的一些建议

    初始能力 让阅读思路保持清晰连贯,主力关注在流程架构和逻辑实现上,不被语法.技巧和业务流程等频繁地阻碍和打断. 建议基本满足以下条件,再开始进行代码阅读: 具备一定的语言基础:熟悉基础语法,常用的函数 ...

  7. MediaInfo代码阅读

      MediaInfo是一个用来分析媒体文件的开源工具. 支持的文件非常全面,基本上支持所有的媒体文件. 最近是在做HEVC开发,所以比较关注MediaInfo中关于HEVC的分析与处理. 从Meid ...

  8. Tools - 一些代码阅读的方法

    1 初始能力 让阅读思路清晰连贯,保持在程序的流程架构和逻辑实现上,不被语法.编程技巧和业务流程等频繁地阻碍和打断. 语言基础:熟悉基础语法,常用的函数.库.编程技巧等: 了解设计模式.构建工具.代码 ...

  9. Bleve代码阅读(二)——Index Mapping

    引言 Bleve是Golang实现的一个全文检索库,类似Lucene之于Java.在这里通过阅读其代码,来学习如何使用及定制检索功能.也是为了通过阅读代码,学习在具体环境下Golang的一些使用方式. ...

随机推荐

  1. await异步的,容易理解一点

    C# 5.0中引入了async 和 await.这两个关键字可以让你更方便的写出异步代码. 看个例子: public class MyClass { public MyClass() { Displa ...

  2. DOIS 2019 DevOps国际峰会北京站来袭~

    DevOps 国际峰会是国内唯一的国际性 DevOps 技术峰会,由 OSCAR 联盟指导.DevOps 时代社区与高效运维社区联合主办,共邀全球80余名顶级专家畅谈 DevOps 体系与方法.过程与 ...

  3. cocos2dx JS 图片精灵添加纹理缓存

    添加精灵图片缓存 : cc.spriteFrameCache.addSpriteFrames("res/pic.plist"); 从缓存中获取 : var frame = cc.s ...

  4. Axure RP 8过期,用户名和序列号(注册码)

    用户名:axureuser 序列号:8wFfIX7a8hHq6yAy6T8zCz5R0NBKeVxo9IKu+kgKh79FL6IyPD6lK7G6+tqEV4LG 用户名:aaa注册码:2GQrt5 ...

  5. zw·10倍速大数据与全内存计算

    zw·10倍速大数据与全内存计算 zw全内存10倍速计算blog,早就在博客园机器视觉栏目发过,大数据版的一直挂着,今天抽空补上. 在<零起点,python大数据与量化交易>目录中 htt ...

  6. c# 文件与流

    1.创建和删除目录 在c#中涉及到输入.输出(i/o)相关操作的API都被放在System.IO命名空间下,或者子命令System.IO.IsolatedStoorage中.对目录进行操作可以使用Di ...

  7. Android开发中使用Intent跳转到系统应用中的拨号界面、联系人界面、短信界面

    现在开发中的功能需要直接跳转到拨号.联系人.短信界面等等,查找了很多资料,自己整理了一下. 首先,我们先看拨号界面,代码如下: Intent intent =new Intent(); intent. ...

  8. mock基本使用

    **一.mock解决的问题** 开发时,后端还没完成数据输出,前端只好写静态模拟数据.数据太长了,将数据写在js文件里,完成后挨个改url.某些逻辑复杂的代码,加入或去除模拟数据时得小心翼翼.想要尽可 ...

  9. php window系统 xdebug+phpstorm 本地断点调试使用教程

    运行环境: phpStorm 2017.2 PHP 7.1.5 Xdebug 2.6.1 php.ini添加xdebug模块 你需要仔细分析和选择要下载的对应版本,否则无法调试.由于非常容易出错,建议 ...

  10. Linux tshark抓包

    使用tshark进行抓包 注:需要安装wireshar抓包工具 安装:yum -y install wireshark # 可以抓的包 命令:tshark # 抓取mysql查询 命令:tshark ...