[Python急救站]基于Transformer Models模型完成GPT2的学生AIGC学习训练模型
为了AIGC的学习,我做了一个基于Transformer Models模型完成GPT2的学生AIGC学习训练模型,指在训练模型中学习编程AI。
在编程之前需要准备一些文件:
首先,先win+R打开运行框,输入:PowerShell后
输入:
pip install -U huggingface_hub
下载完成后,指定我们的环境变量:
$env:HF_ENDPOINT = "https://hf-mirror.com"
然后下载模型:
huggingface-cli download --resume-download gpt2 --local-dir "D:\Pythonxiangmu\PythonandAI\Transformer Models\gpt-2"
这边我的目录是我要下载的工程目录地址
然后下载数据量:
huggingface-cli download --repo-type dataset --resume-download wikitext --local-dir "D:\Pythonxiangmu\PythonandAI\Transformer Models\gpt-2"
这边我的目录是我要下载的工程目录地址
所以两个地址记得更改成自己的工程目录下(建议放在创建一个名为gpt-2的文件夹)
在PowerShell中下载完这些后,可以开始我们的代码啦
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AdamW,
get_linear_schedule_with_warmup,
set_seed,
)
from torch.optim import AdamW
# 设置随机种子以确保结果可复现
set_seed(42)
class TextDataset(Dataset):
def __init__(self, tokenizer, texts, block_size=128):
self.tokenizer = tokenizer
self.examples = [
self.tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=block_size) for
text
in texts]
# 在tokenizer初始化后,确保unk_token已设置
print(f"Tokenizer's unk_token: {self.tokenizer.unk_token}, unk_token_id: {self.tokenizer.unk_token_id}")
def __len__(self):
return len(self.examples)
def __getitem__(self, i):
item = self.examples[i]
# 替换所有不在vocab中的token为unk_token_id
for key in item.keys():
item[key] = torch.where(item[key] >= self.tokenizer.vocab_size, self.tokenizer.unk_token_id, item[key])
return item
def train(model, dataloader, optimizer, scheduler, de, tokenizer):
model.train()
for batch in dataloader:
input_ids = batch['input_ids'].to(de)
# 添加日志输出检查input_ids
if torch.any(input_ids >= model.config.vocab_size):
print("Warning: Some input IDs are outside the model's vocabulary.")
print(f"Max input ID: {input_ids.max()}, Vocabulary Size: {model.config.vocab_size}")
attention_mask = batch['attention_mask'].to(de)
labels = input_ids.clone()
labels[labels[:, :] == tokenizer.pad_token_id] = -100
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
def main():
local_model_path = "D:/Pythonxiangmu/PythonandAI/Transformer Models/gpt-2"
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
# 确保pad_token已经存在于tokenizer中,对于GPT-2,它通常自带pad_token
if tokenizer.pad_token is None:
special_tokens_dict = {'pad_token': '[PAD]'}
tokenizer.add_special_tokens(special_tokens_dict)
model = AutoModelForCausalLM.from_pretrained(local_model_path, pad_token_id=tokenizer.pad_token_id)
else:
model = AutoModelForCausalLM.from_pretrained(local_model_path)
model.to(device)
train_texts = [
"The quick brown fox jumps over the lazy dog.",
"In the midst of chaos, there is also opportunity.",
"To be or not to be, that is the question.",
"Artificial intelligence will reshape our future.",
"Every day is a new opportunity to learn something.",
"Python programming enhances problem-solving skills.",
"The night sky sparkles with countless stars.",
"Music is the universal language of mankind.",
"Exploring the depths of the ocean reveals hidden wonders.",
"A healthy mind resides in a healthy body.",
"Sustainability is key for our planet's survival.",
"Laughter is the shortest distance between two people.",
"Virtual reality opens doors to immersive experiences.",
"The early morning sun brings hope and vitality.",
"Books are portals to different worlds and minds.",
"Innovation distinguishes between a leader and a follower.",
"Nature's beauty can be found in the simplest things.",
"Continuous learning fuels personal growth.",
"The internet connects the world like never before."
# 更多训练文本...
]
dataset = TextDataset(tokenizer, train_texts, block_size=128)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
optimizer = AdamW(model.parameters(), lr=5e-5)
total_steps = len(dataloader) * 5 # 假设训练5个epoch
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
for epoch in range(5): # 训练5个epoch
train(model, dataloader, optimizer, scheduler, device, tokenizer) # 使用正确的变量名dataloader并传递tokenizer
# 保存微调后的模型
model.save_pretrained("path/to/save/fine-tuned_model")
tokenizer.save_pretrained("path/to/save/fine-tuned_tokenizer")
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
main()
这个代码只训练了5个epoch,有一些实例文本,记得调成直接的路径后,运行即可啦。
如果有什么问题可以随时在评论区或者是发个人邮箱:linyuanda@linyuanda.com
[Python急救站]基于Transformer Models模型完成GPT2的学生AIGC学习训练模型的更多相关文章
- 基于 Agent 的模型入门:Python 实现隔离仿真
2005 年诺贝尔经济学奖得主托马斯·谢林(Thomas Schelling)在上世纪 70 年代就纽约的人种居住分布得出了著名的 Schelling segregation model,这是一个 A ...
- python学习-- Django根据现有数据库,自动生成models模型文件
Django引入外部数据库还是比较方便的,步骤如下 : 创建一个项目,修改seting文件,在setting里面设置你要连接的数据库类型和连接名称,地址之类,和创建新项目的时候一致 运行下面代码可以自 ...
- 基于Python的信用评分卡模型分析(二)
上一篇文章基于Python的信用评分卡模型分析(一)已经介绍了信用评分卡模型的数据预处理.探索性数据分析.变量分箱和变量选择等.接下来我们将继续讨论信用评分卡的模型实现和分析,信用评分的方法和自动评分 ...
- 【tornado】系列项目(二)基于领域驱动模型的区域后台管理+前端easyui实现
本项目是一个系列项目,最终的目的是开发出一个类似京东商城的网站.本文主要介绍后台管理中的区域管理,以及前端基于easyui插件的使用.本次增删改查因数据量少,因此采用模态对话框方式进行,关于数据量大采 ...
- 【tornado】系列项目(一)之基于领域驱动模型架构设计的京东用户管理后台
本博文将一步步揭秘京东等大型网站的领域驱动模型,致力于让读者完全掌握这种网络架构中的“高富帅”. 一.预备知识: 1.接口: python中并没有类似java等其它语言中的接口类型,但是python中 ...
- 机器学习经典算法详解及Python实现--基于SMO的SVM分类器
原文:http://blog.csdn.net/suipingsp/article/details/41645779 支持向量机基本上是最好的有监督学习算法,因其英文名为support vector ...
- 02基于python玩转人工智能最火框架之TensorFlow人工智能&深度学习介绍
人工智能之父麦卡锡给出的定义 构建智能机器,特别是智能计算机程序的科学和工程. 人工智能是一种让计算机程序能够"智能地"思考的方式 思考的模式类似于人类. 什么是智能? 智能的英语 ...
- 转 Django根据现有数据库,自动生成models模型文件
Django引入外部数据库还是比较方便的,步骤如下 : 创建一个项目,修改seting文件,在setting里面设置你要连接的数据库类型和连接名称,地址之类,和创建新项目的时候一致 运行下面代码可以自 ...
- Django models模型
Django models模型 一. 所谓Django models模型,是指的对数据库的抽象模型,models在英文中的意思是模型,模板的意思,在这里的意思是通过models,将数据库的借口抽象成p ...
- 基于Distiller的模型压缩工具简介
Reference: https://github.com/NervanaSystems/distiller https://nervanasystems.github.io/distiller/in ...
随机推荐
- Oracle与MySQL的差异和对比
Oracle与MySQL的差异和对比:配套hands-on参考脚本. 方便客户针对培训课件内容进行动手实践,加强理解. --------------------------------- -- 主题: ...
- C# AES CBC模式 加密和解密
using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.S ...
- 基于logisim-D触发器设计四人抢答电路
实验1:设计一个简易4人知识竞赛抢答电路,要求是: 裁判掌握一个按钮,作用是给电路复位和发出抢答开始命令;4名竞赛者各掌握一个按钮,每人对应一个指示灯,在主持人发出开始抢答命令后,哪位参赛者先按钮其对 ...
- mysql统计所有分类下的数量,没有的也要展示
要求统计所有分类下的数量,如果分类下没有对应的数据也要展示.这种问题在日常的开发中很常见,每次写每次忘,所以在此记录下. 这种统计往往不能直接group by,因为有些类别可能没有对应的数据 这里有两 ...
- mybatis踩坑之integer类型是0的时候会被认为0!=''是假
当你的参数类型是integer类型,并且传的是0的时候,在SQL里面做if判断的时候 <if test="auditStatus != null and auditStatus != ...
- 面试官:Session和JWT有什么区别?
Session 和 JWT(JSON Web Token)都是用于在用户和服务器之间建立认证状态的机制,但它们在工作原理.存储方式和安全性等方面存在着一些差异,下面我们一起来看. 1.什么是JWT? ...
- #Manacher,并查集#洛谷 3279 [SCOI2013]密码
题目 分析 这些回文长度可以提供相等或者不等的信息, 不等的直接连边强制不等,相等用并查集合并连通块, 但是这样判断是\(O(n^2)\),考虑这些回文长度当用Manacher求时, 所有的回文长度都 ...
- JDK10的新特性:var泛型和多个接口实现
目录 简介 实现多个接口 使用多个接口 使用var 总结 简介 在JDK10的新特性:本地变量类型var中我们讲到了为什么使用var和怎么使用var. 今天我们来深入的考虑一下var和泛型,多个接口实 ...
- 帕鲁重大更新!macOS 竟然也能玩了
近日,<幻兽帕鲁>迎来了 v0.2.1.0 大版本的更新. 本次更新的最大亮点是新实装的突袭头目系统.玩家可以在 "召唤祭坛" 献祭石板,从而召唤强大的突袭头目.其中, ...
- 如何在现实场景中随心放置AR虚拟对象?
随着AR的发展和电子设备的普及,人们在生活中使用AR技术的门槛降低,比如对于不方便测量的物体使用AR测量,方便又准确:遇到陌生的路段使用AR导航,清楚又便捷:网购时拿不准的物品使用AR购物,体验更逼真 ...