目录结构

dogsData.py

import json

import torch
import os, glob
import random, csv from PIL import Image
from torch.utils.data import Dataset, DataLoader from torchvision import transforms
from torchvision.transforms import InterpolationMode class Dogs(Dataset): def __init__(self, root, resize, mode):
super().__init__()
self.root = root
self.resize = resize
self.nameLable = {}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root, name)):
continue
self.nameLable[name] = len(self.nameLable.keys()) if not os.path.exists(os.path.join(self.root, 'label.txt')):
with open(os.path.join(self.root, 'label.txt'), 'w', encoding='utf-8') as f:
f.write(json.dumps(self.nameLable, ensure_ascii=False)) # print(self.nameLable)
self.images, self.labels = self.load_csv('images.csv')
# print(self.labels) if mode == 'train':
self.images = self.images[:int(0.8*len(self.images))]
self.labels = self.labels[:int(0.8*len(self.labels))]
elif mode == 'val':
self.images = self.images[int(0.8*len(self.images)):int(0.9*len(self.images))]
self.labels = self.labels[int(0.8*len(self.labels)):int(0.9*len(self.labels))]
else:
self.images = self.images[int(0.9*len(self.images)):]
self.labels = self.labels[int(0.9*len(self.labels)):] def load_csv(self, filename): if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.nameLable.keys():
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
# print(len(images)) random.shuffle(images)
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2]
label = self.nameLable[name]
writer.writerow([img, label])
print('csv write succesful') images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label) assert len(images) == len(labels) return images, labels def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x_hot = (x-mean)/std
# x = x_hat * std = mean
# x : [c, w, h]
# mean [3] => [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1) x = x_hat * std + mean
return x def __len__(self):
return len(self.images) def __getitem__(self, idx):
# print(idx, len(self.images), len(self.labels))
img, label = self.images[idx], self.labels[idx] # 将字符串路径转换为tensor数据
# print(self.resize, type(self.resize))
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'),
transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = tf(img) label = torch.tensor(label) return img, label def main(): import visdom
import time viz = visdom.Visdom() # func1 通用
db = Dogs('Images_Data_Dog', 224, 'train')
# 取一张
# x,y = next(iter(db))
# print(x.shape, y)
# viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x')) # 取一个batch
loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
print(len(loader))
print(db.nameLable)
# for x, y in loader:
# # print(x.shape, y)
# viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
# viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
# time.sleep(10) # # fun2
# import torchvision
# tf = transforms.Compose([
# transforms.Resize((64, 64)),
# transforms.RandomRotation(15),
# transforms.ToTensor(),
# ])
# db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
# loader = DataLoader(db, batch_size=32, shuffle=True)
# print(len(loader))
# for x, y in loader:
# # print(x.shape, y)
# viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
# viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
# time.sleep(10) if __name__ == '__main__':
main()

utils.py

import torch
from torch import nn class Flatten(nn.Module):
def __init__(self):
super().__init__() def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)

train.py

import os
import sys
base_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(base_path)
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)
import torch
import visdom
from torch import optim, nn
import torchvision from torch.utils.data import DataLoader from dogs_train.utils import Flatten
from dogsData import Dogs from torchvision.models import resnet18 viz = visdom.Visdom() batchsz = 32
lr = 1e-3
epochs = 20 device = torch.device('cuda')
torch.manual_seed(1234) train_db = Dogs('Images_Data_Dog', 224, mode='train')
val_db = Dogs('Images_Data_Dog', 224, mode='val')
test_db = Dogs('Images_Data_Dog', 224, mode='test') train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2) def evalute(model, loader):
correct = 0
total = len(loader.dataset)
for x, y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
logist = model(x)
pred = logist.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct/total def main(): # model = ResNet18(5).to(device)
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],
Flatten(), # [b, 512, 1, 1] => [b, 512]
nn.Linear(512, 27)
).to(device) x = torch.randn(2, 3, 224, 224).to(device)
print(model(x).shape) optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss() best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
for epoch in range(epochs): for step, (x, y) in enumerate(train_loader):
x = x.to(device)
y = y.to(device) logits = model(x)
loss = criteon(logits, y) optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
if epoch % 2 == 0:
val_acc = evalute(model, val_loader)
if val_acc > best_acc:
best_acc = val_acc
best_epoch = epoch
torch.save(model.state_dict(), 'best.mdl') viz.line([val_acc], [global_step], win='val_acc', update='append') print('best acc', best_acc, 'best epoch', best_epoch) model.load_state_dict(torch.load('best.mdl'))
print('loader from ckpt') test_acc = evalute(model, test_loader)
print(test_acc) if __name__ == '__main__':
main()

resnet18训练自定义数据集的更多相关文章

  1. MMDetection 快速开始,训练自定义数据集

    本文将快速引导使用 MMDetection ,记录了实践中需注意的一些问题. 环境准备 基础环境 Nvidia 显卡的主机 Ubuntu 18.04 系统安装,可见 制作 USB 启动盘,及系统安装 ...

  2. Scaled-YOLOv4 快速开始,训练自定义数据集

    代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...

  3. [炼丹术]YOLOv5训练自定义数据集

    YOLOv5训练自定义数据 一.开始之前的准备工作 克隆 repo 并在Python>=3.6.0环境中安装requirements.txt,包括PyTorch>=1.7.模型和数据集会从 ...

  4. yolov5训练自定义数据集

    yolov5训练自定义数据 step1:参考文献及代码 博客 https://blog.csdn.net/weixin_41868104/article/details/107339535 githu ...

  5. Tensorflow2 自定义数据集图片完成图片分类任务

    对于自定义数据集的图片任务,通用流程一般分为以下几个步骤: Load data Train-Val-Test Build model Transfer Learning 其中大部分精力会花在数据的准备 ...

  6. torch_13_自定义数据集实战

    1.将图片的路径和标签写入csv文件并实现读取 # 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0 def load_csv(self,file ...

  7. tensorflow从训练自定义CNN网络模型到Android端部署tflite

    网上有很多关于tensorflow lite在安卓端部署的教程,但是大多只讲如何把训练好的模型部署到安卓端,不讲如何训练,而实际上在部署的时候,需要知道训练模型时预处理的细节,这就导致了自己训练的模型 ...

  8. Yolo训练自定义目标检测

    Yolo训练自定义目标检测 参考darknet:https://pjreddie.com/darknet/yolo/ 1. 下载darknet 在 https://github.com/pjreddi ...

  9. pytorch加载语音类自定义数据集

    pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...

  10. PyTorch 自定义数据集

    准备数据 准备 COCO128 数据集,其是 COCO train2017 前 128 个数据.按 YOLOv5 组织的目录: $ tree ~/datasets/coco128 -L 2 /home ...

随机推荐

  1. 字节过滤流 缓冲流-->BufferedInputStream用法

    1创建字节输入节点流FileInputStream fis = new FileInputStream("文件读取的路径");2创建字节输入过滤流,包装一个字节输入节点流Buffe ...

  2. test.sh 监听进程是否存在

    监听myloader进程是否结束,结束后把时间输出到 /root/time.log vim test.sh #!/bin/bash #确保PRO查询进程唯一 PRO="myloader&qu ...

  3. android系统上编写、运行C#代码

    最近找到个好玩的APP,C#Shell (Compiler REPL),可以在安卓系统上编写和运行C#代码,配合sqlite数据库,写了个小爬虫,运行还不错: 运行一些小爬虫或者定时任务可以用这个,毕 ...

  4. windows消息机制_PostMessage和SendMessage

    1.子线程中建立一个窗口 为了在后面比较这两个函数,先使用win32 windows程序中建立子线程,在子线程中建立一个窗口. (1)新建一个 win32 windows应用程序 (2)定义子窗口的窗 ...

  5. 初识MPC

    MPC调研报告 ​ 这是一篇关于MPC的调研报告,主要介绍了我对MPC领域的一些基础认识.全文按照这样的方式组织:第一节我介绍了什么是MPC以及MPC的起源:第二节介绍了MPC领域常用的一些符号和安全 ...

  6. FIRE2023:殁亡漫谈

    FIRE2023:殁亡漫谈 读书的时候,想到殁亡,脑海涌出一则喜欢的遗言: 钱花完了,我走了.签名 如果可能牵涉到旁人(比如殁在旅馆里),就再立一则: 我的殁与店家无关. 签名 然后放下Kindle, ...

  7. DNS CNAME limitations cname 在哪些情况下不能配置

    https://www.rfc-editor.org/rfc/rfc1912.html https://www.rfc-editor.org/rfc/rfc2181.html 说明: domain n ...

  8. 在Windows上访问linux的共享文件夹

    1. https://blog.csdn.net/weixin_44147924/article/details/123692155

  9. github fork 别人的项目源作者更新后如何同步更新

    如下 左边选择我们拷贝的库  右边选择原工程 如下 点击箭头指向的位置 然后选择右边原工程目录

  10. docker学习3

    docker的启动流程 docker run -t -i <name:tag> /bin/bash -t 把1个伪终端绑定到容器的标准输入 -i 保持容器的标准输入始终打开不关闭 启动流程 ...