论文  《 Convolutional Neural Networks for Sentence Classification》通过CNN实现了文本分类。

论文地址: 666666

模型图:

  

模型解释可以看论文,给出code and comment:https://github.com/graykode/nlp-tutorial

  1. # -*- coding: utf-8 -*-
  2. # @time : 2019/11/9 13:55
  3.  
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.optim as optim
  8. from torch.autograd import Variable
  9. import torch.nn.functional as F
  10.  
  11. dtype = torch.FloatTensor
  12.  
  13. # Text-CNN Parameter
  14. embedding_size = 2 # n-gram
  15. sequence_length = 3
  16. num_classes = 2 # 0 or 1
  17. filter_sizes = [2, 2, 2] # n-gram window
  18. num_filters = 3
  19.  
  20. # 3 words sentences (=sequence_length is 3)
  21. sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
  22. labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good.
  23.  
  24. word_list = " ".join(sentences).split()
  25. word_list = list(set(word_list))
  26. word_dict = {w: i for i, w in enumerate(word_list)}
  27. vocab_size = len(word_dict)
  28.  
  29. inputs = []
  30. for sen in sentences:
  31. inputs.append(np.asarray([word_dict[n] for n in sen.split()]))
  32.  
  33. targets = []
  34. for out in labels:
  35. targets.append(out) # To using Torch Softmax Loss function
  36.  
  37. input_batch = Variable(torch.LongTensor(inputs))
  38. target_batch = Variable(torch.LongTensor(targets))
  39.  
  40. class TextCNN(nn.Module):
  41. def __init__(self):
  42. super(TextCNN, self).__init__()
  43.  
  44. self.num_filters_total = num_filters * len(filter_sizes)
  45. self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)
  46. self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype)
  47. self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype)
  48.  
  49. def forward(self, X):
  50. embedded_chars = self.W[X] # [batch_size, sequence_length, sequence_length]
  51. embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size]
  52.  
  53. pooled_outputs = []
  54. for filter_size in filter_sizes:
  55. # conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
  56. conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars)
  57. h = F.relu(conv)
  58. # mp : ((filter_height, filter_width))
  59. mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))
  60. # pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
  61. pooled = mp(h).permute(0, 3, 2, 1)
  62. pooled_outputs.append(pooled)
  63.  
  64. h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]
  65. h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)]
  66.  
  67. model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]
  68. return model
  69.  
  70. model = TextCNN()
  71.  
  72. criterion = nn.CrossEntropyLoss()
  73. optimizer = optim.Adam(model.parameters(), lr=0.001)
  74.  
  75. # Training
  76. for epoch in range(5000):
  77. optimizer.zero_grad()
  78. output = model(input_batch)
  79.  
  80. # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)
  81. loss = criterion(output, target_batch)
  82. if (epoch + 1) % 1000 == 0:
  83. print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
  84.  
  85. loss.backward()
  86. optimizer.step()
  87.  
  88. # Test
  89. test_text = 'sorry hate you'
  90. tests = [np.asarray([word_dict[n] for n in test_text.split()])]
  91. test_batch = Variable(torch.LongTensor(tests))
  92.  
  93. # Predict
  94. predict = model(test_batch).data.max(1, keepdim=True)[1]
  95. if predict[0][0] == 0:
  96. print(test_text,"is Bad Mean...")
  97. else:
  98. print(test_text,"is Good Mean!!")

pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》的更多相关文章

  1. 卷积神经网络用语句子分类---Convolutional Neural Networks for Sentence Classification 学习笔记

    读了一篇文章,用到卷积神经网络的方法来进行文本分类,故写下一点自己的学习笔记: 本文在事先进行单词向量的学习的基础上,利用卷积神经网络(CNN)进行句子分类,然后通过微调学习任务特定的向量,提高性能. ...

  2. 《Convolutional Neural Networks for Sentence Classification》 文本分类

    文本分类任务中可以利用CNN来提取句子中类似 n-gram 的关键信息. TextCNN的详细过程原理图见下: keras 代码: def convs_block(data, convs=[3, 3, ...

  3. [NLP-CNN] Convolutional Neural Networks for Sentence Classification -2014-EMNLP

    1. Overview 本文将CNN用于句子分类任务 (1) 使用静态vector + CNN即可取得很好的效果:=> 这表明预训练的vector是universal的特征提取器,可以被用于多种 ...

  4. CNN 文本分类

    谈到文本分类,就不得不谈谈CNN(Convolutional Neural Networks).这个经典的结构在文本分类中取得了不俗的结果,而运用在这里的卷积可以分为1d .2d甚至是3d的.  下面 ...

  5. [转] Understanding Convolutional Neural Networks for NLP

    http://www.wildml.com/2015/11/understanding-convolutional-neural-networks-for-nlp/ 讲CNN以及其在NLP的应用,非常 ...

  6. Understanding Convolutional Neural Networks for NLP

    When we hear about Convolutional Neural Network (CNNs), we typically think of Computer Vision. CNNs ...

  7. How to Use Convolutional Neural Networks for Time Series Classification

    How to Use Convolutional Neural Networks for Time Series Classification 2019-10-08 12:09:35 This blo ...

  8. Deep learning_CNN_Review:A Survey of the Recent Architectures of Deep Convolutional Neural Networks——2019

    CNN综述文章 的翻译 [2019 CVPR] A Survey of the Recent Architectures of Deep Convolutional Neural Networks 翻 ...

  9. [转]XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks

    感谢: XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks XNOR-Net ImageNet Cl ...

随机推荐

  1. XSS基础学习

    XSS基础学习 By:Mirror王宇阳 什么是XSS XSS攻击是指在网页中嵌入一段恶意的客户端Js脚本代码片段,JS脚本恶意代码可以获取用户的Cookie.URL跳转.内容篡改.会话劫持--等. ...

  2. 切蛋糕(贪心 or 优先队列)

    链接:https://www.nowcoder.com/acm/contest/80/D来源:牛客网 最可爱的applese生日啦,他准备了许多个质量不同的蛋糕,想请一些同学来参加他的派对为他庆生,为 ...

  3. 跨源请求cors和jsonp

    0.产生跨域的原因 浏览器的同源策略 什么是浏览器的同源策略? src开发 ajax禁止 解决方法 jsonp 通过src绕过浏览器的同源策略 缺点:只发送GET请求 cors 通过设置相应头 分类 ...

  4. 关于Spring集成Quartz的concurrent属性

    关于Spring集成Quartz的concurrent属性 以前经常在任务调度程序中使用Spring集成的Quartz,这种方式可以用简单的声明式配置即可实现定时任务,并结合了Spring自身的Bea ...

  5. 用JAVA分别实现WebSocket客户端与服务端

    最近公司在搞一个项目刚好需要用到WebSocket技术来实现实时数据的传输,因为之前也没接触过,所以捣鼓了好些天,最近恰巧有空就写写.有误的地方还请大牛们能及时指正. 项目背景:基于spring+sp ...

  6. RainbowPlan团队项目-总结

    博客介绍 这个作业属于哪个课程 https://edu.cnblogs.com/campus/xnsy/GeographicInformationScience/ 这个作业要求在哪里 https:// ...

  7. export 和 export default 的区别

    export命令用于规定模块的对外接口. 一个模块就是一个独立的文件.该文件内部的所有变量,外部无法获取.如果你希望外部能够读取模块内部的某个变量,就必须使用export关键字输出该变量.下面是一个 ...

  8. 聊一聊 MySQL 数据库中的那些锁

    在软件开发中,程序在高并发的情况下,为了保证一致性或者说安全性,我们通常都会通过加锁的方式来解决,在 MySQL 数据库中同样有这样的问题,一方面为了最大程度的利用数据库的并发访问,另一方面又需要保证 ...

  9. laravel aritisan命令大全

    1常用命令 显示某个命令的帮助 php artisan -h make:controller 实例命令 php artisan make:controller -r Api/TestControlle ...

  10. linux下redis的部署

    https://www.cnblogs.com/wangchunniu1314/p/6339416.html https://www.linuxidc.com/Linux/2017-09/146894 ...