pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》
论文 《 Convolutional Neural Networks for Sentence Classification》通过CNN实现了文本分类。
论文地址: 666666
模型图:
模型解释可以看论文,给出code and comment:https://github.com/graykode/nlp-tutorial
# -*- coding: utf-8 -*-
# @time : 2019/11/9 13:55 import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F dtype = torch.FloatTensor # Text-CNN Parameter
embedding_size = 2 # n-gram
sequence_length = 3
num_classes = 2 # 0 or 1
filter_sizes = [2, 2, 2] # n-gram window
num_filters = 3 # 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good. word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}
vocab_size = len(word_dict) inputs = []
for sen in sentences:
inputs.append(np.asarray([word_dict[n] for n in sen.split()])) targets = []
for out in labels:
targets.append(out) # To using Torch Softmax Loss function input_batch = Variable(torch.LongTensor(inputs))
target_batch = Variable(torch.LongTensor(targets)) class TextCNN(nn.Module):
def __init__(self):
super(TextCNN, self).__init__() self.num_filters_total = num_filters * len(filter_sizes)
self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)
self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype)
self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype) def forward(self, X):
embedded_chars = self.W[X] # [batch_size, sequence_length, sequence_length]
embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size] pooled_outputs = []
for filter_size in filter_sizes:
# conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars)
h = F.relu(conv)
# mp : ((filter_height, filter_width))
mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))
# pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
pooled = mp(h).permute(0, 3, 2, 1)
pooled_outputs.append(pooled) h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]
h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)] model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]
return model model = TextCNN() criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) # Training
for epoch in range(5000):
optimizer.zero_grad()
output = model(input_batch) # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)
loss = criterion(output, target_batch)
if (epoch + 1) % 1000 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss)) loss.backward()
optimizer.step() # Test
test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = Variable(torch.LongTensor(tests)) # Predict
predict = model(test_batch).data.max(1, keepdim=True)[1]
if predict[0][0] == 0:
print(test_text,"is Bad Mean...")
else:
print(test_text,"is Good Mean!!")
pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》的更多相关文章
- 卷积神经网络用语句子分类---Convolutional Neural Networks for Sentence Classification 学习笔记
读了一篇文章,用到卷积神经网络的方法来进行文本分类,故写下一点自己的学习笔记: 本文在事先进行单词向量的学习的基础上,利用卷积神经网络(CNN)进行句子分类,然后通过微调学习任务特定的向量,提高性能. ...
- 《Convolutional Neural Networks for Sentence Classification》 文本分类
文本分类任务中可以利用CNN来提取句子中类似 n-gram 的关键信息. TextCNN的详细过程原理图见下: keras 代码: def convs_block(data, convs=[3, 3, ...
- [NLP-CNN] Convolutional Neural Networks for Sentence Classification -2014-EMNLP
1. Overview 本文将CNN用于句子分类任务 (1) 使用静态vector + CNN即可取得很好的效果:=> 这表明预训练的vector是universal的特征提取器,可以被用于多种 ...
- CNN 文本分类
谈到文本分类,就不得不谈谈CNN(Convolutional Neural Networks).这个经典的结构在文本分类中取得了不俗的结果,而运用在这里的卷积可以分为1d .2d甚至是3d的. 下面 ...
- [转] Understanding Convolutional Neural Networks for NLP
http://www.wildml.com/2015/11/understanding-convolutional-neural-networks-for-nlp/ 讲CNN以及其在NLP的应用,非常 ...
- Understanding Convolutional Neural Networks for NLP
When we hear about Convolutional Neural Network (CNNs), we typically think of Computer Vision. CNNs ...
- How to Use Convolutional Neural Networks for Time Series Classification
How to Use Convolutional Neural Networks for Time Series Classification 2019-10-08 12:09:35 This blo ...
- Deep learning_CNN_Review:A Survey of the Recent Architectures of Deep Convolutional Neural Networks——2019
CNN综述文章 的翻译 [2019 CVPR] A Survey of the Recent Architectures of Deep Convolutional Neural Networks 翻 ...
- [转]XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks
感谢: XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks XNOR-Net ImageNet Cl ...
随机推荐
- python二维列表求解所有元素之和
2020-01-14 相信很多初学小伙伴都会遇到二维列表求解所有元素之和问题,下面给出两种两种常见的求和方法. 方法1: 思想:遍历整个二维列表元素,然后将所有元素加起来 def Sum_matrix ...
- python——pickle模块的详解
pickle模块详解 该pickle模块实现了用于序列化和反序列化Python对象结构的二进制协议. “Pickling”是将Python对象层次结构转换为字节流的过程, “unpickling”是反 ...
- ES 服务器 索引、类型仓库基类 BaseESStorage
/******************************************************* * * 作者:朱皖苏 * 创建日期:20180508 * 说明:此文件只包含一个类,具 ...
- feign架构 原理解析
什么是feign? 来自官网的解释:Feign makes writing java http clients easier 在使用feign之前,我们怎么发送请求? 拿okhttp举例: publi ...
- next_permutation 函数
next_permutation 是一个定义在 <algorithm> 中的一个全排列函数, 用于按顺序生成一个数列的全排列 基本用法 : int a[] = {1, 2, 3}; do{ ...
- 程序员Java架构师多线程面试题和回答解析
当我们在Java架构师面试的过程中常见的多线程和并发方面的问题肯定是必不可少的一部分.那么在面试之前我们更应该多准备一些关于多线程方面的问题. 面试官只是想确信面试者有足够的Java线程与并发方面的知 ...
- [bzoj3991] [洛谷P3320] [SDOI2015] 寻宝游戏
Description 小B最近正在玩一个寻宝游戏,这个游戏的地图中有 \(N\) 个村庄和 \(N-1\) 条道路,并且任何两个村庄之间有且仅有一条路径可达.游戏开始时,玩家可以任意选择一个村庄,瞬 ...
- 解决python爬虫requests.exceptions.SSLError: HTTPSConnectionPool(host='XXX', port=443)问题
爬虫时报错如下: requests.exceptions.SSLError: HTTPSConnectionPool(host='某某某网站', port=443): Max retries exce ...
- Python通过win32 com接口实现offic自动化
最近几天通过Python做一些自动生成office报表的东东,比如解析.xml文件,导出.html/WORD/PPT等格式,html不足一提,只需要简单的html静态网页知识即可,这儿要说的是怎么生成 ...
- Dynamics CRM Package Deployer 部署工具
CRM 部署工具非常有用 我们可以用部署工具来做迁移,部署 等等. Package Deployer可以同时部署多个solutions. 并且也可以勾选solution的plugin也同时部署. 三 ...