import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim CONTEXT_SIZE = 2 # the same as window_size
EMBEDDING_DIM = 10
test_sentence = "When forty winters shall besiege thy brow,And dig deep trenches in thy beauty's field,Thy youth's proud livery so gazed on now,Will be a totter'd weed of small worth held:Then being asked, where all thy beauty lies,Where all the treasure of thy lusty days;To say, within thine own deep sunken eyes,Were an all-eating shame, and thriftless praise.How much more praise deserv'd thy beauty's use,If thou couldst answer 'This fair child of mineShall sum my count, and make my old excuse,'Proving his beauty by succession thine!This were to be new made when thou art old,And see thy blood warm when thou feel'st it cold.".split() vocb = set(test_sentence) # remove repeated words
word2id = {word: i for i, word in enumerate(vocb)}
id2word = {word2id[word]: word for word in word2id} # define model
class NgramModel(nn.Module):
def __init__(self, vocb_size, context_size, n_dim):
# super(NgramModel, self)._init_()
super().__init__()
self.n_word = vocb_size
self.embedding = nn.Embedding(self.n_word, n_dim)
self.linear1 = nn.Linear(context_size*n_dim, 128)
self.linear2 = nn.Linear(128, self.n_word) def forward(self, x):
# the first step: transmit words and achieve word embedding. eg. transmit two words, and then achieve (2, 100)
emb = self.embedding(x)
# the second step: word wmbedding unfold to (1,200)
emb = emb.view(1, -1)
# the third step: transmit to linear model, and then use relu, at last, transmit to linear model again
out = self.linear1(emb)
out = F.relu(out)
out = self.linear2(out)
# the output dim of last step is the number of words, wo can view as a classification problem
# if we want to predict the max probability of the words, finally we need use log softmax
log_prob = F.log_softmax(out)
return log_prob ngrammodel = NgramModel(len(word2id), CONTEXT_SIZE, 100)
criterion = nn.NLLLoss()
optimizer = optim.SGD(ngrammodel.parameters(), lr=1e-3) trigram = [((test_sentence[i], test_sentence[i+1]), test_sentence[i+2])
for i in range(len(test_sentence)-2)] for epoch in range(100):
print('epoch: {}'.format(epoch+1))
print('*'*10)
running_loss = 0
for data in trigram:
# we use 'word' to represent the two words forward the predict word, we use 'label' to represent the predict word
word, label = data # attention
word = Variable(torch.LongTensor([word2id[e] for e in word]))
label = Variable(torch.LongTensor([word2id[label]]))
# forward
out = ngrammodel(word)
loss = criterion(out, label)
running_loss += loss.data[0]
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('loss: {:.6f}'.format(running_loss/len(word2id))) # predict
word, label = trigram[3]
word = Variable(torch.LongTensor([word2id[i] for i in word]))
out = ngrammodel(word)
_, predict_label = torch.max(out, 1)
predict_word = id2word[predict_label.data[0][0]]
print('real word is {}, predict word is {}'.format(label, predict_word))

PyTorch学习笔记之n-gram模型实现的更多相关文章

  1. ArcGIS案例学习笔记-批量裁剪地理模型

    ArcGIS案例学习笔记-批量裁剪地理模型 联系方式:谢老师,135-4855-4328,xiexiaokui#qq.com 功能:空间数据的批量裁剪 优点:1.批量裁剪:任意多个目标数据,去裁剪任意 ...

  2. Java学习笔记之---单例模型

    Java学习笔记之---单例模型 单例模型分为:饿汉式,懒汉式 (一)要点 1.某个类只能有一个实例 2.必须自行创建实例 3.必须自行向整个系统提供这个实例 (二)实现 1.只提供私有的构造方法 2 ...

  3. WebGL three.js学习笔记 加载外部模型以及Tween.js动画

    WebGL three.js学习笔记 加载外部模型以及Tween.js动画 本文的程序实现了加载外部stl格式的模型,以及学习了如何把加载的模型变为一个粒子系统,并使用Tween.js对该粒子系统进行 ...

  4. ARMV8 datasheet学习笔记5:异常模型

    1.前言 2.异常类型描述 见 ARMV8 datasheet学习笔记4:AArch64系统级体系结构之编程模型(1)-EL/ET/ST 一文 3. 异常处理路由对比 AArch32.AArch64架 ...

  5. Javascript MVC 学习笔记(一) 模型和数据

    写在前面 近期在看<MVC的Javascript富应用开发>一书.本来是抱着一口气读完的想法去看的.结果才看了一点就傻眼了:太多不懂的地方了. 仅仅好看一点查一点,一点一点往下看吧,进度虽 ...

  6. PowerDesigner 15学习笔记:十大模型及五大分类

    个人认为PowerDesigner 最大的特点和优势就是1)提供了一整套的解决方案,面向了不同的人员提供不同的模型工具,比如有针对企业架构师的模型,有针对需求分析师的模型,有针对系统分析师和软件架构师 ...

  7. [PyTorch 学习笔记] 3.1 模型创建步骤与 nn.Module

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/module_containers.py 这篇文章来看下 ...

  8. [PyTorch 学习笔记] 7.1 模型保存与加载

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...

  9. PyTorch学习笔记之CBOW模型实践

    import torch from torch import nn, optim from torch.autograd import Variable import torch.nn.functio ...

随机推荐

  1. Applied Nonparametric Statistics-lec1

    参考网址: https://onlinecourses.science.psu.edu/stat464/node/2 Binomial Distribution Normal Distribution ...

  2. POJ3216 最小路径覆盖

    首先说一下题意,Q个区域,M个任务,每个区域任务可能有多个,然后给你个到各地所需时间的矩阵,每个任务都有开始和持续时间,问最少需要多少工人? 每个工人只能同时执行一个任务. 通过题意,我的瞬间反应就是 ...

  3. python学习-- Django根据现有数据库,自动生成models模型文件

    Django引入外部数据库还是比较方便的,步骤如下 : 创建一个项目,修改seting文件,在setting里面设置你要连接的数据库类型和连接名称,地址之类,和创建新项目的时候一致 运行下面代码可以自 ...

  4. python学习-- 理解'*','*args','**','**kwargs'

    刚开始学习Python的时候,对有关args,kwargs,和*的使用感到很困惑.相信对此感到疑惑的人也有很多.我打算通过这个帖子来排解这个疑惑(希望能减少疑惑). 让我们通过以下5步来理解: 1.  ...

  5. Leetcode 436.寻找右区间

    寻找右区间 给定一组区间,对于每一个区间 i,检查是否存在一个区间 j,它的起始点大于或等于区间 i 的终点,这可以称为 j 在 i 的"右侧". 对于任何区间,你需要存储的满足条 ...

  6. ASP.NET配置设置-关于web.config各节点的讲解

    在msdn中搜索:“ASP.NET配置设置”,可以查看各个节点的配置. httpRuntime 元素:配置 ASP.NET HTTP 运行时设置,以确定如何处理对 ASP.NET 应用程序的请求.

  7. 【Vjudge】P1989Subpalindromes(线段树)

    题目链接 水题一道,用线段树维护哈希值,脑补一下加减乱搞搞……注意细节就过了 一定注意细节…… #include<cstdio> #include<cstdlib> #incl ...

  8. [luoguP2766] 最长递增子序列问题(最大流)

    传送门 题解来自网络流24题: [问题分析] 第一问时LIS,动态规划求解,第二问和第三问用网络最大流解决. [建模方法] 首先动态规划求出F[i],表示以第i位为开头的最长上升序列的长度,求出最长上 ...

  9. iOS-多线程(3)

    多线程之GCD(grand central dispatch)中心调度 为了简化多线程的操作,iOS为我们提供了GCD来实现编程. 使用GCD只要遵守两个步骤即可: 创建对列(串行队列,并行队列) 将 ...

  10. bzoj 4295 [PA2015]Hazard 贪心,暴力

    [PA2015]Hazard Time Limit: 10 Sec  Memory Limit: 512 MBSubmit: 69  Solved: 19[Submit][Status][Discus ...