PyTorch学习笔记之n-gram模型实现
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim CONTEXT_SIZE = 2 # the same as window_size
EMBEDDING_DIM = 10
test_sentence = "When forty winters shall besiege thy brow,And dig deep trenches in thy beauty's field,Thy youth's proud livery so gazed on now,Will be a totter'd weed of small worth held:Then being asked, where all thy beauty lies,Where all the treasure of thy lusty days;To say, within thine own deep sunken eyes,Were an all-eating shame, and thriftless praise.How much more praise deserv'd thy beauty's use,If thou couldst answer 'This fair child of mineShall sum my count, and make my old excuse,'Proving his beauty by succession thine!This were to be new made when thou art old,And see thy blood warm when thou feel'st it cold.".split() vocb = set(test_sentence) # remove repeated words
word2id = {word: i for i, word in enumerate(vocb)}
id2word = {word2id[word]: word for word in word2id} # define model
class NgramModel(nn.Module):
def __init__(self, vocb_size, context_size, n_dim):
# super(NgramModel, self)._init_()
super().__init__()
self.n_word = vocb_size
self.embedding = nn.Embedding(self.n_word, n_dim)
self.linear1 = nn.Linear(context_size*n_dim, 128)
self.linear2 = nn.Linear(128, self.n_word) def forward(self, x):
# the first step: transmit words and achieve word embedding. eg. transmit two words, and then achieve (2, 100)
emb = self.embedding(x)
# the second step: word wmbedding unfold to (1,200)
emb = emb.view(1, -1)
# the third step: transmit to linear model, and then use relu, at last, transmit to linear model again
out = self.linear1(emb)
out = F.relu(out)
out = self.linear2(out)
# the output dim of last step is the number of words, wo can view as a classification problem
# if we want to predict the max probability of the words, finally we need use log softmax
log_prob = F.log_softmax(out)
return log_prob ngrammodel = NgramModel(len(word2id), CONTEXT_SIZE, 100)
criterion = nn.NLLLoss()
optimizer = optim.SGD(ngrammodel.parameters(), lr=1e-3) trigram = [((test_sentence[i], test_sentence[i+1]), test_sentence[i+2])
for i in range(len(test_sentence)-2)] for epoch in range(100):
print('epoch: {}'.format(epoch+1))
print('*'*10)
running_loss = 0
for data in trigram:
# we use 'word' to represent the two words forward the predict word, we use 'label' to represent the predict word
word, label = data # attention
word = Variable(torch.LongTensor([word2id[e] for e in word]))
label = Variable(torch.LongTensor([word2id[label]]))
# forward
out = ngrammodel(word)
loss = criterion(out, label)
running_loss += loss.data[0]
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('loss: {:.6f}'.format(running_loss/len(word2id))) # predict
word, label = trigram[3]
word = Variable(torch.LongTensor([word2id[i] for i in word]))
out = ngrammodel(word)
_, predict_label = torch.max(out, 1)
predict_word = id2word[predict_label.data[0][0]]
print('real word is {}, predict word is {}'.format(label, predict_word))
PyTorch学习笔记之n-gram模型实现的更多相关文章
- ArcGIS案例学习笔记-批量裁剪地理模型
ArcGIS案例学习笔记-批量裁剪地理模型 联系方式:谢老师,135-4855-4328,xiexiaokui#qq.com 功能:空间数据的批量裁剪 优点:1.批量裁剪:任意多个目标数据,去裁剪任意 ...
- Java学习笔记之---单例模型
Java学习笔记之---单例模型 单例模型分为:饿汉式,懒汉式 (一)要点 1.某个类只能有一个实例 2.必须自行创建实例 3.必须自行向整个系统提供这个实例 (二)实现 1.只提供私有的构造方法 2 ...
- WebGL three.js学习笔记 加载外部模型以及Tween.js动画
WebGL three.js学习笔记 加载外部模型以及Tween.js动画 本文的程序实现了加载外部stl格式的模型,以及学习了如何把加载的模型变为一个粒子系统,并使用Tween.js对该粒子系统进行 ...
- ARMV8 datasheet学习笔记5:异常模型
1.前言 2.异常类型描述 见 ARMV8 datasheet学习笔记4:AArch64系统级体系结构之编程模型(1)-EL/ET/ST 一文 3. 异常处理路由对比 AArch32.AArch64架 ...
- Javascript MVC 学习笔记(一) 模型和数据
写在前面 近期在看<MVC的Javascript富应用开发>一书.本来是抱着一口气读完的想法去看的.结果才看了一点就傻眼了:太多不懂的地方了. 仅仅好看一点查一点,一点一点往下看吧,进度虽 ...
- PowerDesigner 15学习笔记:十大模型及五大分类
个人认为PowerDesigner 最大的特点和优势就是1)提供了一整套的解决方案,面向了不同的人员提供不同的模型工具,比如有针对企业架构师的模型,有针对需求分析师的模型,有针对系统分析师和软件架构师 ...
- [PyTorch 学习笔记] 3.1 模型创建步骤与 nn.Module
本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/module_containers.py 这篇文章来看下 ...
- [PyTorch 学习笔记] 7.1 模型保存与加载
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...
- PyTorch学习笔记之CBOW模型实践
import torch from torch import nn, optim from torch.autograd import Variable import torch.nn.functio ...
随机推荐
- 水题:UVa133-The Dole Queue
The Dole Queue Time limit 3000 ms Description In a serious attempt to downsize (reduce) the dole que ...
- CSS 工程化 小结
CSS 工程化 组成:1.组织 (代码目录)2.优化(那种实现方式更好) 3.构建(代码完成后需要经过哪些处理步骤) 4.维护 常见问题 1.如何解决 CSS 模块化问题 1.Less Sass 等C ...
- python网络编程相关
什么是网络套接字socket?简述基于tcp协议的套接字的通信流程. 为了区别不同的应用程序进程和连接,许多计算机操作系统为应用程序与TCP/IP协议交互提供了称为套接字 (Socket)的接口,区分 ...
- day03_09 编码部分历史及文件编码简介
详细课件:http://www.cnblogs.com/alex3714/articles/5465198.html 字符编码 支持中文的第一张表就是GB2312 1980 gb2312 6700+ ...
- 使用mysql监视器即命令行下的mysql
命令行下登录mysql 首先必须在alias下有设置mysql, 我的mysql安装的位置在/usr/local/mysql 于是做了一个别名: alias mysql='/usr/local/mys ...
- Octave 里的 fminunc
ptions = optimset('GradObj', 'on', 'MaxIter', '100'); initialTheta = zeros(2,1); [optTheta, function ...
- Django模板(filter过滤器{{ }}与tag标签{% %}应用)
模板里面过滤器与标签的应用 templates模板里面的应用参考(主要应用在这里面) <!DOCTYPE html> <html lang="en"> & ...
- 九度oj 1006
题目1006:ZOJ问题 时间限制:1 秒 内存限制:32 兆 特殊判题:否 提交:20252 解决:3544 题目描述: 对给定的字符串(只包含'z', ...
- Thinkphp5.1手册太简单,有的功能用起来不确定结果是否和预料的一样,顾整理记录
//模板{if false} 1 {else/} //====>可以使用 效果同 {else /} 2 {/if} {if condition="(1 eq 1) and false& ...
- 【Luogu】P3317重建(高斯消元+矩阵树定理)
题目链接 因为这个专门跑去学了矩阵树定理和高斯消元qwq 不过不是很懂.所以这里只放题解 玫葵之蝶的题解 某未知dalao的矩阵树定理 代码 #include<cstdio> #inclu ...