huggingface vit训练CIFAR10数据集代码 ,可以改dataset训练自己的数据
上代码,使用hugging face fineturn vit模型
自己写的代码
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST,CIFAR10
from torchvision.transforms import ToTensor
from torchvision.models import resnet101
from tqdm import tqdm # 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("mps")
# torch.device("cpu") # 加载 MNIST 数据集
train_dataset = CIFAR10(root="/data/xinyuuliu/datas", train=True, transform=ToTensor(), download=True)
test_dataset = CIFAR10(root="/data/xinyuuliu/datas", train=False, transform=ToTensor()) def collate_fn(batch):
"""
对batch数据进行处理
:param batch: [一个getitem的结果,getitem的结果,getitem的结果]
:return: 元组
"""
reviews,labels = zip(*batch)
# print(reviews)
# print(labels)
# reviews = torch.Tensor(reviews)
labels = torch.Tensor(labels) return reviews,labels
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,collate_fn=collate_fn) # url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# image = Image.open(requests.get(url, stream=True).raw) processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.config.classifier = 'mlp'
model.config.num_labels = 10
# print(model.get_output_embeddings)
# print(model.classifier)
model.classifier = nn.Linear(768,10)
print(model.classifier) parameters = list(model.parameters())
for x in parameters[:-1]:
x.requires_grad = False model.to(device) # 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001) def train(model, dataloader, optimizer, criterion):
model.train()
running_loss = 0.0
for inputs, labels in tqdm(dataloader, desc="Training"):
# print(inputs)
inputs = processor(images=inputs, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(device)
labels = labels.to(device)
# print(inputs['pixel_values'].shape)
# print(labels.shape)
optimizer.zero_grad() outputs = model(**inputs)
logits = outputs.logits # print(logits,labels)
loss = criterion(logits, labels.long())
loss.backward()
optimizer.step()
# model predicts one of the 1000 ImageNet classes
# predicted_class_idx = logits.argmax(-1).item()
# print("Predicted class:", model.config.id2label[predicted_class_idx])
running_loss += loss.item() * inputs['pixel_values'].size(0) epoch_loss = running_loss / len(dataloader.dataset)
return epoch_loss def evaluate(model, dataloader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in tqdm(dataloader, desc="Evaluating"):
inputs = processor(images=inputs, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(device)
labels = labels.to(device) outputs = model(**inputs)
logits = outputs.logits predicted= logits.argmax(-1) total += labels.size(0)
correct += (predicted == labels).sum().item() accuracy = correct / total * 100
return accuracy # 训练和评估
num_epochs = 10 for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
train_loss = train(model, train_loader, optimizer, criterion)
print(f"Training Loss: {train_loss:.4f}") test_acc = evaluate(model, test_loader)
print(f"Test Accuracy: {test_acc:.2f}%")
chatgpt生成的代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from transformers import ViTModel, ViTForImageClassification
from tqdm import tqdm # 设置随机种子
torch.manual_seed(42) # 定义超参数
batch_size = 32
num_epochs = 10
learning_rate = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
]) # 加载CIFAR-10数据集
train_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=False, download=True, transform=transform) # 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 加载预训练的ViT模型
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) # 替换分类头
num_classes = 10
vit_model.config.classifier = 'mlp'
vit_model.config.num_labels = num_classes
vit_model.classifier = nn.Linear(vit_model.config.hidden_size, num_classes).to(device) # 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate) # 微调ViT模型
for epoch in range(num_epochs):
print("epoch:",epoch)
vit_model.train()
train_loss = 0.0
train_correct = 0 bar = tqdm(train_loader,total=len(train_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) # 前向传播
outputs = vit_model(images)
loss = criterion(outputs.logits, labels) # 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step() train_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
train_correct += (predicted == labels).sum().item() # 在训练集上计算准确率
train_accuracy = 100.0 * train_correct / len(train_dataset) # 在测试集上进行评估
vit_model.eval()
test_loss = 0.0
test_correct = 0 with torch.no_grad():
bar = tqdm(test_loader,total=len(test_loader))
for images, labels in bar:
images = images.to(device)
labels = labels.to(device) outputs = vit_model(images)
loss = criterion(outputs.logits, labels) test_loss += loss.item()
_, predicted = torch.max(outputs.logits, 1)
test_correct += (predicted == labels).sum().item() # 在测试集上计算准确率
test_accuracy = 100.0 * test_correct / len(test_dataset) # 打印每个epoch的训练损失、训练准确率和测试准确率
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%')
huggingface vit训练CIFAR10数据集代码 ,可以改dataset训练自己的数据的更多相关文章
- Ubuntu+caffe训练cifar-10数据集
1. 下载cifar-10数据库 ciffar-10数据集包含10种物体分类,50000张训练图片,10000张测试图片. 在终端执行指令下载cifar-10数据集(二进制文件): cd ~/caff ...
- Keras学习:试用卷积-训练CIFAR-10数据集
import numpy as np import cPickle import keras as ks from keras.layers import Dense, Activation, Fla ...
- MXNet学习:试用卷积-训练CIFAR-10数据集
第一次用卷积,看的别人的模型跑的CIFAR-10,不过吐槽一下...我觉着我的965m加速之后比我的cpu算起来没快多少..正确率64%的样子,没达到模型里说的75%,不知道问题出在哪里 import ...
- 使用caffe训练mnist数据集 - caffe教程实战(一)
个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...
- CaffeExample 在CIFAR-10数据集上训练与测试
本文主要来自Caffe作者Yangqing Jia网站给出的examples. @article{jia2014caffe, Author = {Jia, Yangqing and Shelhamer ...
- 仿照CIFAR-10数据集格式,制作自己的数据集
本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50801226 前一篇博客:C/C++ ...
- TensorFlow CNN 测试CIFAR-10数据集
本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50738311 1 CIFAR-10 数 ...
- caffe︱cifar-10数据集quick模型的官方案例
准备拿几个caffe官方案例用来练习,就看到了caffe中的官方案例有cifar-10数据集.于是练习了一下,在CPU情况下构建quick模型.主要参考博客:liumaolincycle的博客 配置: ...
- 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集
上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...
- CIFAR-10数据集图像分类【PCA+基于最小错误率的贝叶斯决策】
CIFAR-10和CIFAR-100均是带有标签的数据集,都出自于规模更大的一个数据集,他有八千万张小图片.而本次实验采用CIFAR-10数据集,该数据集共有60000张彩色图像,这些图像是32*32 ...
随机推荐
- hutool,真香!
前言 今天给大家介绍一个能够帮助大家提升开发效率的开源工具包:hutool. Hutool是一个小而全的Java工具类库,通过静态方法封装,降低相关API的学习成本,提高工作效率,使Java拥有函数式 ...
- 适用于AbpBoilerplate的阿里云腾讯云Sms短信服务
Sms 适用于AbpBoilerplate的短信服务(Short Message Service,SMS)模块,通过简单配置即可使用,仅更改一处代码即可切换短信服务提供商. Aliyun.Sms由阿里 ...
- Zabbix自动发现:python-json模块应用介绍
一.JSON模块介绍 json模块是python内置的库,其主要功能是将序列化数据从文件里读取出来或者存入文件.该模块有四个方法:dump().load().dumps().loads(),其中dum ...
- 感慨 vscode 支持win7最后一个版本 1.70.3 于2022年7月发布
为什么 家里电脑一直是win7,也懒的升级,nodejs也不能用最新的,没想到vscode也停产了 https://code.visualstudio.com/updates/v1_70 后记 别用u ...
- redis三主三从详细搭建过程
搭建Redis三主三从集群的详细步骤如下: 准备环境: 确保你有六台服务器或虚拟机,每台服务器上都已经安装了Redis.这些服务器将用于搭建三主三从的Redis集群. 确保所有服务器之间的网络连接正常 ...
- 关于初始化page入参的设计思路
最近在重构老的代码,在写的过程中发现之前的逻辑如果遇到没有入参pageNo会Npe,于是乎我想找找公司项目有啥方式处理page入参的有两种如下 使用三元表达式直接判断是否null,然后赋值 使用map ...
- window.showModalDialog与opener及returnValue
首先来看看 window.showModalDialog 的参数 vReturnValue = window.showModalDialog(sURL [, vArguments] [, sFeatu ...
- MySQL varchar详解
说明:以下结果都是在mysql8.2及Innodb环境下测试. varcahr(255)是什么含义? varchar(255) 表示可以存储最大255个字符,至于占多少个字节由字符集决定. varch ...
- koa2整合mysql
引入mysql包 npm install mysql 封装mysql 创建mysql.js文件放在utils(工具包)中 使用pool连接池 mysql.js //封装mysql const mysq ...
- Win10 如何在桌面显示我的电脑
Win10桌面右键鼠标,然后在弹出来的选项中选择个性化. 选择了个性化后会弹出设置界面,在设置中选择[主题] 找到[桌面图标设置] 点击[桌面图标设置],会弹出一个对话框,该对话框有可以设置显示的图标 ...