PyTorch学习笔记之n-gram模型实现
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim CONTEXT_SIZE = 2 # the same as window_size
EMBEDDING_DIM = 10
test_sentence = "When forty winters shall besiege thy brow,And dig deep trenches in thy beauty's field,Thy youth's proud livery so gazed on now,Will be a totter'd weed of small worth held:Then being asked, where all thy beauty lies,Where all the treasure of thy lusty days;To say, within thine own deep sunken eyes,Were an all-eating shame, and thriftless praise.How much more praise deserv'd thy beauty's use,If thou couldst answer 'This fair child of mineShall sum my count, and make my old excuse,'Proving his beauty by succession thine!This were to be new made when thou art old,And see thy blood warm when thou feel'st it cold.".split() vocb = set(test_sentence) # remove repeated words
word2id = {word: i for i, word in enumerate(vocb)}
id2word = {word2id[word]: word for word in word2id} # define model
class NgramModel(nn.Module):
def __init__(self, vocb_size, context_size, n_dim):
# super(NgramModel, self)._init_()
super().__init__()
self.n_word = vocb_size
self.embedding = nn.Embedding(self.n_word, n_dim)
self.linear1 = nn.Linear(context_size*n_dim, 128)
self.linear2 = nn.Linear(128, self.n_word) def forward(self, x):
# the first step: transmit words and achieve word embedding. eg. transmit two words, and then achieve (2, 100)
emb = self.embedding(x)
# the second step: word wmbedding unfold to (1,200)
emb = emb.view(1, -1)
# the third step: transmit to linear model, and then use relu, at last, transmit to linear model again
out = self.linear1(emb)
out = F.relu(out)
out = self.linear2(out)
# the output dim of last step is the number of words, wo can view as a classification problem
# if we want to predict the max probability of the words, finally we need use log softmax
log_prob = F.log_softmax(out)
return log_prob ngrammodel = NgramModel(len(word2id), CONTEXT_SIZE, 100)
criterion = nn.NLLLoss()
optimizer = optim.SGD(ngrammodel.parameters(), lr=1e-3) trigram = [((test_sentence[i], test_sentence[i+1]), test_sentence[i+2])
for i in range(len(test_sentence)-2)] for epoch in range(100):
print('epoch: {}'.format(epoch+1))
print('*'*10)
running_loss = 0
for data in trigram:
# we use 'word' to represent the two words forward the predict word, we use 'label' to represent the predict word
word, label = data # attention
word = Variable(torch.LongTensor([word2id[e] for e in word]))
label = Variable(torch.LongTensor([word2id[label]]))
# forward
out = ngrammodel(word)
loss = criterion(out, label)
running_loss += loss.data[0]
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('loss: {:.6f}'.format(running_loss/len(word2id))) # predict
word, label = trigram[3]
word = Variable(torch.LongTensor([word2id[i] for i in word]))
out = ngrammodel(word)
_, predict_label = torch.max(out, 1)
predict_word = id2word[predict_label.data[0][0]]
print('real word is {}, predict word is {}'.format(label, predict_word))
PyTorch学习笔记之n-gram模型实现的更多相关文章
- ArcGIS案例学习笔记-批量裁剪地理模型
ArcGIS案例学习笔记-批量裁剪地理模型 联系方式:谢老师,135-4855-4328,xiexiaokui#qq.com 功能:空间数据的批量裁剪 优点:1.批量裁剪:任意多个目标数据,去裁剪任意 ...
- Java学习笔记之---单例模型
Java学习笔记之---单例模型 单例模型分为:饿汉式,懒汉式 (一)要点 1.某个类只能有一个实例 2.必须自行创建实例 3.必须自行向整个系统提供这个实例 (二)实现 1.只提供私有的构造方法 2 ...
- WebGL three.js学习笔记 加载外部模型以及Tween.js动画
WebGL three.js学习笔记 加载外部模型以及Tween.js动画 本文的程序实现了加载外部stl格式的模型,以及学习了如何把加载的模型变为一个粒子系统,并使用Tween.js对该粒子系统进行 ...
- ARMV8 datasheet学习笔记5:异常模型
1.前言 2.异常类型描述 见 ARMV8 datasheet学习笔记4:AArch64系统级体系结构之编程模型(1)-EL/ET/ST 一文 3. 异常处理路由对比 AArch32.AArch64架 ...
- Javascript MVC 学习笔记(一) 模型和数据
写在前面 近期在看<MVC的Javascript富应用开发>一书.本来是抱着一口气读完的想法去看的.结果才看了一点就傻眼了:太多不懂的地方了. 仅仅好看一点查一点,一点一点往下看吧,进度虽 ...
- PowerDesigner 15学习笔记:十大模型及五大分类
个人认为PowerDesigner 最大的特点和优势就是1)提供了一整套的解决方案,面向了不同的人员提供不同的模型工具,比如有针对企业架构师的模型,有针对需求分析师的模型,有针对系统分析师和软件架构师 ...
- [PyTorch 学习笔记] 3.1 模型创建步骤与 nn.Module
本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/module_containers.py 这篇文章来看下 ...
- [PyTorch 学习笔记] 7.1 模型保存与加载
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...
- PyTorch学习笔记之CBOW模型实践
import torch from torch import nn, optim from torch.autograd import Variable import torch.nn.functio ...
随机推荐
- 数据库学习网站和linux学习网站
Oracle ITPub论坛 http://www.itpub.net 著名IT技术论坛.尤以数据库技术闻名. ITPUB论坛的前身应该是建立在 smiling 的 oracle小组,他们搬家前的主页 ...
- Oracle 10g Data Pump Expdp/Impdp 详解
Data Pump 介绍 在第一部分看了2段官网的说明, 可以看出数据泵的工作流程如下: (1)在命令行执行命令 (2)expdp/impd 命令调用DBMS_DATAPUMP PL/SQL包. 这个 ...
- webdriver高级应用- 启动FireFox的同时打开Firebug
1. 首先本机Firefox浏览器需要安装一下firebug插件,具体怎么安装这里不赘述,网上教程很多. 2. 具体自动化实现的代码如下: #encoding=utf-8 from selenium ...
- Nginx从入门到放弃-第2章 基础篇
2-1 什么是Nginx 2-2 常见的中间件服务 2-3 Nginx的特性_实现优点1 2-4 Nginx特性_实现优点2 2-5 Nginx特性_实现优点3 2-6 Nginx特性_实现优点4 2 ...
- Navicat Premium 连接Oracle登入时候报ORA-12638: 身份证明检索失败的解决办法
我的电脑是64位,oracle也是64位, plsql客户端是32位,oci连接的是32位 11.2版本: 用plsql 连接本地或远程数据库都没问题.在用 Navicat Premium 连本也没问 ...
- [转]查看Linux版本信息
一.查看Linux内核版本命令(两种方法): 1.cat /proc/version [root@S-CentOS home]# cat /proc/version Linux version 2.6 ...
- centos7 install google-chrome
important: Google Chrome support for all 32-bit Linux distributions is deprecated from March, 2016. ...
- XDEBUG 远程调试
我的PHP环境是安装在虚拟机中.真机系统用的是windows.那么我要用XDEBUG调试代码,就得用XDEBUG的远程调试功能. 首先要给远程环境中安装XDEBUG扩展,具体方法:http://www ...
- 读《MySql必知必会》笔记
MySql必知必会 2017-12-21 意义:记录个人不注意的,或不明确的,或不知道的细节方法技巧,此书250页 登陆: mysql -u root-p -h myserver -P 9999 SH ...
- Js 希望某链接只能点击一次
<a onclick=”function(){...}”> 希望这连接只能执行一次 <a onclick=”function(){...}; this.onclick()=funct ...