介绍:

欢迎来到本篇文章!在这里,我们将探讨一个常见而重要的自然语言处理任务——文本分类。具体而言,我们将关注情感分析任务,即通过分析电影评论的情感来判断评论是正面的、负面的。

展示:

训练展示如下:



实际使用如下:

实现方式:

选择PyTorch作为深度学习框架,使用电影评论IMDB数据集,并结合torchtext对数据进行预处理。

环境:

Windows+Anaconda
重要库版本信息
torch==1.8.2+cu102
torchaudio==0.8.2
torchdata==0.7.1
torchtext==0.9.2
torchvision==0.9.2+cu102

实现思路:

1、数据集

本次使用的是IMDB数据集,IMDB是一个含有50000条关于电影评论的数据集

数据如下:

2、数据加载与预处理

使用torchtext加载IMDB数据集,并对数据集进行划分

具体划分如下:

TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)
# Load the IMDB dataset
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

创建一个 Field 对象,用于处理文本数据。同时使用spacy分词器对文本进行分词,由于IMDB是英文的,所以使用en_core_web_sm语言模型。

创建一个 LabelField 对象,用于处理标签数据。设置dtype 参数为 torch.float,表示标签的数据类型为浮点型。

使用 datasets.IMDB.splits 方法加载 IMDB 数据集,并将文本字段 TEXT 和标签字段 LABEL 传递给该方法。返回的 train_data 和 test_data 包含了 IMDB 数据集的训练和测试部分。

下面是train_data的输出

3、构建词汇表与加载预训练词向量

TEXT.build_vocab(train_data,max_size=25000,vectors="glove.6B.100d",unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

train_data:表示使用train_data中数据构建词汇表

max_size:限制词汇表的大小为 25000

vectors="glove.6B.100d":表示使用预训练的 GloVe 词向量,其中 "glove.6B.100d" 指的是包含 100 维向量的 6B 版 GloVe。

unk_init=torch.Tensor.normal_ :表示指定未知单词(UNK)的初始化方式,这里使用正态分布进行初始化。

LABEL.build_vocab(train_data):表示对标签进行类似的操作,构建标签的词汇表

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits( (train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)

使用data.BucketIterator.splits 来创建数据加载器,包括训练、验证和测试集的迭代器。这将确保你能够方便地以批量的形式获取数据进行训练和评估。

4、定义神经网络

这里的网络定义比较简单,主要采用在词嵌入层(embedding)后接一个全连接层的方式完成对文本数据的分类。

具体如下:

class NetWork(nn.Module):
    def __init__(self,vocab_size,embedding_dim,output_dim,pad_idx):
        super(NetWork,self).__init__()
        self.embedding = nn.Embedding(vocab_size,embedding_dim,padding_idx=pad_idx)
        self.fc = nn.Linear(embedding_dim,output_dim)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()     def forward(self,x):
        embedded = self.embedding(x)
        embedded = embedded.permute(1,0,2)
        pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)
        pooled = self.relu(pooled)
        pooled = self.dropout(pooled)
       
        output = self.fc(pooled)
        return output

5、模型初始化

vocab_size = len(TEXT.vocab)
embedding_dim  = 100
output = 1
pad_idx = TEXT.vocab.stoi[TEXT.pad_token]
model = NetWork(vocab_size,embedding_dim,output,pad_idx)
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)

定义模型的超参数,包括词汇表大小(vocab_size)、词向量维度(embedding_dim)、输出维度(output,在这个任务中是1,因为是二元分类,所以使用1),以及 PAD 标记的索引(pad_idx)

之后需要将预训练的词向量加载到嵌入层的权重中。TEXT.vocab.vectors 包含了词汇表中每个单词的预训练词向量,然后通过 copy_ 方法将这些词向量复制到模型的嵌入层权重中对网络进行初始化。这样做确保了模型的初始化状态良好。

6、训练模型

 total_loss = 0
 train_acc = 0
model.train()
for batch in train_iterator:
        optimizer.zero_grad()
        preds = model(batch.text).squeeze(1)
        loss = criterion(preds,batch.label)
        total_loss += loss.item()         batch_acc = (torch.round(torch.sigmoid(preds)) == batch.label).sum().item()
        train_acc += batch_acc
       
        loss.backward()
        optimizer.step()     average_loss = total_loss / len(train_iterator)
    train_acc /= len(train_iterator.dataset)

optimizer.zero_grad():表示将模型参数的梯度清零,以准备接收新的梯度。

preds = model(batch.text).squeeze(1):表示一次前向传播的过程,由于model输出的是torch.tensor(batch_size,1)所以使用squeeze(1)给其中的1维度数据去除,以匹配标签张量的形状

criterion(preds,batch.label):定义的损失函数 criterion 计算预测值 preds 与真实标签 batch.label 之间的损失

(torch.round(torch.sigmoid(preds)) == batch.label).sum().item():

通过比较模型的预测值与真实标签,计算当前批次的准确率,并将其累加到 train_acc 中

后面的就是进行反向传播更新参数,还有就是计算loss和train_acc的值了

7、模型评估:

model.eval()
    valid_loss = 0
    valid_acc = 0
    best_valid_acc = 0
    with torch.no_grad():
        for batch in valid_iterator:
            preds = model(batch.text).squeeze(1)
            loss = criterion(preds,batch.label)
            valid_loss += loss.item()
            batch_acc = ((torch.round(torch.sigmoid(preds)) == batch.label).sum().item())
            valid_acc += batch_acc

和训练模型的类似,这里就不解释了

8、保存模型

这里一共使用了两种保存模型的方式:

torch.save(model, "model.pth")
torch.save(model.state_dict(),"model.pth")

第一种方式叫做模型的全量保存

第二种方式叫做模型的参数保存

全量保存是保存了整个模型,包括模型的结构、参数、优化器状态等信息

参数量保存是保存了模型的参数(state_dict),不包括模型的结构

9、测试模型

测试模型的基本思路:

加载训练保存的模型、对待推理的文本进行预处理、将文本数据加载给模型进行推理

加载模型:

saved_model_path = "model.pth"
saved_model = torch.load(saved_model_path)

输入文本:

input_text = "Great service! The staff was very friendly and helpful."

文本进行处理:

tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
tokenized_text = tokenizer(input_text)
indexed_text = [TEXT.vocab.stoi[token] for token in tokenized_text]
tensor_text = torch.LongTensor(indexed_text).unsqueeze(1).to(device)

模型推理:

saved_model.eval()
with torch.no_grad():
output = saved_model(tensor_text).squeeze(1)
prediction = torch.round(torch.sigmoid(output)).item()
probability = torch.sigmoid(output).item()

由于笔者能力有限,所以在描述的过程中难免会有不准确的地方,还请多多包含!

关注公众号“陶陶name”获取更多NLP和CV文章以及完整代码!

NLP项目实战01--之电影评论分类的更多相关文章

  1. 【项目实战】Kaggle电影评论情感分析

    前言 这几天持续摆烂了几天,原因是我自己对于Kaggle电影评论情感分析的这个赛题敲出来的代码无论如何没办法运行,其中数据变换的维度我无法把握好,所以总是在函数中传错数据.今天痛定思痛,重新写了一遍代 ...

  2. 【SSH网上商城项目实战01】整合Struts2、Hibernate4.3和Spring4.2

    转自:https://blog.csdn.net/eson_15/article/details/51277324 今天开始做一个网上商城的项目,首先从搭建环境开始,一步步整合S2SH.这篇博文主要总 ...

  3. 电影评论分类:二分类问题(IMDB数据集)

    IMDB数据集是Keras内部集成的,初次导入需要下载一下,之后就可以直接用了. IMDB数据集包含来自互联网的50000条严重两极分化的评论,该数据被分为用于训练的25000条评论和用于测试的250 ...

  4. JAVAEE——SSH项目实战01:SVN介绍、安装和使用方法

    1 学习目标 1.掌握svn服务端.svn客户端.svn eclipse插件安装方法 2.掌握svn的基本使用方法 2 svn介绍 2.1 项目管理中的版本控制问题 通常软件开发由多人协作开发,如果对 ...

  5. JAVAEE——SSH项目实战01:SVN介绍、eclipse插件安装和使用方法

    1 学习目标 1.掌握svn服务端.svn客户端.svn eclipse插件安装方法 2.掌握svn的基本使用方法 2 svn介绍 2.1 项目管理中的版本控制问题 通常软件开发由多人协作开发,如果对 ...

  6. 【Robot Framework 项目实战 01】使用 RequestsLibrary 进行接口测试

    写在前面 本文我们一起来学习如何使用Robot Framework 的RequestsLibrary库,涉及POST.GET接口测试,RF用例分层封装设计等内容. 接口 接口测试是我们最常见的测试类型 ...

  7. React Native商城项目实战01 - 初始化设置

    1.创建项目 $ react-native init BuyDemo 2.导入图片资源 安卓:把文件夹放到/android/app/src/main/res/目录下,如图: iOS: Xcode打开工 ...

  8. Flask项目实战:创建电影网站(3)后台的增删改查

    添加预告 根据需求数据库创建表格 需求数据库,关键字title logo # 上映预告 class Preview(db.Model): __tablename__ = "preview&q ...

  9. Flask项目实战:创建电影网站(2)

    flask网站制作后台时候常见流程总结 安利一个神神器: 百度脑图PC版 创建数据库 下面是创建User数据库,需要导入db库 #coding:utf8 from flask import Flask ...

  10. Flask项目实战:创建电影网站-创世纪(1)

    以后要养成写博客的习惯,用来做笔记.本人看的东西很多很杂,但因为工作中很少涉及,造成看了之后就忘,或者看了就看了,但是没有融入的自己的知识体系里面. 写博客一方面是做记录,一方面是给这段时间业余学习的 ...

随机推荐

  1. AI绘画StableDiffusion:云端在线版免费使用笔记分享-Kaggle版

    玩AI绘画(SD),自己电脑配置不够?今天给大家介绍一下如何baipiao在线版AI绘画StableDiffusion. Kaggle 是世界上最大的数据科学社区,拥有强大的工具和资源,可帮助您实现数 ...

  2. 利用AI点亮副业变现:5个变现实操案例的启示

    整体思维导图: 在这里先分享五个实操案例: 宝宝起名服务 AI科技热点号 头像壁纸号 小说推广号 流量营销号 你们好,我是小梦. 最初我计划撰写一篇关于AI盈利策略的文章,对AI目前的技术走向.应用场 ...

  3. 基于 JMeter API 开发性能测试平台

    背景: JMeter 是一个功能强大的性能测试工具,若开发一个性能测试平台,用它作为底层执行引擎在合适不过.如要使用其API,就不得不对JMeter 整个执行流程,常见的类有清楚的了解. 常用的 JM ...

  4. Java NIO 图解 Netty 服务端启动的过程

    一.启动概述 了解整体Netty常用的核心组件后,并且对比了传统IO模式.在对比过程中,找到了传统IO对应Netty中是如何实现的.最后我们了解到在netty中常用的那些组件. 本文在了解下这些核心组 ...

  5. 使用PIL为图片添加水印

    使用pillow库为图片添加文件或者图片水印 下面是我们想要添加水印的图片: 图片水印: 效果图如下: ps:对图片添加字体时,需指定字体文件,如 simsun.ttc windows中在 C:\Wi ...

  6. iOS添加图片

    添加一个按钮 将图片添加到

  7. Go代码包与引入:如何有效组织您的项目

    本文深入探讨了Go语言中的代码包和包引入机制,从基础概念到高级应用一一剖析.文章详细讲解了如何创建.组织和管理代码包,以及包引入的多种使用场景和最佳实践.通过阅读本文,开发者将获得全面而深入的理解,进 ...

  8. 2017-A

    2017-A 题目描述: 输入一个字符串,要求输出能把所有的小写字符放前面,大写字符放中间,数字放后面,并且中间用空格隔开,如果同种类字符间有不同种类的字符,输出后也要用字符隔开. 例: 输入 12a ...

  9. 有关library导入的个人总结和反思

    本来帮助朋友找寻一下android的一些特效的demo,结果找到了一个,朋友试验可以,自己却是在导入项目需要的library的时候总是出问题,真的很是丢人,反省反省. 也许专业人士看来这是非常可笑的问 ...

  10. 聊聊Maven的依赖传递、依赖管理、依赖作用域

    1. 依赖传递 在Maven中,依赖是会传递的,假如在业务项目中引入了spring-boot-starter-web依赖: <dependency> <groupId>org. ...