pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》
论文 《 Convolutional Neural Networks for Sentence Classification》通过CNN实现了文本分类。
论文地址: 666666
模型图:
模型解释可以看论文,给出code and comment:https://github.com/graykode/nlp-tutorial
- # -*- coding: utf-8 -*-
- # @time : 2019/11/9 13:55
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.autograd import Variable
- import torch.nn.functional as F
- dtype = torch.FloatTensor
- # Text-CNN Parameter
- embedding_size = 2 # n-gram
- sequence_length = 3
- num_classes = 2 # 0 or 1
- filter_sizes = [2, 2, 2] # n-gram window
- num_filters = 3
- # 3 words sentences (=sequence_length is 3)
- sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
- labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good.
- word_list = " ".join(sentences).split()
- word_list = list(set(word_list))
- word_dict = {w: i for i, w in enumerate(word_list)}
- vocab_size = len(word_dict)
- inputs = []
- for sen in sentences:
- inputs.append(np.asarray([word_dict[n] for n in sen.split()]))
- targets = []
- for out in labels:
- targets.append(out) # To using Torch Softmax Loss function
- input_batch = Variable(torch.LongTensor(inputs))
- target_batch = Variable(torch.LongTensor(targets))
- class TextCNN(nn.Module):
- def __init__(self):
- super(TextCNN, self).__init__()
- self.num_filters_total = num_filters * len(filter_sizes)
- self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)
- self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype)
- self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype)
- def forward(self, X):
- embedded_chars = self.W[X] # [batch_size, sequence_length, sequence_length]
- embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size]
- pooled_outputs = []
- for filter_size in filter_sizes:
- # conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
- conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars)
- h = F.relu(conv)
- # mp : ((filter_height, filter_width))
- mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))
- # pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
- pooled = mp(h).permute(0, 3, 2, 1)
- pooled_outputs.append(pooled)
- h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]
- h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)]
- model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]
- return model
- model = TextCNN()
- criterion = nn.CrossEntropyLoss()
- optimizer = optim.Adam(model.parameters(), lr=0.001)
- # Training
- for epoch in range(5000):
- optimizer.zero_grad()
- output = model(input_batch)
- # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)
- loss = criterion(output, target_batch)
- if (epoch + 1) % 1000 == 0:
- print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
- loss.backward()
- optimizer.step()
- # Test
- test_text = 'sorry hate you'
- tests = [np.asarray([word_dict[n] for n in test_text.split()])]
- test_batch = Variable(torch.LongTensor(tests))
- # Predict
- predict = model(test_batch).data.max(1, keepdim=True)[1]
- if predict[0][0] == 0:
- print(test_text,"is Bad Mean...")
- else:
- print(test_text,"is Good Mean!!")
pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》的更多相关文章
- 卷积神经网络用语句子分类---Convolutional Neural Networks for Sentence Classification 学习笔记
读了一篇文章,用到卷积神经网络的方法来进行文本分类,故写下一点自己的学习笔记: 本文在事先进行单词向量的学习的基础上,利用卷积神经网络(CNN)进行句子分类,然后通过微调学习任务特定的向量,提高性能. ...
- 《Convolutional Neural Networks for Sentence Classification》 文本分类
文本分类任务中可以利用CNN来提取句子中类似 n-gram 的关键信息. TextCNN的详细过程原理图见下: keras 代码: def convs_block(data, convs=[3, 3, ...
- [NLP-CNN] Convolutional Neural Networks for Sentence Classification -2014-EMNLP
1. Overview 本文将CNN用于句子分类任务 (1) 使用静态vector + CNN即可取得很好的效果:=> 这表明预训练的vector是universal的特征提取器,可以被用于多种 ...
- CNN 文本分类
谈到文本分类,就不得不谈谈CNN(Convolutional Neural Networks).这个经典的结构在文本分类中取得了不俗的结果,而运用在这里的卷积可以分为1d .2d甚至是3d的. 下面 ...
- [转] Understanding Convolutional Neural Networks for NLP
http://www.wildml.com/2015/11/understanding-convolutional-neural-networks-for-nlp/ 讲CNN以及其在NLP的应用,非常 ...
- Understanding Convolutional Neural Networks for NLP
When we hear about Convolutional Neural Network (CNNs), we typically think of Computer Vision. CNNs ...
- How to Use Convolutional Neural Networks for Time Series Classification
How to Use Convolutional Neural Networks for Time Series Classification 2019-10-08 12:09:35 This blo ...
- Deep learning_CNN_Review:A Survey of the Recent Architectures of Deep Convolutional Neural Networks——2019
CNN综述文章 的翻译 [2019 CVPR] A Survey of the Recent Architectures of Deep Convolutional Neural Networks 翻 ...
- [转]XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks
感谢: XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks XNOR-Net ImageNet Cl ...
随机推荐
- XSS基础学习
XSS基础学习 By:Mirror王宇阳 什么是XSS XSS攻击是指在网页中嵌入一段恶意的客户端Js脚本代码片段,JS脚本恶意代码可以获取用户的Cookie.URL跳转.内容篡改.会话劫持--等. ...
- 切蛋糕(贪心 or 优先队列)
链接:https://www.nowcoder.com/acm/contest/80/D来源:牛客网 最可爱的applese生日啦,他准备了许多个质量不同的蛋糕,想请一些同学来参加他的派对为他庆生,为 ...
- 跨源请求cors和jsonp
0.产生跨域的原因 浏览器的同源策略 什么是浏览器的同源策略? src开发 ajax禁止 解决方法 jsonp 通过src绕过浏览器的同源策略 缺点:只发送GET请求 cors 通过设置相应头 分类 ...
- 关于Spring集成Quartz的concurrent属性
关于Spring集成Quartz的concurrent属性 以前经常在任务调度程序中使用Spring集成的Quartz,这种方式可以用简单的声明式配置即可实现定时任务,并结合了Spring自身的Bea ...
- 用JAVA分别实现WebSocket客户端与服务端
最近公司在搞一个项目刚好需要用到WebSocket技术来实现实时数据的传输,因为之前也没接触过,所以捣鼓了好些天,最近恰巧有空就写写.有误的地方还请大牛们能及时指正. 项目背景:基于spring+sp ...
- RainbowPlan团队项目-总结
博客介绍 这个作业属于哪个课程 https://edu.cnblogs.com/campus/xnsy/GeographicInformationScience/ 这个作业要求在哪里 https:// ...
- export 和 export default 的区别
export命令用于规定模块的对外接口. 一个模块就是一个独立的文件.该文件内部的所有变量,外部无法获取.如果你希望外部能够读取模块内部的某个变量,就必须使用export关键字输出该变量.下面是一个 ...
- 聊一聊 MySQL 数据库中的那些锁
在软件开发中,程序在高并发的情况下,为了保证一致性或者说安全性,我们通常都会通过加锁的方式来解决,在 MySQL 数据库中同样有这样的问题,一方面为了最大程度的利用数据库的并发访问,另一方面又需要保证 ...
- laravel aritisan命令大全
1常用命令 显示某个命令的帮助 php artisan -h make:controller 实例命令 php artisan make:controller -r Api/TestControlle ...
- linux下redis的部署
https://www.cnblogs.com/wangchunniu1314/p/6339416.html https://www.linuxidc.com/Linux/2017-09/146894 ...