import cv2
import numpy as np
import os
import pickle data_dir = os.path.join("data", "cifar-10-batches-py")
train_o_dir = os.path.join("data", "train")
test_o_dir = os.path.join("data", "test") Train = True # 不解压训练集,仅解压测试集 # 解压缩,返回解压后的字典
def unpickle(file):
with open(file, 'rb') as fo:
dict_ = pickle.load(fo, encoding='bytes')
return dict_ def my_mkdir(my_dir):
if not os.path.isdir(my_dir):
os.makedirs(my_dir) # 生成训练集图片,
if __name__ == '__main__':
if Train:
for j in range(1, 6):
data_path = os.path.join(data_dir, "data_batch_" + str(j)) # data_batch_12345
train_data = unpickle(data_path)
print(data_path + " is loading...") for i in range(0, 10000):
img = np.reshape(train_data[b'data'][i], (3, 32, 32))
img = img.transpose(1, 2, 0) label_num = str(train_data[b'labels'][i])
o_dir = os.path.join(train_o_dir, "data_batch_" + str(j) ,label_num)
my_mkdir(o_dir) img_name = label_num + '_' + str(i + (j - 1)*10000) + '.png'
img_path = os.path.join(o_dir, img_name)
cv2.imwrite(img_path, img)
print(data_path + " loaded.") print("test_batch is loading...") # 生成测试集图片
test_data_path = os.path.join(data_dir, "test_batch")
test_data = unpickle(test_data_path)
for i in range(0, 10000):
img = np.reshape(test_data[b'data'][i], (3, 32, 32))
img = img.transpose(1, 2, 0) label_num = str(test_data[b'labels'][i])
o_dir = os.path.join(test_o_dir, label_num)
my_mkdir(o_dir) img_name = label_num + '_' + str(i) + '.png'
img_path = os.path.join(o_dir, img_name)
cv2.imwrite(img_path, img) print("test_batch loaded.") import sys
import os
my_mkdir("data/traintxt")
#生成batch的txt
data_dir = "data/train/"
datat = "data/traintxt"
for j in range(1, 6):
data_path = os.path.join(data_dir, "data_batch_" + str(j)) # data_batch_12345
datatraint = os.path.join(datat, "data_batch_" + str(j) + ".txt")
ft = open(datatraint, 'w')
print(data_path)
for root, s_dirs, _ in os.walk(data_path, topdown=True): # 获取 train文件下各文件夹名称
print(s_dirs)
for sub_dir in s_dirs:
i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径
img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径
for i in range(len(img_list)):
if not img_list[i].endswith('png'): # 若不是png文件,跳过
continue
label = img_list[i].split('_')[0]
img_path = os.path.join(i_dir, img_list[i])
line = img_path + ' ' + label + '\n'
ft.write(line)
ft.close() #总生成txt
data_dir = "data/train/"
datat = "data"
datatraint = os.path.join(datat, "train.txt")
ft = open(datatraint, 'w')
for j in range(1, 6):
data_path = os.path.join(data_dir, "data_batch_" + str(j)) # data_batch_12345
print(data_path)
for root, s_dirs, _ in os.walk(data_path, topdown=True): # 获取 train文件下各文件夹名称
print(s_dirs)
for sub_dir in s_dirs:
i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径
img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径
for i in range(len(img_list)):
if not img_list[i].endswith('png'): # 若不是png文件,跳过
continue
label = img_list[i].split('_')[0]
img_path = os.path.join(i_dir, img_list[i])
line = img_path + ' ' + label + '\n'
ft.write(line)
ft.close() #test的txt
data_dir = "data"
datat = "data" data_path = os.path.join(data_dir, "test")
datatraint = os.path.join(datat, "test.txt")
ft = open(datatraint, 'w') print(data_path)
for root, s_dirs, _ in os.walk(data_path, topdown=True): # 获取 test文件下各文件夹名称
print(s_dirs)
for sub_dir in s_dirs:
i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径
img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径
for i in range(len(img_list)):
if not img_list[i].endswith('png'): # 若不是png文件,跳过
continue
label = img_list[i].split('_')[0]
img_path = os.path.join(i_dir, img_list[i])
line = img_path + ' ' + label + '\n'
ft.write(line)
ft.close()

update from other’s github main.py

'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset
from PIL import Image
import torchvision
import torchvision.transforms as transforms import os
import argparse
from models import *
from utils import progress_bar class Mydataset(Dataset):
def __init__(self,txt_path,transform = None,target_transform = None):
fh = open(txt_path,'r')
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
imgs.append((words[0],int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self,index):
fn,label = self.imgs[index]
img = Image.open(fn)
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.imgs) parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true',
help='resume from checkpoint')
args = parser.parse_args() device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch # Data
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]) transform_test = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]) trainset = Mydataset(txt_path = '/work/aiit/warming/cifar-10-batches-py/train.txt',
transform=transform_train)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=128, shuffle=True, num_workers=2) testset = Mydataset(txt_path = '/work/aiit/warming/cifar-10-batches-py/test.txt',
transform=transform_test)
testloader = torch.utils.data.DataLoader(
testset, batch_size=100, shuffle=False, num_workers=2) classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck') # Model
print('==> Building model..')
#net = vgg.VGG('VGG19')
#net = ResNet18()
# net = PreActResNet18()
# net = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
# net = MobileNet()
# net = MobileNetV2()
# net = DPN92()
# net = ShuffleNetG2()
#net = SENet18()
# net = ShuffleNetV2(1)
# net = EfficientNetB0()
net = RegNetX_200MF()
net = net.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)
cudnn.benchmark = True if args.resume:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.pth')
net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch'] criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr,
momentum=0.9, weight_decay=5e-4) # Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step() train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item() progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
torch.save(net, './checkpoint/RegNetX_200MF.pth') def test(epoch):
global best_acc
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets) test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item() progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) # Save checkpoint.
acc = 100.*correct/total
if acc > best_acc:
print('Saving..')
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
#torch.save(net, './checkpoint/ckpt1.pth')
best_acc = acc for epoch in range(start_epoch, start_epoch+100):
train(epoch)
test(epoch)

预测

import torch
import cv2
import torch.nn.functional as F
import sys
sys.path.append('/work/aiit/warming/pytorch-cifar-master/models')
#import vgg
#import torchvision.models as models
#from vgg2 import vgg #重要,虽然显示灰色(即在次代码中没用到),但若没有引入这个模型代码,加载模型时会找不到模型
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as np classes = ('plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck')
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#net=models.vgg19(pretrained=False)
model = (torch.load('/work/aiit/warming/pytorch-cifar-master/checkpoint/RegNetX_200MF.pth')) # 加载模型
model = model.to(device)
model.eval() # 把模型转为test模式 img = cv2.imread("/work/aiit/warming/cifar-10-batches-py/test/1/1_6.png") # 读取要预测的图片
img=cv2.resize(img,(32,32))
trans = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]) img = trans(img)
img = img.to(device)
img = img.unsqueeze(0) # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]
# 扩展后,为[1,1,28,28]
output = model(img)
prob = F.softmax(output,dim=1) #prob是10个分类的概率
print(prob)
value, predicted = torch.max(output.data, 1)
#print(predicted.item())
#print(value)
pred_class = classes[predicted.item()]
print(pred_class) '''prob = F.softmax(output, dim=1)
prob = Variable(prob)
prob = prob.cpu().numpy() # 用GPU的数据训练的模型保存的参数都是gpu形式的,要显示则先要转回cpu,再转回numpy模式
print(prob) # prob是10个分类的概率
pred = np.argmax(prob) # 选出概率最大的一个
print(pred)
print(pred.item())
pred_class = classes[pred]
print(pred_class)'''

cifar-10-dataset的更多相关文章

  1. 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network (CNN)| CIFAR 10 TensorFlow

    原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...

  2. DL Practice:Cifar 10分类

    Step 1:数据加载和处理 一般使用深度学习框架会经过下面几个流程: 模型定义(包括损失函数的选择)——>数据处理和加载——>训练(可能包括训练过程可视化)——>测试 所以自己写代 ...

  3. 【神经网络与深度学习】基于Windows+Caffe的Minst和CIFAR—10训练过程说明

    Minst训练 我的路径:G:\Caffe\Caffe For Windows\examples\mnist  对于新手来说,初步完成环境的配置后,一脸茫然.不知如何跑Demo,有么有!那么接下来的教 ...

  4. PyTorch教程之Training a classifier

    我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...

  5. TensorLayer官方中文文档1.7.4:API – 数据预处理

    所属分类:TensorLayer API - 数据预处理¶ 我们提供大量的数据增强及处理方法,使用 Numpy, Scipy, Threading 和 Queue. 不过,我们建议你直接使用 Tens ...

  6. TensorFlow入门学习(让机器/算法帮助我们作出选择)

    catalogue . 个人理解 . 基本使用 . MNIST(multiclass classification)入门 . 深入MNIST . 卷积神经网络:CIFAR- 数据集分类 . 单词的向量 ...

  7. What are some good books/papers for learning deep learning?

    What's the most effective way to get started with deep learning?       29 Answers     Yoshua Bengio, ...

  8. (转)Awesome PyTorch List

    Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...

  9. [转]最好用的 AI 开源数据集 Top 39:NLP、语音等 6 大类

    原文链接 本文修正部分错误. 以下是精心收集的一些非常好的开放数据集,也是做 AI 研究不容错过的数据集. 标签解释 [经典]这些是在 AI 领域中非常著名.众所周知的数据集.很少有研究者或工程师没有 ...

  10. 深度学习常用数据集 API(包括 Fashion MNIST)

    基准数据集 深度学习中经常会使用一些基准数据集进行一些测试.其中 MNIST, Cifar 10, cifar100, Fashion-MNIST 数据集常常被人们拿来当作练手的数据集.为了方便,诸如 ...

随机推荐

  1. Azure Devops(十四) 使用Azure的私有Nuget仓库

    哈喽大家好,最近因为工作的原因没有时间写文章,断更了俩月,今天我们开始继续研究Azure上的功能. 今天我们开始研究一下Azure的制品仓库,在之前的流水线的相关文章中,我们都使用到了制品仓库用来保存 ...

  2. linux学习随笔2之防火墙

    centos7默认使用的防火墙是firewalld 查看所有打开的端口: firewall-cmd --zone=public --list-ports 更新防火墙规则: firewall-cmd - ...

  3. 从零开始Blazor Server(4)--登录系统

    说明 上一篇文章中我们添加了Cookie授权,可以跳转到登录页了.但是并没有完成登录,今天我们来完成它. 我们添加Cookie授权的时候也说了,这套跟MVC一模一样,所以我们登录也是跟MVC一模一样. ...

  4. mosquitto使用的基本流程以及一些遇见的问题

    改配置文件 以记事本的方式打开mosquitto.conf更改部分内容,找到# listener port-number [ip address/host name/unix socket path] ...

  5. 万答#14,xtrabackup8.0怎么恢复单表

    欢迎来到 GreatSQL社区分享的MySQL技术文章,如有疑问或想学习的内容,可以在下方评论区留言,看到后会进行解答 GreatSQL社区原创内容未经授权不得随意使用,转载请联系小编并注明来源. 实 ...

  6. Vector3类定义

    大家一定要先看书,在看我的随笔啊.不然不知道原理的.而且我是不写教程的,只是写笔记怕自己忘记了. 我把所有的基础类放在了名叫geometry的文件中,包含Vector3, Normal3, Point ...

  7. Docker 03 镜像命令

    参考源 https://www.bilibili.com/video/BV1og4y1q7M4?spm_id_from=333.999.0.0 https://www.bilibili.com/vid ...

  8. Pycharm5个非常有用的技巧

    PyCharm 是一款非常强大的编写 python 代码的工具.掌握一些小技巧能成倍的提升写代码的效率,本篇介绍几个经常使用的小技巧. 一.分屏展示 当你想同时看到多个文件的时候: 右击标签页: 选择 ...

  9. windows自动切换深色模式(夜晚模式)

    img { width: 30vw } windows系统上怎么根据日出日落时间判断切换为深色模式或浅色模式呢? windows系统自带了一个叫做"任务计划程序"的软件.可以通过& ...

  10. 使用idea remote 开发体验

    本地使用idea开发最不好的一个体验就是打开稍大的工程就非常的卡,怎么调参数都没用,现在idea推出了idea remote就赶紧来体验下. 使用方式 除了idea不需要额外下载什么包,但是因为rem ...