(第三章)TF框架之实现验证码识别
这里实现一个用神经网络(卷积神经网络也可以)实现验证码识别的小案例,主要记录本人做这个案例的流程,不会像之前那么详细,主要用作个人记录用。。。
- 这里是验证码的四个字母,被one-hot编码后形成的四个一维数组,[1, 26] * 4 ----> 可以转变成[4, 26] ----> [1, 104]
第一个位置:[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0]
第二个位置:[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1]
第三个位置:[0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0]
第四个位置:[0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0]
字母验证码识别设计:
这两个(真实值和预测值)104的一阶张量进行交叉熵损失计算,得出损失大小。会提高四个位置的概率,使得4组中每组26个目标值中为1的位置对应的预测概率值越来越大,在预测的四组当中概率值最大。这样得出预测中每组的字母位置。所有104个概率相加为1
流程设计
1、把图片的特征值和目标值,-----> 转换成tfrecords格式,方便数据特征值、目标值统一读取
[b'NZPP' b'WKHK' b'WPSJ' ..., b'FVQJ' b'BQYA' b'BCHR'] -----> [[13, 25, 15, 15], [22, 10, 7, 10], [22, 15, 18, 9], [16, 6, 13, 10]]
"ABCD……Z" —>"0, 1, …, 25"
2、训练验证码、准确率的计算
将原来的图片数据(特征)和csv数据(标签)------> 转变为tfrecords格式的数据,注意example协议(序列化后)
代码如下:
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("tfrecords_dir", "./tfrecords/captcha.tfrecords", "验证码tfrecords文件")
tf.app.flags.DEFINE_string("captcha_dir", "../data/Genpics/", "验证码图片路径")
tf.app.flags.DEFINE_string("letter", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "验证码字符的种类") def dealwithlabel(label_str): # 构建字符索引 {0:'A', 1:'B'......}
num_letter = dict(enumerate(list(FLAGS.letter))) # 键值对反转 {'A':0, 'B':1......}
letter_num = dict(zip(num_letter.values(), num_letter.keys())) print(letter_num) # 构建标签的列表
array = [] # 给标签数据进行处理[[b"NZPP"], ......]
for string in label_str: letter_list = []# [1,2,3,4] # 修改编码,b'FVQJ'到字符串,并且循环找到每张验证码的字符对应的数字标记
for letter in string.decode('utf-8'):
letter_list.append(letter_num[letter]) array.append(letter_list) # [[13, 25, 15, 15], [22, 10, 7, 10], [22, 15, 18, 9], [16, 6, 13, 10], [1, 0, 8, 17], [0, 9, 24, 14].....]
print(array) # 将array转换成tensor类型
label = tf.constant(array) return label def get_captcha_image():
"""
获取验证码图片数据
:param file_list: 路径+文件名列表
:return: image
"""
# 构造文件名
filename = [] for i in range(6000):
string = str(i) + ".jpg"
filename.append(string) # 构造路径+文件
file_list = [os.path.join(FLAGS.captcha_dir, file) for file in filename] # 构造文件队列
file_queue = tf.train.string_input_producer(file_list, shuffle=False) # 构造阅读器
reader = tf.WholeFileReader() # 读取图片数据内容
key, value = reader.read(file_queue) # 解码图片数据
image = tf.image.decode_jpeg(value) image.set_shape([20, 80, 3]) # 批处理数据 [6000, 20, 80, 3]
image_batch = tf.train.batch([image], batch_size=6000, num_threads=1, capacity=6000) return image_batch def get_captcha_label():
"""
读取验证码图片标签数据
:return: label
"""
file_queue = tf.train.string_input_producer(["../data/Genpics/labels.csv"], shuffle=False) reader = tf.TextLineReader() key, value = reader.read(file_queue) records = [[1], ["None"]] number, label = tf.decode_csv(value, record_defaults=records) # [["NZPP"], ["WKHK"], ["ASDY"]]
label_batch = tf.train.batch([label], batch_size=6000, num_threads=1, capacity=6000) return label_batch def write_to_tfrecords(image_batch, label_batch):
"""
将图片内容和标签写入到tfrecords文件当中
:param image_batch: 特征值
:param label_batch: 标签值
:return: None
"""
# 转换类型
label_batch = tf.cast(label_batch, tf.uint8) print(label_batch) # 建立TFRecords 存储器
writer = tf.python_io.TFRecordWriter(FLAGS.tfrecords_dir) # 循环将每一个图片上的数据构造example协议块,序列化后写入
for i in range(6000):
# 取出第i个图片数据,转换相应类型,图片的特征值要转换成字符串形式
image_string = image_batch[i].eval().tostring() # 标签值,转换成整型
label_string = label_batch[i].eval().tostring() # 构造协议块
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_string])),
"label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_string]))
})) writer.write(example.SerializeToString()) # 关闭文件
writer.close() return None if __name__ == "__main__": # 获取验证码文件当中的图片
image_batch = get_captcha_image() # 获取验证码文件当中的标签数据
label = get_captcha_label() print(image_batch, label) with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # [b'NZPP' b'WKHK' b'WPSJ' ..., b'FVQJ' b'BQYA' b'BCHR']
label_str = sess.run(label) print(label_str) # 处理字符串标签到数字张量
label_batch = dealwithlabel(label_str) print(label_batch) # 将图片数据和内容写入到tfrecords文件当中
write_to_tfrecords(image_batch, label_batch) coord.request_stop() coord.join(threads)
- 训练验证码,得到准确率的代码
import tensorflow as tf class CaptchaIdentification(object):
"""
验证码的读取数据、网络训练
"""
def __init__(self): # 验证码图片的属性
self.height = 20
self.width = 80
self.channel = 3
# 每个验证码的目标值个数(4个字符)
self.label_num = 4
self.feature_num = 26 # 每批次训练样本个数
self.train_batch = 100 @staticmethod
def weight_variables(shape):
w = tf.Variable(tf.random_normal(shape=shape, mean=0.0, stddev=0.1))
return w @staticmethod
def bias_variables(shape):
b = tf.Variable(tf.random_normal(shape=shape, mean=0.0, stddev=0.1))
return b def read_captcha_tfrecords(self):
"""
读取验证码特征值和目标值数据
:return:
"""
# 1、构造文件的队列
file_queue = tf.train.string_input_producer(["./tfrecords/captcha.tfrecords"]) # 2、tf.TFRecordReader 读取TFRecords数据
reader = tf.TFRecordReader() # 单个样本数据
key, value = reader.read(file_queue) # 3、解析example协议
feature = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.string)
}) # 4、解码操作、数据类型、形状
image = tf.decode_raw(feature["image"], tf.uint8)
label = tf.decode_raw(feature["label"], tf.uint8) # 确定类型和形状
# 图片形状 [20, 80, 3]
# 目标值 [4]
image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
label_reshape = tf.reshape(label, [self.label_num]) # 类型
image_type = tf.cast(image_reshape, tf.float32)
label_type = tf.cast(label_reshape, tf.int32) # 5、 批处理
# print(image_type, label_type)
# 提供每批次多少样本去进行训练
image_batch, label_batch = tf.train.batch([image_type, label_type],
batch_size=self.train_batch,
num_threads=1,
capacity=self.train_batch)
print(image_batch, label_batch)
return image_batch, label_batch def captcha_model(self, image_batch):
"""
建立全连接层网络
:param image_batch: 验证码图片特征值
:return: 预测结果
"""
# 全连接层
# [100, 20, 80, 3] --->[100, 20 * 80 * 3]
# [100, 20 * 80 * 3] * [20 * 80 * 3, 104] + [104] = [None, 104] 104 = 4*26
with tf.variable_scope("captcha_fc_model"):
# 初始化权重和偏置参数
self.weight = self.weight_variables([20 * 80 * 3, 104]) self.bias = self.bias_variables([104]) # 4维---->2维做矩阵运算
x_reshape = tf.reshape(image_batch, [self.train_batch, 20 * 80 * 3]) # [self.train_batch, 104]
y_predict = tf.matmul(x_reshape, self.weight) + self.bias return y_predict def loss(self, y_true, y_predict):
"""
建立验证码4个目标值的损失
:param y_true: 真实值
:param y_predict: 预测值
:return: loss
"""
with tf.variable_scope("loss"):
# 先进行网络输出的值的概率计算softmax,在进行交叉熵损失计算
# y_true:[100, 4, 26]------>[None, 104]
# y_predict:[100, 104]
y_reshape = tf.reshape(y_true,
[self.train_batch, self.label_num * self.feature_num]) all_loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_reshape,
logits=y_predict,
name="compute_loss")
# 求出平均损失
loss = tf.reduce_mean(all_loss) return loss def turn_to_onehot(self, label_batch):
"""
目标值转换成one_hot编码
:param label_batch: 目标值 [None, 4]
:return:
"""
with tf.variable_scope("one_hot"): # [None, 4]--->[None, 4, 26]
y_true = tf.one_hot(label_batch,
depth=self.feature_num,
on_value=1.0)
return y_true def sgd(self, loss):
"""
梯度下降优化损失
:param loss:
:return: train_op
"""
with tf.variable_scope("sgd"): train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss) return train_op def acc(self, y_true, y_predict):
"""
计算准确率
:param y_true: 真实值
:param y_predict: 预测值
:return: accuracy
"""
with tf.variable_scope("acc"): # y_true:[None, 4, 26]
# y_predict:[None, 104]
y_predict_reshape = tf.reshape(y_predict, [self.train_batch, self.label_num, self.feature_num]) # 先对最大值的位置去求解 这里的2指的是维度
euqal_list = tf.equal(tf.argmax(y_true, 2), tf.argmax(y_predict_reshape, 2)) # 需要对每个样本进行判断 这里的1指的是维度
# x = tf.constant([[True, True], [False, False]])
# tf.reduce_all(x, 1) # [True, False]
accuracy = tf.reduce_mean(tf.cast(tf.reduce_all(euqal_list, 1), tf.float32)) return accuracy def train(self):
"""
模型训练逻辑
:return:
"""
# 1、通过接口获取特征值和目标值
# image_batch:[100, 20, 80, 3]
# label_batch: [100, 4]
# [[13, 25, 15, 15], [22, 10, 7, 10]]
image_batch, label_batch = self.read_captcha_tfrecords() # 2、建立验证码识别的模型
# 全连接层神经网络
# y_predict [100, 104]
y_predict = self.captcha_model(image_batch) # 转换label_batch 到one_hot编码
# y_true:[None, 4, 26]
y_true = self.turn_to_onehot(label_batch) # 3、利用真实值和目标值建立损失
loss = self.loss(y_true, y_predict) # 4、对损失进行梯度下降优化
train_op = self.sgd(loss) # 5、计算准确率
accuracy = self.acc(y_true, y_predict) # 会话训练
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 生成线程的管理
coord = tf.train.Coordinator() # 指定开启子线程去读取数据
threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 循环训练打印结果
for i in range(1000): _, acc_run = sess.run([train_op, accuracy]) print("第 %d 次训练的准确率为:%f " % (i, acc_run)) # 回收线程
coord.request_stop() coord.join(threads) return None if __name__ == '__main__':
ci = CaptchaIdentification()
ci.train()
(第三章)TF框架之实现验证码识别的更多相关文章
- jQuery系列 第三章 jQuery框架操作CSS
第三章 jQuery框架操作CSS 3.1 jQuery框架的CSS方法 jQuery框架提供了css方法,我们通过调用该方法传递对应的参数,可以方便的来批量设置标签的CSS样式. 使用JavaScr ...
- NancyFX 第三章 Web框架
如果使用Nancy作为一个WEB框架而言,会有什么不同?实际上很多. 在使用Nancy框架为网页添加Rest节点和路由和之前的Rest框架中是相同的,这方面没有什么需要学习的了.Nancy采用一贯的处 ...
- 第三章、drf框架 - 序列化组件 | Serializer
目录 第三章.drf框架 - 序列化组件 | Serializer 序列化组件 知识点:Serializer(偏底层).ModelSerializer(重点).ListModelSerializer( ...
- ASP.NET Core 中文文档 第三章 原理(17)为你的服务器选择合适版本的.NET框架
原文:Choosing the Right .NET For You on the Server 作者:Daniel Roth 翻译:王健 校对:谢炀(Kiler).何镇汐.许登洋(Seay).孟帅洋 ...
- 《Django By Example》第三章 中文 翻译 (个人学习,渣翻)
书籍出处:https://www.packtpub.com/web-development/django-example 原作者:Antonio Melé (译者注:第三章滚烫出炉,大家请不要吐槽文中 ...
- 《Entity Framework 6 Recipes》中文翻译系列 (11) -----第三章 查询之异步查询
翻译的初衷以及为什么选择<Entity Framework 6 Recipes>来学习,请看本系列开篇 第三章 查询 前一章,我们展示了常见数据库场景的建模方式,本章将向你展示如何查询实体 ...
- 《Entity Framework 6 Recipes》中文翻译系列 (19) -----第三章 查询之使用位操作和多属性连接(join)
翻译的初衷以及为什么选择<Entity Framework 6 Recipes>来学习,请看本系列开篇 3-16 过滤中使用位操作 问题 你想在查询的过滤条件中使用位操作. 解决方案 假 ...
- 第19章 集合框架(3)-Map接口
第19章 集合框架(3)-Map接口 1.Map接口概述 Map是一种映射关系,那么什么是映射关系呢? 映射的数学解释 设A,B是两个非空集合,如果存在一个法则,使得对A中的每一个元素a,按法则f,在 ...
- 第18章 集合框架(2)-Set接口
第18章 集合框架(2)-Set接口 Set是Collection子接口,模拟了数学上的集的概念 Set集合存储特点 1.不允许元素重复 2.不会记录元素的先后添加顺序 Set只包含从Collecti ...
随机推荐
- Vue2和Vue3技术整理1 - 入门篇 - 更新完毕
Vue2 0.前言 首先说明:要直接上手简单得很,看官网熟悉大概有哪些东西.怎么用的,然后简单练一下就可以做出程序来了,最多两天,无论Vue2还是Vue3,就都完全可以了,Vue3就是比Vue2多了一 ...
- HTML 基础1
HTML 超文本标记语言 文件后缀html,htm 标签成对出现:开始标签--结束标签 元素内容位于开始标签--结束标签之间(可以有空内容) 空元素<a/> 大小写不敏感 元素,属性 &l ...
- Spring @SessionAttributes注解 @ModelAttribute注解
一.@SessionAttribute详解 如果多个请求之间需要共享数据,就可以使用@SessionAttribute. 配置的方法: 在控制器类上标注@SessionAttribute. 配置需要共 ...
- python小兵之时间模块
Python 日期和时间 Python 程序能用很多方式处理日期和时间,转换日期格式是一个常见的功能. Python 提供了一个 time 和 calendar 模块可以用于格式化日期和时间. 时间 ...
- Pandas 秘籍·翻译完成
协议:CC BY-NC-SA 4.0 欢迎任何人参与和完善:一个人可以走的很快,但是一群人却可以走的更远. 在线阅读 ApacheCN 面试求职交流群 724187166 ApacheCN 学习资源 ...
- Halcon视觉入门芯片识别
Halcon视觉入门芯片识别 需求 有如下图的一个摆盘,摆盘的方格中摆放芯片,一个格子中只放一个,我们需要知道每个方格中是否有芯片去指导我们将芯片放到空的方格中. 分析 通过图片分析得出 我们感兴趣的 ...
- (DDS)正弦波形发生器——幅值、频率、相位可调(二)
(DDS)正弦波形发生器--幅值.频率.相位可调(二) 主要关于调相方面 一.项目任务: 设计一个幅值.频率.相位均可调的正弦波发生器. 频率每次增加10kHz 相位每次增加 PI/2 幅值每次增加两 ...
- IAAS, SAAS, PAAS
原文是Pizza‐as‐a‐Service: a detailed view,用来类比Cloud Service Models.出处来自于Data Sovereignty and the Cloud ...
- Mac搭建Git服务器—开启SSH
SSH开启 在osx中开启ssh访问非常简单,只需要打开"系统偏好设置"并且点击"共享"图标即可. 选中下图中的check box即允许远程登陆.server处 ...
- MySQL事务以及存储引擎
MySQL事务以及存储引擎 目录 MySQL事务以及存储引擎 一.事务 1. 事务的概念 2. 事务的ACID特点 (1)原子性 (2)一致性 (3)隔离性 ①事务之间的相互影响 ②MySQL事务支持 ...