前言

上一节大概讲了一下LeNet的内容,这一章就直接来用,实际上用一下LeNet来进行训练和分类试试。

调用的数据集:

https://aistudio.baidu.com/datasetdetail/19065

说明:

如今近视已经成为困扰人们健康的一项全球性负担,在近视人群中,有超过35%的人患有重度近视。近视会拉长眼睛的光轴,也可能引起视网膜或者络网膜的病变。随着近视度数的不断加深,高度近视有可能引发病理性病变,这将会导致以下几种症状:视网膜或者络网膜发生退化、视盘区域萎缩、漆裂样纹损害、Fuchs斑等。因此,及早发现近视患者眼睛的病变并采取治疗,显得非常重要。

数据集实际上用了这么几个部分:



其中DATADIR目录下的图片都是以下形式:



其中照片的成分是按照文件的名称来进行区分的,当其命名规则为P开头的,则代表其为病理性近视,N开头的为正常眼睛,而H开头的则代表为高度近视。

我们将病理性患者的图片作为正样本,标签为1; 非病理性患者的图片作为负样本,标签为0。从数据集中选取两张图片,通过LeNet提取特征,构建分类器,对正负样本进行分类,并将图片显示出来。

流程

正式开始之前,我们还是要将任务按照流程划分

我们在这次开发中,不仅要训练模型,计算Loss,还需要用数据集对成果进行准确性分析。

  1. 定义数据读取器

  2. 定义LeNet模型

  3. 编写训练过程

  4. 定义评估过程

  5. 进行模型计算

实际开发

定义数据读取器

我们需要读取两部分数据,分别是训练集和评估集,这两个集是分开的

  1. 首先我们需要进行一个预处理部分:
# 对读入的图像数据进行预处理
def transform_img(img):
# 将图片尺寸缩放道 224x224
img = cv2.resize(img, (224, 224))
# 读入的图像数据格式是[H, W, C]
# 使用转置操作将其变成[C, H, W]
img = np.transpose(img, (2,0,1))
img = img.astype('float32')
# 将数据范围调整到[-1.0, 1.0]之间
img = img / 255.
img = img * 2.0 - 1.0
return img
  1. 定义一个训练集数据读取器

类似之前的,我们需要将训练集划分batch,还需要将其打乱进行。

至于Label,则是由名称决定的

分组读取完毕之后,

# 定义训练集数据读取器
def data_loader(datadir, batch_size=10, mode = 'train'):
# 将datadir目录下的文件列出来,每条文件都要读入
filenames = os.listdir(datadir)
def reader():
if mode == 'train':
# 训练时随机打乱数据顺序
random.shuffle(filenames)
batch_imgs = []
batch_labels = []
for name in filenames:
filepath = os.path.join(datadir, name)
img = cv2.imread(filepath)
img = transform_img(img)
if name[0] == 'H' or name[0] == 'N':
# H开头的文件名表示高度近似,N开头的文件名表示正常视力
# 高度近视和正常视力的样本,都不是病理性的,属于负样本,标签为0
label = 0
elif name[0] == 'P':
# P开头的是病理性近视,属于正样本,标签为1
label = 1
else:
raise('Not excepted file name')
# 每读取一个样本的数据,就将其放入数据列表中
batch_imgs.append(img)
batch_labels.append(label)
if len(batch_imgs) == batch_size:
# 当数据列表的长度等于batch_size的时候,
# 把这些数据当作一个mini-batch,并作为数据生成器的一个输出
imgs_array = np.array(batch_imgs).astype('float32')
labels_array = np.array(batch_labels).astype('float32').reshape(-1, 1)
yield imgs_array, labels_array
batch_imgs = []
batch_labels = [] if len(batch_imgs) > 0:
# 剩余样本数目不足一个batch_size的数据,一起打包成一个mini-batch
imgs_array = np.array(batch_imgs).astype('float32')
labels_array = np.array(batch_labels).astype('float32').reshape(-1, 1)
yield imgs_array, labels_array return reader

定义一个验证集数据读取器

训练集读取时通过文件名来确定样本标签,验证集则通过csvfile来读取每个图片对应的标签

请查看解压后的验证集标签数据,观察csvfile文件里面所包含的内容

需要注意的是,原先的文件不是csv,而是一个xlsx,不能直接改个后缀就直接用,而是需要使用office或者wps重新保存一下,而且需要注意的是,请使用UTF-8或者gbk格式打开,否则可能会导致无法正确读取文件。

csvfile文件所包含的内容格式如下,每一行代表一个样本,其中第一列是图片id,第二列是文件名,第三列是图片标签,第四列和第五列是Fovea的坐标,与分类任务无关

ID,imgName,Label,Fovea_X,Fovea_Y

1,V0001.jpg,0,1157.74,1019.87

2,V0002.jpg,1,1285.82,1080.47

打开包含验证集标签的csvfile,并读入其中的内容

# 定义验证集数据读取器
def valid_data_loader(datadir, csvfile, batch_size=10, mode='valid'):
# 训练集读取时通过文件名来确定样本标签,验证集则通过csvfile来读取每个图片对应的标签
# 请查看解压后的验证集标签数据,观察csvfile文件里面所包含的内容
# csvfile文件所包含的内容格式如下,每一行代表一个样本,
# 其中第一列是图片id,第二列是文件名,第三列是图片标签,
# 第四列和第五列是Fovea的坐标,与分类任务无关
# ID,imgName,Label,Fovea_X,Fovea_Y
# 1,V0001.jpg,0,1157.74,1019.87
# 2,V0002.jpg,1,1285.82,1080.47
# 打开包含验证集标签的csvfile,并读入其中的内容
filelists = open(csvfile).readlines()
def reader():
batch_imgs = []
batch_labels = []
for line in filelists[1:]:
line = line.strip().split(',')
name = line[1]
label = int(line[2])
# 根据图片文件名加载图片,并对图像数据作预处理
filepath = os.path.join(datadir, name)
img = cv2.imread(filepath)
img = transform_img(img)
# 每读取一个样本的数据,就将其放入数据列表中
batch_imgs.append(img)
batch_labels.append(label)
if len(batch_imgs) == batch_size:
# 当数据列表的长度等于batch_size的时候,
# 把这些数据当作一个mini-batch,并作为数据生成器的一个输出
imgs_array = np.array(batch_imgs).astype('float32')
labels_array = np.array(batch_labels).astype('float32').reshape(-1, 1)
yield imgs_array, labels_array
batch_imgs = []
batch_labels = [] if len(batch_imgs) > 0:
# 剩余样本数目不足一个batch_size的数据,一起打包成一个mini-batch
imgs_array = np.array(batch_imgs).astype('float32')
labels_array = np.array(batch_labels).astype('float32').reshape(-1, 1)
yield imgs_array, labels_array return reader

定义LeNet模型

这个地方其实上一节也说过了,就是LeNet是如何定义的,详情可以参考

简易机器学习笔记(八)关于经典的图像分类问题-常见经典神经网络LeNet

这里就不过多介绍了,简单放一下代码:

# -*- coding:utf-8 -*-

# 导入需要的包
import paddle
import numpy as np
from paddle.nn import Conv2D, MaxPool2D, Linear, Dropout
import paddle.nn.functional as F # 定义 LeNet 网络结构
class LeNet(paddle.nn.Layer):
def __init__(self, num_classes=1):
super(LeNet, self).__init__()
self.num_classes = num_classes
# 创建卷积和池化层块,每个卷积层使用Sigmoid激活函数,后面跟着一个2x2的池化
self.conv1 = Conv2D(in_channels=3, out_channels=6, kernel_size=5)
self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)
self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5)
self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)
# 创建第3个卷积层
self.conv3 = Conv2D(in_channels=16, out_channels=120, kernel_size=4)
# 创建全连接层,第一个全连接层的输出神经元个数为64
self.fc1 = Linear(in_features=300000, out_features=64)
# 第二个全连接层输出神经元个数为分类标签的类别数
self.fc2 = Linear(in_features=64, out_features=num_classes) # 网络的前向计算过程
def forward(self, x, label=None):
x = self.conv1(x)
x = F.sigmoid(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.sigmoid(x)
x = self.max_pool2(x)
x = self.conv3(x)
x = F.sigmoid(x)
x = paddle.reshape(x, [x.shape[0], -1])
x = self.fc1(x)
x = F.sigmoid(x)
x = self.fc2(x)
if label is not None:
if self.num_classes == 1:
pred = F.sigmoid(x)
pred = paddle.concat([1.0 - pred, pred], axis=1)
acc = paddle.metric.accuracy(pred, paddle.cast(label, dtype='int64'))
else:
acc = paddle.metric.accuracy(x, paddle.cast(label, dtype='int64'))
return x, acc
else:
return x

编写训练过程

训练过程实际上和之前文章中提到的训练过程并无二至,实际上还是老一套:

  1. 读数据
  2. 前向计算预测
  3. 计算loss函数
  4. 反向传播
  5. 更新权重
  6. 清除梯度

当然了,这次的训练主要目的不是为了进行实际的工作,而是来进行模型准确度的测算,这也是我们在上面为什么读取数据集的时候除了基本的训练集,还添加了一个验证集。

验证集的验证工作其实比较简单,就是把model和验证集的参数传进去,然后让模型的预测和实际结果进行比较,计算出预测值和实际的label值的binary_cross_entropy_with_logits,再求出平均的损失值和准确度

代码如下:

# -*- coding: utf-8 -*-
# LeNet 识别眼疾图片
import os
import random
import paddle
import numpy as np DATADIR = '/home/aistudio/work/palm/PALM-Training400/PALM-Training400'
DATADIR2 = '/home/aistudio/work/palm/PALM-Validation400'
CSVFILE = '/home/aistudio/labels.csv' def train_pm(model, optimizer):
print('start training ... ')
model.train() #定义数据读取器,训练数据读取器和验证数据读取器
train_loader = data_loader(DATADIR, batch_size=10, mode='train')
valid_loader = valid_data_loader(DATADIR2, CSVFILE)
for epoch in range(EPOCH_NUM):
for batch_id,data in enumerate(train_loader()):
x_data,y_data = data
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
# 运行模型前向计算,得到预测值
logits = model(img)
loss = F.binary_cross_entropy_with_logits(logits, label)
avg_loss = paddle.mean(loss) if batch_id % 10 == 0:
print("epoch: {}, batch_id: {}, loss is: {:.4f}".format(epoch, batch_id, float(avg_loss.numpy()))) #反向传播,更新权重,清除梯度
avg_loss.backward()
optimizer.step()
optimizer.clear_grad() model.eval()
accuracies = []
losses = [] for batch_id,data in enumerate(valid_loader()):
x_data, y_data = data
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
# 运行模型前向计算,得到预测值
logits = model(img)
# 二分类,sigmoid计算后的结果以0.5为阈值分两个类别
# 计算sigmoid后的预测概率,进行loss计算
pred = F.sigmoid(logits)
loss = F.binary_cross_entropy_with_logits(logits, label) # 计算预测概率小于0.5的类别
pred2 = pred * (-1.0) + 1.0 # 得到两个类别的预测概率,并沿第一个维度级联
pred = paddle.concat([pred2, pred], axis=1)
acc = paddle.metric.accuracy(pred, paddle.cast(label, dtype='int64')) accuracies.append(acc.numpy())
losses.append(loss.numpy())
print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses)))
model.train() paddle.save(model.state_dict(), 'palm.pdparams')
paddle.save(optimizer.state_dict(), 'palm.pdopt') #定义评估过程
def evaluation(model,params_file_path): print('start evaluation.....') #加载模型参数
model_state_dict = paddle.load(params_file_path)
model.load_dict(model_state_dict) model.eval()
eval_loader = data_loader(DATADIR,
batch_size=10, mode='eval')
acc_set = []
avg_loss_set = [] for batch_id, data in enumerate(eval_loader()):
x_data,y_data = data
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
y_data = y_data.astype(np.int64)
label_64 = paddle.to_tensor(y_data) # 计算预测和精度
prediction, acc = model(img, label_64) # 计算损失函数值
loss = F.binary_cross_entropy_with_logits(prediction, label)
avg_loss = paddle.mean(loss)
acc_set.append(float(acc.numpy()))
avg_loss_set.append(float(avg_loss.numpy()))
# 求平均精度
acc_val_mean = np.array(acc_set).mean()
avg_loss_val_mean = np.array(avg_loss_set).mean() print('loss={:.4f}, acc={:.4f}'.format(avg_loss_val_mean, acc_val_mean))

上述就是LeNet在实际验证中的总全部代码,稍微看懂整理一下即可。我们可以跑一下看看结果

结果

他奶奶的,本来就是个三分法的问题,算出来的准确度才0.5几,那不就和我瞎猜准确度高一点点.....

不过这也算是一个全流程的设计与开发,可以参考一下流程,

start training ...
epoch: 0, batch_id: 0, loss is: 0.8100
epoch: 0, batch_id: 10, loss is: 0.6131
epoch: 0, batch_id: 20, loss is: 0.7744
epoch: 0, batch_id: 30, loss is: 0.7073
[validation] accuracy/loss: 0.5275/0.6923
epoch: 1, batch_id: 0, loss is: 0.7042
epoch: 1, batch_id: 10, loss is: 0.6933
epoch: 1, batch_id: 20, loss is: 0.6831
epoch: 1, batch_id: 30, loss is: 0.6810
[validation] accuracy/loss: 0.5275/0.6920
epoch: 2, batch_id: 0, loss is: 0.7451
epoch: 2, batch_id: 10, loss is: 0.6951
epoch: 2, batch_id: 20, loss is: 0.7227
epoch: 2, batch_id: 30, loss is: 0.6579
[validation] accuracy/loss: 0.5275/0.6918
epoch: 3, batch_id: 0, loss is: 0.6808
epoch: 3, batch_id: 10, loss is: 0.6888
epoch: 3, batch_id: 20, loss is: 0.6944
epoch: 3, batch_id: 30, loss is: 0.6829
[validation] accuracy/loss: 0.5275/0.6917
epoch: 4, batch_id: 0, loss is: 0.6855
epoch: 4, batch_id: 10, loss is: 0.6458
epoch: 4, batch_id: 20, loss is: 0.7227
epoch: 4, batch_id: 30, loss is: 0.7857
[validation] accuracy/loss: 0.5275/0.6917
start evaluation.....
loss=0.6912, acc=0.5325

简易机器学习笔记(九)LeNet实例 - 在眼疾识别数据集iChallenge-PM上的应用的更多相关文章

  1. 吴恩达机器学习笔记61-应用实例:图片文字识别(Application Example: Photo OCR)【完结】

    最后一章内容,主要是OCR的实例,很多都是和经验或者实际应用有关:看完了,总之,善始善终,继续加油!! 一.图像识别(店名识别)的步骤: 图像文字识别应用所作的事是,从一张给定的图片中识别文字.这比从 ...

  2. 【原】Coursera—Andrew Ng机器学习—课程笔记 Lecture 18—Photo OCR 应用实例:图片文字识别

    Lecture 18—Photo OCR 应用实例:图片文字识别 18.1 问题描述和流程图 Problem Description and Pipeline 图像文字识别需要如下步骤: 1.文字侦测 ...

  3. Elasticsearch笔记九之优化

    Elasticsearch笔记九之优化 ).get(); } curl命令可以在linux中建立一个定时任务每天执行一次,同样java代码也可以建立一个定时器来执行. 2:内存设置之前介绍过es集群有 ...

  4. Python机器学习笔记:使用Keras进行回归预测

    Keras是一个深度学习库,包含高效的数字库Theano和TensorFlow.是一个高度模块化的神经网络库,支持CPU和GPU. 本文学习的目的是学习如何加载CSV文件并使其可供Keras使用,如何 ...

  5. Python机器学习笔记:sklearn库的学习

    网上有很多关于sklearn的学习教程,大部分都是简单的讲清楚某一方面,其实最好的教程就是官方文档. 官方文档地址:https://scikit-learn.org/stable/ (可是官方文档非常 ...

  6. Python机器学习笔记:不得不了解的机器学习面试知识点(1)

    机器学习岗位的面试中通常会对一些常见的机器学习算法和思想进行提问,在平时的学习过程中可能对算法的理论,注意点,区别会有一定的认识,但是这些知识可能不系统,在回答的时候未必能在短时间内答出自己的认识,因 ...

  7. 多线程学习笔记九之ThreadLocal

    目录 多线程学习笔记九之ThreadLocal 简介 类结构 源码分析 ThreadLocalMap set(T value) get() remove() 为什么ThreadLocalMap的键是W ...

  8. 【转】机器学习笔记之(3)——Logistic回归(逻辑斯蒂回归)

    原文链接:https://blog.csdn.net/gwplovekimi/article/details/80288964 本博文为逻辑斯特回归的学习笔记.由于仅仅是学习笔记,水平有限,还望广大读 ...

  9. Python机器学习笔记:K-Means算法,DBSCAN算法

    K-Means算法 K-Means 算法是无监督的聚类算法,它实现起来比较简单,聚类效果也不错,因此应用很广泛.K-Means 算法有大量的变体,本文就从最传统的K-Means算法学起,在其基础上学习 ...

  10. Python机器学习笔记:SVM(1)——SVM概述

    前言 整理SVM(support vector machine)的笔记是一个非常麻烦的事情,一方面这个东西本来就不好理解,要深入学习需要花费大量的时间和精力,另一方面我本身也是个初学者,整理起来难免思 ...

随机推荐

  1. .NET企业应用安全开发动向-概览

    太长不读版:试图从安全的全局视角触发,探讨安全的重要性,讨论如何识别安全问题的方法,介绍.NET提供的与安全相关的基础设施,以及一些与时俱进的安全问题,为读者建立体系化的安全思考框架. 引言 关于&q ...

  2. Spring整合Quartz简单入门

    创建一个Web项目 导入相关jar包 <?xml version="1.0" encoding="UTF-8"?> <project xmln ...

  3. django-celery-beat插件使用

    该插件从 Django 管理界面管理celery的定期任务,您可以在其中动态****创建.编辑和删除定期任务以及它们的运行频率. django-celery-beat提供了几种添加定时或周期性任务的方 ...

  4. Python——第二章:元组

    元组 tuple 使用小括号组成 特点: 元组是不可变的,固定了某些数据. t = ("张无忌", "赵敏", "呵呵哒") print(t ...

  5. 部署堡垒机4——CentOS7 编译安装 Python 3.8.12

    1.去python3的官方网站下载源代码 https://www.python.org/downloads/ 下载安装Python 3.8.12到/opt/python3 cd /opt wget h ...

  6. 云MSP技本功|redis的5种对象与8种数据结构之字符串对象(下)

    简介: 引言 本文是对<redis设计与实现(第二版)>中数据结构与对象相关内容的整理与说明.本篇文章只对对象结构,1种对象--字符串对象.以及字符串对象所对应的两种编码--raw和emb ...

  7. 云MSP服务案例|互联网商城的上云改造之旅

    简介: 在中国,经过十年的发展,云计算产业已走过概念普及的1.0时期,进入"上云"和落地的2. 0阶段,企业上云意识不断增强,越来越多的企业选择部署多云和混合IT. 如今,云计算生 ...

  8. 1.elasticsearch运行

    在docker中运行elasticsearch.kibana 一.MacOs 首先需要安装doceker,提供两种方式,选一种方便的就好 1.命令行安装方式 安装命令行 xcode-select -- ...

  9. 【csharp】抽象类与接口有哪些不同?什么时候应该使用抽象类?

    抽象类与接口有哪些不同? 抽象类和接口是在面向对象编程中两个不同的概念,它们有一些重要的区别.以下是抽象类和接口的主要不同点: 抽象类(Abstract Class): 成员类型: 抽象类可以包含抽象 ...

  10. Git 的底层原理

    前言 ​ 基于 Git 的使用,已经在前文有过相关的介绍,使用 Git 用作日常的开发基本上是足够的.现在,本文将详细介绍一些有关 Git 的实现原理. 底层命令与上层命令 ​ 一般情况下,正常使用的 ...