目录结构

dogsData.py

  1. import json
  2.  
  3. import torch
  4. import os, glob
  5. import random, csv
  6.  
  7. from PIL import Image
  8. from torch.utils.data import Dataset, DataLoader
  9.  
  10. from torchvision import transforms
  11. from torchvision.transforms import InterpolationMode
  12.  
  13. class Dogs(Dataset):
  14.  
  15. def __init__(self, root, resize, mode):
  16. super().__init__()
  17. self.root = root
  18. self.resize = resize
  19. self.nameLable = {}
  20. for name in sorted(os.listdir(os.path.join(root))):
  21. if not os.path.isdir(os.path.join(root, name)):
  22. continue
  23. self.nameLable[name] = len(self.nameLable.keys())
  24.  
  25. if not os.path.exists(os.path.join(self.root, 'label.txt')):
  26. with open(os.path.join(self.root, 'label.txt'), 'w', encoding='utf-8') as f:
  27. f.write(json.dumps(self.nameLable, ensure_ascii=False))
  28.  
  29. # print(self.nameLable)
  30. self.images, self.labels = self.load_csv('images.csv')
  31. # print(self.labels)
  32.  
  33. if mode == 'train':
  34. self.images = self.images[:int(0.8*len(self.images))]
  35. self.labels = self.labels[:int(0.8*len(self.labels))]
  36. elif mode == 'val':
  37. self.images = self.images[int(0.8*len(self.images)):int(0.9*len(self.images))]
  38. self.labels = self.labels[int(0.8*len(self.labels)):int(0.9*len(self.labels))]
  39. else:
  40. self.images = self.images[int(0.9*len(self.images)):]
  41. self.labels = self.labels[int(0.9*len(self.labels)):]
  42.  
  43. def load_csv(self, filename):
  44.  
  45. if not os.path.exists(os.path.join(self.root, filename)):
  46. images = []
  47. for name in self.nameLable.keys():
  48. images += glob.glob(os.path.join(self.root, name, '*.png'))
  49. images += glob.glob(os.path.join(self.root, name, '*.jpg'))
  50. images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
  51. # print(len(images))
  52.  
  53. random.shuffle(images)
  54. with open(os.path.join(self.root, filename), mode='w', newline='') as f:
  55. writer = csv.writer(f)
  56. for img in images:
  57. name = img.split(os.sep)[-2]
  58. label = self.nameLable[name]
  59. writer.writerow([img, label])
  60. print('csv write succesful')
  61.  
  62. images, labels = [], []
  63. with open(os.path.join(self.root, filename)) as f:
  64. reader = csv.reader(f)
  65. for row in reader:
  66. img, label = row
  67. label = int(label)
  68. images.append(img)
  69. labels.append(label)
  70.  
  71. assert len(images) == len(labels)
  72.  
  73. return images, labels
  74.  
  75. def denormalize(self, x_hat):
  76. mean = [0.485, 0.456, 0.406]
  77. std = [0.229, 0.224, 0.225]
  78. # x_hot = (x-mean)/std
  79. # x = x_hat * std = mean
  80. # x : [c, w, h]
  81. # mean [3] => [3, 1, 1]
  82. mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
  83. std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
  84.  
  85. x = x_hat * std + mean
  86. return x
  87.  
  88. def __len__(self):
  89. return len(self.images)
  90.  
  91. def __getitem__(self, idx):
  92. # print(idx, len(self.images), len(self.labels))
  93. img, label = self.images[idx], self.labels[idx]
  94.  
  95. # 将字符串路径转换为tensor数据
  96. # print(self.resize, type(self.resize))
  97. tf = transforms.Compose([
  98. lambda x: Image.open(x).convert('RGB'),
  99. transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
  100. transforms.RandomRotation(15),
  101. transforms.CenterCrop(self.resize),
  102. transforms.ToTensor(),
  103. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  104. ])
  105. img = tf(img)
  106.  
  107. label = torch.tensor(label)
  108.  
  109. return img, label
  110.  
  111. def main():
  112.  
  113. import visdom
  114. import time
  115.  
  116. viz = visdom.Visdom()
  117.  
  118. # func1 通用
  119. db = Dogs('Images_Data_Dog', 224, 'train')
  120. # 取一张
  121. # x,y = next(iter(db))
  122. # print(x.shape, y)
  123. # viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
  124.  
  125. # 取一个batch
  126. loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
  127. print(len(loader))
  128. print(db.nameLable)
  129. # for x, y in loader:
  130. # # print(x.shape, y)
  131. # viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
  132. # viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
  133. # time.sleep(10)
  134.  
  135. # # fun2
  136. # import torchvision
  137. # tf = transforms.Compose([
  138. # transforms.Resize((64, 64)),
  139. # transforms.RandomRotation(15),
  140. # transforms.ToTensor(),
  141. # ])
  142. # db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
  143. # loader = DataLoader(db, batch_size=32, shuffle=True)
  144. # print(len(loader))
  145. # for x, y in loader:
  146. # # print(x.shape, y)
  147. # viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
  148. # viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
  149. # time.sleep(10)
  150.  
  151. if __name__ == '__main__':
  152. main()

utils.py

  1. import torch
  2. from torch import nn
  3.  
  4. class Flatten(nn.Module):
  5. def __init__(self):
  6. super().__init__()
  7.  
  8. def forward(self, x):
  9. shape = torch.prod(torch.tensor(x.shape[1:])).item()
  10. return x.view(-1, shape)

train.py

  1. import os
  2. import sys
  3. base_path = os.path.dirname(os.path.abspath(__file__))
  4. sys.path.append(base_path)
  5. base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  6. sys.path.append(base_path)
  7. import torch
  8. import visdom
  9. from torch import optim, nn
  10. import torchvision
  11.  
  12. from torch.utils.data import DataLoader
  13.  
  14. from dogs_train.utils import Flatten
  15. from dogsData import Dogs
  16.  
  17. from torchvision.models import resnet18
  18.  
  19. viz = visdom.Visdom()
  20.  
  21. batchsz = 32
  22. lr = 1e-3
  23. epochs = 20
  24.  
  25. device = torch.device('cuda')
  26. torch.manual_seed(1234)
  27.  
  28. train_db = Dogs('Images_Data_Dog', 224, mode='train')
  29. val_db = Dogs('Images_Data_Dog', 224, mode='val')
  30. test_db = Dogs('Images_Data_Dog', 224, mode='test')
  31.  
  32. train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
  33. val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
  34. test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
  35.  
  36. def evalute(model, loader):
  37. correct = 0
  38. total = len(loader.dataset)
  39. for x, y in loader:
  40. x = x.to(device)
  41. y = y.to(device)
  42. with torch.no_grad():
  43. logist = model(x)
  44. pred = logist.argmax(dim=1)
  45. correct += torch.eq(pred, y).sum().float().item()
  46. return correct/total
  47.  
  48. def main():
  49.  
  50. # model = ResNet18(5).to(device)
  51. trained_model = resnet18(pretrained=True)
  52. model = nn.Sequential(*list(trained_model.children())[:-1],
  53. Flatten(), # [b, 512, 1, 1] => [b, 512]
  54. nn.Linear(512, 27)
  55. ).to(device)
  56.  
  57. x = torch.randn(2, 3, 224, 224).to(device)
  58. print(model(x).shape)
  59.  
  60. optimizer = optim.Adam(model.parameters(), lr=lr)
  61. criteon = nn.CrossEntropyLoss()
  62.  
  63. best_acc, best_epoch = 0, 0
  64. global_step = 0
  65. viz.line([0], [-1], win='loss', opts=dict(title='loss'))
  66. viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
  67. for epoch in range(epochs):
  68.  
  69. for step, (x, y) in enumerate(train_loader):
  70. x = x.to(device)
  71. y = y.to(device)
  72.  
  73. logits = model(x)
  74. loss = criteon(logits, y)
  75.  
  76. optimizer.zero_grad()
  77. loss.backward()
  78. optimizer.step()
  79. viz.line([loss.item()], [global_step], win='loss', update='append')
  80. global_step += 1
  81. if epoch % 2 == 0:
  82. val_acc = evalute(model, val_loader)
  83. if val_acc > best_acc:
  84. best_acc = val_acc
  85. best_epoch = epoch
  86. torch.save(model.state_dict(), 'best.mdl')
  87.  
  88. viz.line([val_acc], [global_step], win='val_acc', update='append')
  89.  
  90. print('best acc', best_acc, 'best epoch', best_epoch)
  91.  
  92. model.load_state_dict(torch.load('best.mdl'))
  93. print('loader from ckpt')
  94.  
  95. test_acc = evalute(model, test_loader)
  96. print(test_acc)
  97.  
  98. if __name__ == '__main__':
  99. main()

resnet18训练自定义数据集的更多相关文章

  1. MMDetection 快速开始,训练自定义数据集

    本文将快速引导使用 MMDetection ,记录了实践中需注意的一些问题. 环境准备 基础环境 Nvidia 显卡的主机 Ubuntu 18.04 系统安装,可见 制作 USB 启动盘,及系统安装 ...

  2. Scaled-YOLOv4 快速开始,训练自定义数据集

    代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...

  3. [炼丹术]YOLOv5训练自定义数据集

    YOLOv5训练自定义数据 一.开始之前的准备工作 克隆 repo 并在Python>=3.6.0环境中安装requirements.txt,包括PyTorch>=1.7.模型和数据集会从 ...

  4. yolov5训练自定义数据集

    yolov5训练自定义数据 step1:参考文献及代码 博客 https://blog.csdn.net/weixin_41868104/article/details/107339535 githu ...

  5. Tensorflow2 自定义数据集图片完成图片分类任务

    对于自定义数据集的图片任务,通用流程一般分为以下几个步骤: Load data Train-Val-Test Build model Transfer Learning 其中大部分精力会花在数据的准备 ...

  6. torch_13_自定义数据集实战

    1.将图片的路径和标签写入csv文件并实现读取 # 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0 def load_csv(self,file ...

  7. tensorflow从训练自定义CNN网络模型到Android端部署tflite

    网上有很多关于tensorflow lite在安卓端部署的教程,但是大多只讲如何把训练好的模型部署到安卓端,不讲如何训练,而实际上在部署的时候,需要知道训练模型时预处理的细节,这就导致了自己训练的模型 ...

  8. Yolo训练自定义目标检测

    Yolo训练自定义目标检测 参考darknet:https://pjreddie.com/darknet/yolo/ 1. 下载darknet 在 https://github.com/pjreddi ...

  9. pytorch加载语音类自定义数据集

    pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...

  10. PyTorch 自定义数据集

    准备数据 准备 COCO128 数据集,其是 COCO train2017 前 128 个数据.按 YOLOv5 组织的目录: $ tree ~/datasets/coco128 -L 2 /home ...

随机推荐

  1. 数据库ip被锁了怎么办

    由于多次访问失败,导致ip被限制,登录时会报错 Internal error/check (Not system error) 如何解决: 找一台同事的机子,(或者修改自己的ip),然后打开sql 的 ...

  2. MongoDB:内嵌文档查询匹配 查询集合中的文档

    1.db.getCollection('Notification').find({ Title:{$regex:/班/}, "Message.TargetUrl":{$regex: ...

  3. linux学习之grep

    grep 可进行查找内容 如 cat logs/anyproxy.log | grep "2020080321000049" 还可以通过-v 反向过滤 如 tail -f  log ...

  4. 调用mglearn时的报错 TypeError: __init__() got an unexpected keyword argument 'cachedir'

    import mglearn的时候发生的报错 原因是调用了joblib包中的memory类,但是cachedir这个参数已经弃用了 查到下面帖子之后改掉cachedir解决问题 https://blo ...

  5. turtle绘制风轮

    题目要求: 使用turtle库,绘制一个风轮效果,其中,每个风轮内角为45度,风轮边长150像素. 我的代码: import turtle turtle.setup(500,500,100,200) ...

  6. VUE项目中检测网页滑动注意事项

    一.this.$nextTick(function () {             window.addEventListener('scroll', this.onScroll, true)   ...

  7. 树莓派 wiringPi的BCM与BOARD编码

    一.基础命令使用wiringPi库 1.1.获取管教信息 gpio readall ---获取管脚信息   1.2.BOARD编码和BCM一般都在python库中使用 import RPi.GPIO ...

  8. 在.NET中使用JWT

    1.配置文件添加 //jwt配置文件 "JWT": { "SigningKey": "14fa5f2rrwsg627fs256fdgff2r5rf52 ...

  9. fetch请求方式

    Fetch请求的方式 1:GET 请求 // 未传参数 const getData = async () => { const res = await fetch('http://www.xxx ...

  10. Pytorch基础复习

    项目推进中期,重新到头来学Pytorch.five落泪了.(╬▔皿▔)凸 笑死,憋不住了,边更边学. 整篇博客整体采用总分总形式.首先将介绍内容(加黑部分)之间关系进行概括,后拆解,最后以图总结. 全 ...