VisualPytorch beta发布了!

功能概述:通过可视化拖拽网络层方式搭建模型,可选择不同数据集、损失函数、优化器生成可运行pytorch代码

扩展功能:1. 模型搭建支持模块的嵌套;2. 模型市场中能共享及克隆模型;3. 模型推理助你直观的感受神经网络在语义分割、目标探测上的威力;4.添加图像增强、快速入门、参数弹窗等辅助性功能

修复缺陷:1.大幅改进UI界面,提升用户体验;2.修改注销不跳转、图片丢失等已知缺陷;3.实现双服务器访问,缓解访问压力

访问地址http://sunie.top:9000

发布声明详见https://www.cnblogs.com/NAG2020/p/13030602.html

一、Dataloader与Dataset

1. DataLoader

Epoch: 所有训练样本都已输入到模型中,称为一个Epoch

Iteration:一批样本输入到模型中,称之为一个Iteration

Batchsize:批大小,决定一个Epoch有多少个Iteration

样本总数:87, Batchsize:8

1 Epoch = 10 Iteration ? drop_last = True

1 Epoch = 11 Iteration ? drop_last = False

2. Dataset



人民币二分类为例:将RMB_data按照8:1:1分为train, test, valid三组,构建Dataset类

class RMBDataset(Dataset):
def __init__(self, data_dir, transform=None):
"""
rmb面额分类任务的Dataset
:param data_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理
"""
self.label_name = {"1": 0, "100": 1}
self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255 if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等 return img, label def __len__(self):
return len(self.data_info) @staticmethod
def get_img_info(data_dir):
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names)) # 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = rmb_label[sub_dir]
data_info.append((path_img, int(label))) return data_info

实例化Dataset与DataLoader:

# ============================ step 1/5 数据 ============================

split_dir = os.path.join("..", "..", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid") norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225] train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
]) valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
]) # 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform) # 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) # 每一个epoch样本顺序不同
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

二、图像预处理——transform

torchvision.transforms : 常用的图像预处理方法

  • 数据中心化
  • 数据标准化
  • 缩放
  • 裁剪
  • 旋转
  • 翻转
  • 填充
  • 噪声添加
  • 灰度变换
  • 线性变换
  • 仿射变换
  • 亮度、饱和度及对比度变换

图像增强:丰富数据集,提高模型的泛化能力

# 对图像进行有序的组合与包装
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std), # 标准化
]) valid_transform = transforms.Compose([
transforms.Resize((32, 32)), # 验证时不需要进行图像增强
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])



进行标准化能加快模型的收敛!

三、transform图像增强(一)

数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力.

1. 裁剪



  • padding_mode:填充模式,有4种模式

    • 1、constant:像素值由fill设定
    • 2、edge:像素值由图像边缘像素决定
    • 3、reflect:镜像填充,最后一个像素不镜像,eg:[1,2,3,4] → [3,2,1,2,3,4,3,2]
    • 4、symmetric:镜像填充,最后一个像素镜像,eg:[1,2,3,4] → [2,1,1,2,3,4,4,3]
    • fill:constant时,设置填充的像素值

2. 翻转和旋转





center=(0,0) # 左上角旋转

expand仅针对center,没办法找回左上角旋转丢失的信息

四、transform图像增强(二)

1. 图像变换



padding_mode为镜像时,fill不起作用







直接在tensor上进行操作:

transforms.ToTensor()
transforms.RandomErasing(p=1, scale=(0.02,0.33), ratio=(0.5,1), value=(254/255,0,0))

其中value只要是字符串,随机色彩

2. transform操作

3. 自定义transform

class AddPepperNoise(object):
"""增加椒盐噪声
Args:
snr (float): Signal Noise Rate
p (float): 概率值,依概率执行该操作
""" def __init__(self, snr, p=0.9):
assert isinstance(snr, float) or (isinstance(p, float))
self.snr = snr
self.p = p def __call__(self, img):
"""
Args:
img (PIL Image): PIL Image
Returns:
PIL Image: PIL image.
"""
if random.uniform(0, 1) < self.p:
img_ = np.array(img).copy()
h, w, c = img_.shape
signal_pct = self.snr
noise_pct = (1 - self.snr)
mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[signal_pct, noise_pct/2., noise_pct/2.])
mask = np.repeat(mask, c, axis=2)
img_[mask == 1] = 255 # 盐噪声:白
img_[mask == 2] = 0 # 椒噪声:黑
return Image.fromarray(img_.astype('uint8')).convert('RGB')
else:
return img

4. 实战

原则:让训练集与测试集更接近

  • 空间位置:平移
  • 色彩:灰度图,色彩抖动
  • 形状:仿射变换
  • 上下文场景:遮挡,填充
  • ......





分析原图,主要是色相的问题,将输入图片直接转为灰度图

train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.RandomGrayscale(p=1),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])

源代码:

# -*- coding: utf-8 -*-
"""
# @file name : RMB_data_augmentation.py
# @author : tingsongyu
# @date : 2019-09-16 10:08:00
# @brief : 人民币分类模型数据增强实验
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import transform_invert def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed) set_seed() # 设置随机种子
rmb_label = {"1": 0, "100": 1} # 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1 # ============================ step 1/5 数据 ============================ split_dir = os.path.join("..", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid") norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225] train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.RandomGrayscale(p=1),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
]) valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomGrayscale(p=1),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
]) # 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform) # 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE) # ============================ step 2/5 模型 ============================ net = LeNet(classes=2)
net.initialize_weights() # ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss() # 选择损失函数 # ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # 设置学习率下降策略 # ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list() for epoch in range(MAX_EPOCH): loss_mean = 0.
correct = 0.
total = 0. net.train()
for i, data in enumerate(train_loader): # forward
inputs, labels = data
outputs = net(inputs) # backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward() # update weights
optimizer.step() # 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy() # 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0. scheduler.step() # 更新学习率 # validate the model
if (epoch+1) % val_interval == 0: correct_val = 0.
total_val = 0.
loss_val = 0.
net.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
inputs, labels = data
outputs = net(inputs)
loss = criterion(outputs, labels) _, predicted = torch.max(outputs.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).squeeze().sum().numpy() loss_val += loss.item() valid_curve.append(loss_val)
print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct / total)) train_x = range(len(train_curve))
train_y = train_curve train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid') plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show() # ============================ inference ============================ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data") test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1) for i, data in enumerate(valid_loader):
# forward
inputs, labels = data
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1) rmb = 1 if predicted.numpy()[0] == 0 else 100 img_tensor = inputs[0, ...] # C H W
img = transform_invert(img_tensor, train_transform)
plt.imshow(img)
plt.title("LeNet got {} Yuan".format(rmb))
plt.show()
plt.pause(0.5)
plt.close()

Pytorch_Part2_数据模块的更多相关文章

  1. 关于ofbiz加载数据模块的文件参数配置

    1,在applications文件夹下新建一个数据模块meetingroom 2, 要让ofbiz加载这个数据模块就需要在applications下的配置文件里修改参数 (1)在application ...

  2. Qt中数据模块学习

    QtSql模块 驱动类型和数据库:不同的数据库用不同的驱动连接(接口不同) QDB2->DB2 QOCI->orcle QODBC->SQLServer等 QSqlDataBase类 ...

  3. Python-读写csv数据模块 csv

    案例: 通过股票网站,我们获取了中国股市数据集,它以csv数据格式存储 Data,Open,High,Low,Close,Volume,Adj Close 2016-06-28,8.63,8.47,8 ...

  4. 构建通用的 React 和 Node 应用

    这是一篇非常优秀的 React 教程,这篇文章对 React 组件.React Router 以及 Node 做了很好的梳理.我是 9 月份读的该文章,当时跟着教程做了一遍,收获很大.但是由于时间原因 ...

  5. [django]数据导出excel升级强化版(很强大!)

    不多说了,原理采用xlwt导出excel文件,所谓的强化版指的是实现在网页上选择一定条件导出对应的数据 之前我的博文出过这类文章,但只是实现导出数据,这次左思右想,再加上网上的搜索,终于找出方法实现条 ...

  6. 【腾讯Bugly干货分享】React Native项目实战总结

    本文来自于腾讯bugly开发者社区,非经作者同意,请勿转载,原文地址:http://dev.qq.com/topic/577e16a7640ad7b4682c64a7 “8小时内拼工作,8小时外拼成长 ...

  7. [翻译]Orchard如何工作

    Orchard一直是博主心中神一般的存在,由于水平比较菜,Orchard代码又比较复杂看了几次都不了了之了.这次下定决心要搞懂其工作原理,争取可以在自己的项目中有所应用.为了入门先到官网去学习一下相关 ...

  8. 升讯威ADO.NET增强组件(源码):送给喜欢原生ADO.NET的你

    目前我们所接触到的许多项目开发,大多数都应用了 ORM 技术来实现与数据库的交互,ORM 虽然有诸多好处,但是在实际工作中,特别是在大型项目开发中,容易发现 ORM 存在一些缺点,在复杂场景下,反而容 ...

  9. JS模块化开发:使用SeaJs高效构建页面

    一.扯淡部分 很久很久以前,也就是刚开始接触前端的那会儿,脑袋里压根没有什么架构.重构.性能这些概念,天真地以为前端===好看的页面,甚至把js都划分到除了用来写一些美美的特效别无它用的阴暗角落里,就 ...

随机推荐

  1. 前瞻|Amundsen的数据血缘功能

    目前,Amundsen并不支持表级别和列级别的数据血缘功能,也没有办法展示数据的来龙去脉. 作为Amundsen一项非常核心的功能,Lineage功能早已经提上日程,并进入设计与研发阶段.本位将展示此 ...

  2. Android Studio中批量注释 Java代码

    •ctrl+/ 选中需要注释的多行代码,然后按 ctrl + / 实现多行快速注释: 再次按下 ctrl + / 取消注释. •ctrl+shift+/ 选中一行或几行代码,按 ctrl + shif ...

  3. Codeforces-121C(逆康托展开)

    题目大意: 给你两个数n,k求n的全排列的第k小,有多少满足如下条件的数: 首先定义一个幸运数字:只由4和7构成 对于排列p[i]满足i和p[i]都是幸运数字 思路: 对于n,k<=1e9 一眼 ...

  4. PaddleOCR详解

    @ 目录 PaddleOCR简介 环境配置 PaddleOCR2.0的配置环境 Docker 数据集 文本检测 使用自己的数据集 文本识别 使用自己的数据集 字典 自定义字典 添加空格类别 文本角度分 ...

  5. [BFS]骑士旅行

    骑士旅行 Description 在一个n m 格子的棋盘上,有一只国际象棋的骑士在棋盘的左下角 (1;1)(如图1),骑士只能根据象棋的规则进行移动,要么横向跳动一格纵向跳动两格,要么纵向跳动一格横 ...

  6. 安装mongoDB出现的问题:无法启动

    在我的电脑- 管理 - 服务-中会出现一个MongoDB Server的服务,你需要去手动删除这个服务删除指令: 在cmd管理员模式下使用: sc delete MongoDB Server 然后再配 ...

  7. 02-MySQL主要配置文件

    一.二进制日志log-bin 作用:主从复制 二.错误日志 log-err 默认关闭,记录严重的警告和错误信息,每次启动和关闭的详细信息 三.慢查询日志log 默认关闭,记录查询的sql语句,如果开启 ...

  8. Apache Hudi 0.8.0版本重磅发布

    1. 重点特性 1.1 Flink集成 自从Hudi 0.7.0版本支持Flink写入后,Hudi社区又进一步完善了Flink和Hudi的集成.包括重新设计性能更好.扩展性更好.基于Flink状态索引 ...

  9. mvn 报错 - The POM for <name> is invalid, transitive dependencies (if any) will not be available

    核心:  通过 mvn dependency:tree -X 分析依赖解决方案:  解决依赖冲突版本 1. MILGpController 编译突然报错 14:10:28 [ERROR] Failed ...

  10. js--原型和原型链相关问题

    前言 阅读本文前先来思考一个问题,我们在 js 中创建一个变量,我们并没有给这个变量添加一些方法,比如 toString() 方法,为什么我们可以直接使用这个方法呢?如以下代码,带着这样的问题,我们来 ...